Checking point after continuing to flesh out mock server
This commit is contained in:
parent
b9cd9af7fa
commit
9132a28e9c
4 changed files with 299 additions and 55 deletions
11
client.go
11
client.go
|
|
@ -2,6 +2,7 @@ package liveshare
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
|
@ -10,6 +11,7 @@ import (
|
|||
// A Client capable of joining a liveshare connection
|
||||
type Client struct {
|
||||
connection Connection
|
||||
tlsConfig *tls.Config
|
||||
|
||||
ssh *sshSession
|
||||
rpc *rpcClient
|
||||
|
|
@ -43,9 +45,16 @@ func WithConnection(connection Connection) ClientOption {
|
|||
}
|
||||
}
|
||||
|
||||
func WithTLSConfig(tlsConfig *tls.Config) ClientOption {
|
||||
return func(c *Client) error {
|
||||
c.tlsConfig = tlsConfig
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Join is a method that joins the client to the liveshare session
|
||||
func (c *Client) Join(ctx context.Context) (err error) {
|
||||
clientSocket := newSocket(c.connection)
|
||||
clientSocket := newSocket(c.connection, c.tlsConfig)
|
||||
if err := clientSocket.connect(ctx); err != nil {
|
||||
return fmt.Errorf("error connecting websocket: %v", err)
|
||||
}
|
||||
|
|
|
|||
100
client_test.go
100
client_test.go
|
|
@ -1,12 +1,14 @@
|
|||
package liveshare
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
livesharetest "github.com/github/go-liveshare/test"
|
||||
"github.com/sourcegraph/jsonrpc2"
|
||||
)
|
||||
|
||||
func TestNewClient(t *testing.T) {
|
||||
|
|
@ -39,53 +41,51 @@ func TestNewClientWithInvalidConnection(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
var upgrader = websocket.Upgrader{}
|
||||
|
||||
func newMockLiveShareServer() *httptest.Server {
|
||||
endpoint := func(w http.ResponseWriter, req *http.Request) {
|
||||
c, err := upgrader.Upgrade(w, req, nil)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
return
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
for {
|
||||
mt, message, err := c.ReadMessage()
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
break
|
||||
}
|
||||
|
||||
err = c.WriteMessage(mt, message)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
break
|
||||
}
|
||||
|
||||
}
|
||||
func TestClientJoin(t *testing.T) {
|
||||
sessionToken := "session-token"
|
||||
joinWorkspace := func(req *jsonrpc2.Request) (interface{}, error) {
|
||||
return 1, nil
|
||||
}
|
||||
|
||||
return httptest.NewTLSServer(http.HandlerFunc(endpoint))
|
||||
}
|
||||
|
||||
func TestClientJoin(t *testing.T) {
|
||||
// server := newMockLiveShareServer()
|
||||
// defer server.Close()
|
||||
|
||||
// connection := Connection{
|
||||
// SessionID: "session-id",
|
||||
// SessionToken: "session-token",
|
||||
// RelaySAS: "relay-sas",
|
||||
// RelayEndpoint: "sb" + strings.TrimPrefix(server.URL, "https"),
|
||||
// }
|
||||
|
||||
// client, err := NewClient(WithConnection(connection))
|
||||
// if err != nil {
|
||||
// t.Errorf("error creating new client: %v", err)
|
||||
// }
|
||||
// ctx := context.Background()
|
||||
// if err := client.Join(ctx); err != nil {
|
||||
// t.Errorf("error joining client: %v", err)
|
||||
// }
|
||||
server, err := livesharetest.NewServer(
|
||||
livesharetest.WithPassword(sessionToken),
|
||||
livesharetest.WithService("workspace.joinWorkspace", joinWorkspace),
|
||||
)
|
||||
if err != nil {
|
||||
t.Errorf("error creating liveshare server: %v", err)
|
||||
}
|
||||
defer server.Close()
|
||||
|
||||
ctx := context.Background()
|
||||
connection := Connection{
|
||||
SessionID: "session-id",
|
||||
SessionToken: sessionToken,
|
||||
RelaySAS: "relay-sas",
|
||||
RelayEndpoint: "sb" + strings.TrimPrefix(server.URL(), "https"),
|
||||
}
|
||||
|
||||
tlsConfig := WithTLSConfig(&tls.Config{InsecureSkipVerify: true})
|
||||
client, err := NewClient(WithConnection(connection), tlsConfig)
|
||||
if err != nil {
|
||||
t.Errorf("error creating new client: %v", err)
|
||||
}
|
||||
|
||||
clientErr := make(chan error)
|
||||
go func() {
|
||||
if err := client.Join(ctx); err != nil {
|
||||
clientErr <- fmt.Errorf("error joining client: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
ctx.Done()
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-server.Err():
|
||||
t.Errorf("error from server: %v", err)
|
||||
case err := <-clientErr:
|
||||
t.Errorf("error from client: %v", err)
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
|
|
|
|||
17
socket.go
17
socket.go
|
|
@ -2,9 +2,11 @@ package liveshare
|
|||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
|
|
@ -12,19 +14,26 @@ import (
|
|||
)
|
||||
|
||||
type socket struct {
|
||||
addr string
|
||||
addr string
|
||||
tlsConfig *tls.Config
|
||||
|
||||
conn *websocket.Conn
|
||||
readMutex sync.Mutex
|
||||
writeMutex sync.Mutex
|
||||
reader io.Reader
|
||||
}
|
||||
|
||||
func newSocket(clientConn Connection) *socket {
|
||||
return &socket{addr: clientConn.uri("connect")}
|
||||
func newSocket(clientConn Connection, tlsConfig *tls.Config) *socket {
|
||||
return &socket{addr: clientConn.uri("connect"), tlsConfig: tlsConfig}
|
||||
}
|
||||
|
||||
func (s *socket) connect(ctx context.Context) error {
|
||||
ws, _, err := websocket.DefaultDialer.Dial(s.addr, nil)
|
||||
dialer := websocket.Dialer{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
HandshakeTimeout: 45 * time.Second,
|
||||
TLSClientConfig: s.tlsConfig,
|
||||
}
|
||||
ws, _, err := dialer.Dial(s.addr, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
|||
226
test/server.go
Normal file
226
test/server.go
Normal file
|
|
@ -0,0 +1,226 @@
|
|||
package livesharetest
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/sourcegraph/jsonrpc2"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
password string
|
||||
services map[string]RpcHandleFunc
|
||||
|
||||
sshConfig *ssh.ServerConfig
|
||||
httptestServer *httptest.Server
|
||||
errCh chan error
|
||||
}
|
||||
|
||||
func NewServer(opts ...ServerOption) (*Server, error) {
|
||||
server := new(Server)
|
||||
|
||||
for _, o := range opts {
|
||||
if err := o(server); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
server.sshConfig = &ssh.ServerConfig{
|
||||
PasswordCallback: sshPasswordCallback(server.password),
|
||||
}
|
||||
b, err := ioutil.ReadFile(filepath.Join("test", "private.key"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error reading private.key: %v", err)
|
||||
}
|
||||
privateKey, err := ssh.ParsePrivateKey(b)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing key: %v", err)
|
||||
}
|
||||
server.sshConfig.AddHostKey(privateKey)
|
||||
|
||||
server.errCh = make(chan error)
|
||||
server.httptestServer = httptest.NewTLSServer(http.HandlerFunc(newConnection(server)))
|
||||
return server, nil
|
||||
}
|
||||
|
||||
type ServerOption func(*Server) error
|
||||
|
||||
func WithPassword(password string) ServerOption {
|
||||
return func(s *Server) error {
|
||||
s.password = password
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func WithService(serviceName string, handler RpcHandleFunc) ServerOption {
|
||||
return func(s *Server) error {
|
||||
if s.services == nil {
|
||||
s.services = make(map[string]RpcHandleFunc)
|
||||
}
|
||||
|
||||
s.services[serviceName] = handler
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func sshPasswordCallback(serverPassword string) func(ssh.ConnMetadata, []byte) (*ssh.Permissions, error) {
|
||||
return func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
|
||||
if string(password) == serverPassword {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, errors.New("password rejected")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) Close() {
|
||||
s.httptestServer.Close()
|
||||
}
|
||||
|
||||
func (s *Server) URL() string {
|
||||
return s.httptestServer.URL
|
||||
}
|
||||
|
||||
func (s *Server) Err() <-chan error {
|
||||
return s.errCh
|
||||
}
|
||||
|
||||
var upgrader = websocket.Upgrader{}
|
||||
|
||||
func newConnection(server *Server) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, req *http.Request) {
|
||||
c, err := upgrader.Upgrade(w, req, nil)
|
||||
if err != nil {
|
||||
server.errCh <- fmt.Errorf("error upgrading connection: %v", err)
|
||||
return
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
socketConn := newSocketConn(c)
|
||||
_, chans, reqs, err := ssh.NewServerConn(socketConn, server.sshConfig)
|
||||
if err != nil {
|
||||
server.errCh <- fmt.Errorf("error creating new ssh conn: %v", err)
|
||||
return
|
||||
}
|
||||
go ssh.DiscardRequests(reqs)
|
||||
|
||||
for newChannel := range chans {
|
||||
ch, reqs, err := newChannel.Accept()
|
||||
if err != nil {
|
||||
server.errCh <- fmt.Errorf("error accepting new channel: %v", err)
|
||||
return
|
||||
}
|
||||
go ssh.DiscardRequests(reqs)
|
||||
go handleNewChannel(server, ch)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func handleNewChannel(server *Server, channel ssh.Channel) {
|
||||
stream := jsonrpc2.NewBufferedStream(channel, jsonrpc2.VSCodeObjectCodec{})
|
||||
jsonrpc2.NewConn(context.Background(), stream, newRpcHandler(server))
|
||||
}
|
||||
|
||||
type RpcHandleFunc func(req *jsonrpc2.Request) (interface{}, error)
|
||||
|
||||
type rpcHandler struct {
|
||||
server *Server
|
||||
}
|
||||
|
||||
func newRpcHandler(server *Server) *rpcHandler {
|
||||
return &rpcHandler{server}
|
||||
}
|
||||
|
||||
func (r *rpcHandler) Handle(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2.Request) {
|
||||
handler, found := r.server.services[req.Method]
|
||||
if !found {
|
||||
r.server.errCh <- fmt.Errorf("RPC Method: '%v' not serviced", req.Method)
|
||||
return
|
||||
}
|
||||
|
||||
result, err := handler(req)
|
||||
if err != nil {
|
||||
r.server.errCh <- fmt.Errorf("error handling: '%v': %v", req.Method, err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := conn.Reply(ctx, req.ID, result); err != nil {
|
||||
r.server.errCh <- fmt.Errorf("error replying: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
type socketConn struct {
|
||||
*websocket.Conn
|
||||
|
||||
reader io.Reader
|
||||
writeMutex sync.Mutex
|
||||
readMutex sync.Mutex
|
||||
}
|
||||
|
||||
func newSocketConn(conn *websocket.Conn) *socketConn {
|
||||
return &socketConn{Conn: conn}
|
||||
}
|
||||
|
||||
func (s *socketConn) Read(b []byte) (int, error) {
|
||||
s.readMutex.Lock()
|
||||
defer s.readMutex.Unlock()
|
||||
|
||||
if s.reader == nil {
|
||||
msgType, r, err := s.Conn.NextReader()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("error getting next reader: %v", err)
|
||||
}
|
||||
if msgType != websocket.BinaryMessage {
|
||||
return 0, fmt.Errorf("invalid message type")
|
||||
}
|
||||
s.reader = r
|
||||
}
|
||||
|
||||
bytesRead, err := s.reader.Read(b)
|
||||
if err != nil {
|
||||
s.reader = nil
|
||||
|
||||
if err == io.EOF {
|
||||
err = nil
|
||||
}
|
||||
}
|
||||
|
||||
return bytesRead, err
|
||||
}
|
||||
|
||||
func (s *socketConn) Write(b []byte) (int, error) {
|
||||
s.writeMutex.Lock()
|
||||
defer s.writeMutex.Unlock()
|
||||
|
||||
w, err := s.Conn.NextWriter(websocket.BinaryMessage)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("error getting next writer: %v", err)
|
||||
}
|
||||
|
||||
n, err := w.Write(b)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("error writing: %v", err)
|
||||
}
|
||||
|
||||
if err := w.Close(); err != nil {
|
||||
return 0, fmt.Errorf("error closing writer: %v", err)
|
||||
}
|
||||
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (s *socketConn) SetDeadline(deadline time.Time) error {
|
||||
if err := s.Conn.SetReadDeadline(deadline); err != nil {
|
||||
return err
|
||||
}
|
||||
return s.Conn.SetWriteDeadline(deadline)
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue