From 9132a28e9cf2a09359a80e104a558c77a0c0abea Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Fri, 23 Jul 2021 19:15:54 +0000 Subject: [PATCH] Checking point after continuing to flesh out mock server --- client.go | 11 ++- client_test.go | 100 +++++++++++----------- socket.go | 17 +++- test/server.go | 226 +++++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 299 insertions(+), 55 deletions(-) create mode 100644 test/server.go diff --git a/client.go b/client.go index 435dd6775..a8a1e3864 100644 --- a/client.go +++ b/client.go @@ -2,6 +2,7 @@ package liveshare import ( "context" + "crypto/tls" "fmt" "golang.org/x/crypto/ssh" @@ -10,6 +11,7 @@ import ( // A Client capable of joining a liveshare connection type Client struct { connection Connection + tlsConfig *tls.Config ssh *sshSession rpc *rpcClient @@ -43,9 +45,16 @@ func WithConnection(connection Connection) ClientOption { } } +func WithTLSConfig(tlsConfig *tls.Config) ClientOption { + return func(c *Client) error { + c.tlsConfig = tlsConfig + return nil + } +} + // Join is a method that joins the client to the liveshare session func (c *Client) Join(ctx context.Context) (err error) { - clientSocket := newSocket(c.connection) + clientSocket := newSocket(c.connection, c.tlsConfig) if err := clientSocket.connect(ctx); err != nil { return fmt.Errorf("error connecting websocket: %v", err) } diff --git a/client_test.go b/client_test.go index 86f732637..bf77e3dce 100644 --- a/client_test.go +++ b/client_test.go @@ -1,12 +1,14 @@ package liveshare import ( + "context" + "crypto/tls" "fmt" - "net/http" - "net/http/httptest" + "strings" "testing" - "github.com/gorilla/websocket" + livesharetest "github.com/github/go-liveshare/test" + "github.com/sourcegraph/jsonrpc2" ) func TestNewClient(t *testing.T) { @@ -39,53 +41,51 @@ func TestNewClientWithInvalidConnection(t *testing.T) { } } -var upgrader = websocket.Upgrader{} - -func newMockLiveShareServer() *httptest.Server { - endpoint := func(w http.ResponseWriter, req *http.Request) { - c, err := upgrader.Upgrade(w, req, nil) - if err != nil { - fmt.Println(err) - return - } - defer c.Close() - - for { - mt, message, err := c.ReadMessage() - if err != nil { - fmt.Println(err) - break - } - - err = c.WriteMessage(mt, message) - if err != nil { - fmt.Println(err) - break - } - - } +func TestClientJoin(t *testing.T) { + sessionToken := "session-token" + joinWorkspace := func(req *jsonrpc2.Request) (interface{}, error) { + return 1, nil } - return httptest.NewTLSServer(http.HandlerFunc(endpoint)) -} - -func TestClientJoin(t *testing.T) { - // server := newMockLiveShareServer() - // defer server.Close() - - // connection := Connection{ - // SessionID: "session-id", - // SessionToken: "session-token", - // RelaySAS: "relay-sas", - // RelayEndpoint: "sb" + strings.TrimPrefix(server.URL, "https"), - // } - - // client, err := NewClient(WithConnection(connection)) - // if err != nil { - // t.Errorf("error creating new client: %v", err) - // } - // ctx := context.Background() - // if err := client.Join(ctx); err != nil { - // t.Errorf("error joining client: %v", err) - // } + server, err := livesharetest.NewServer( + livesharetest.WithPassword(sessionToken), + livesharetest.WithService("workspace.joinWorkspace", joinWorkspace), + ) + if err != nil { + t.Errorf("error creating liveshare server: %v", err) + } + defer server.Close() + + ctx := context.Background() + connection := Connection{ + SessionID: "session-id", + SessionToken: sessionToken, + RelaySAS: "relay-sas", + RelayEndpoint: "sb" + strings.TrimPrefix(server.URL(), "https"), + } + + tlsConfig := WithTLSConfig(&tls.Config{InsecureSkipVerify: true}) + client, err := NewClient(WithConnection(connection), tlsConfig) + if err != nil { + t.Errorf("error creating new client: %v", err) + } + + clientErr := make(chan error) + go func() { + if err := client.Join(ctx); err != nil { + clientErr <- fmt.Errorf("error joining client: %v", err) + return + } + + ctx.Done() + }() + + select { + case err := <-server.Err(): + t.Errorf("error from server: %v", err) + case err := <-clientErr: + t.Errorf("error from client: %v", err) + case <-ctx.Done(): + return + } } diff --git a/socket.go b/socket.go index c3f75b9db..e4f80a0cf 100644 --- a/socket.go +++ b/socket.go @@ -2,9 +2,11 @@ package liveshare import ( "context" + "crypto/tls" "errors" "io" "net" + "net/http" "sync" "time" @@ -12,19 +14,26 @@ import ( ) type socket struct { - addr string + addr string + tlsConfig *tls.Config + conn *websocket.Conn readMutex sync.Mutex writeMutex sync.Mutex reader io.Reader } -func newSocket(clientConn Connection) *socket { - return &socket{addr: clientConn.uri("connect")} +func newSocket(clientConn Connection, tlsConfig *tls.Config) *socket { + return &socket{addr: clientConn.uri("connect"), tlsConfig: tlsConfig} } func (s *socket) connect(ctx context.Context) error { - ws, _, err := websocket.DefaultDialer.Dial(s.addr, nil) + dialer := websocket.Dialer{ + Proxy: http.ProxyFromEnvironment, + HandshakeTimeout: 45 * time.Second, + TLSClientConfig: s.tlsConfig, + } + ws, _, err := dialer.Dial(s.addr, nil) if err != nil { return err } diff --git a/test/server.go b/test/server.go new file mode 100644 index 000000000..ed8666cce --- /dev/null +++ b/test/server.go @@ -0,0 +1,226 @@ +package livesharetest + +import ( + "context" + "errors" + "fmt" + "io" + "io/ioutil" + "net/http" + "net/http/httptest" + "path/filepath" + "sync" + "time" + + "github.com/gorilla/websocket" + "github.com/sourcegraph/jsonrpc2" + "golang.org/x/crypto/ssh" +) + +type Server struct { + password string + services map[string]RpcHandleFunc + + sshConfig *ssh.ServerConfig + httptestServer *httptest.Server + errCh chan error +} + +func NewServer(opts ...ServerOption) (*Server, error) { + server := new(Server) + + for _, o := range opts { + if err := o(server); err != nil { + return nil, err + } + } + + server.sshConfig = &ssh.ServerConfig{ + PasswordCallback: sshPasswordCallback(server.password), + } + b, err := ioutil.ReadFile(filepath.Join("test", "private.key")) + if err != nil { + return nil, fmt.Errorf("error reading private.key: %v", err) + } + privateKey, err := ssh.ParsePrivateKey(b) + if err != nil { + return nil, fmt.Errorf("error parsing key: %v", err) + } + server.sshConfig.AddHostKey(privateKey) + + server.errCh = make(chan error) + server.httptestServer = httptest.NewTLSServer(http.HandlerFunc(newConnection(server))) + return server, nil +} + +type ServerOption func(*Server) error + +func WithPassword(password string) ServerOption { + return func(s *Server) error { + s.password = password + return nil + } +} + +func WithService(serviceName string, handler RpcHandleFunc) ServerOption { + return func(s *Server) error { + if s.services == nil { + s.services = make(map[string]RpcHandleFunc) + } + + s.services[serviceName] = handler + return nil + } +} + +func sshPasswordCallback(serverPassword string) func(ssh.ConnMetadata, []byte) (*ssh.Permissions, error) { + return func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) { + if string(password) == serverPassword { + return nil, nil + } + return nil, errors.New("password rejected") + } +} + +func (s *Server) Close() { + s.httptestServer.Close() +} + +func (s *Server) URL() string { + return s.httptestServer.URL +} + +func (s *Server) Err() <-chan error { + return s.errCh +} + +var upgrader = websocket.Upgrader{} + +func newConnection(server *Server) http.HandlerFunc { + return func(w http.ResponseWriter, req *http.Request) { + c, err := upgrader.Upgrade(w, req, nil) + if err != nil { + server.errCh <- fmt.Errorf("error upgrading connection: %v", err) + return + } + defer c.Close() + + socketConn := newSocketConn(c) + _, chans, reqs, err := ssh.NewServerConn(socketConn, server.sshConfig) + if err != nil { + server.errCh <- fmt.Errorf("error creating new ssh conn: %v", err) + return + } + go ssh.DiscardRequests(reqs) + + for newChannel := range chans { + ch, reqs, err := newChannel.Accept() + if err != nil { + server.errCh <- fmt.Errorf("error accepting new channel: %v", err) + return + } + go ssh.DiscardRequests(reqs) + go handleNewChannel(server, ch) + } + } +} + +func handleNewChannel(server *Server, channel ssh.Channel) { + stream := jsonrpc2.NewBufferedStream(channel, jsonrpc2.VSCodeObjectCodec{}) + jsonrpc2.NewConn(context.Background(), stream, newRpcHandler(server)) +} + +type RpcHandleFunc func(req *jsonrpc2.Request) (interface{}, error) + +type rpcHandler struct { + server *Server +} + +func newRpcHandler(server *Server) *rpcHandler { + return &rpcHandler{server} +} + +func (r *rpcHandler) Handle(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2.Request) { + handler, found := r.server.services[req.Method] + if !found { + r.server.errCh <- fmt.Errorf("RPC Method: '%v' not serviced", req.Method) + return + } + + result, err := handler(req) + if err != nil { + r.server.errCh <- fmt.Errorf("error handling: '%v': %v", req.Method, err) + return + } + + if err := conn.Reply(ctx, req.ID, result); err != nil { + r.server.errCh <- fmt.Errorf("error replying: %v", err) + } +} + +type socketConn struct { + *websocket.Conn + + reader io.Reader + writeMutex sync.Mutex + readMutex sync.Mutex +} + +func newSocketConn(conn *websocket.Conn) *socketConn { + return &socketConn{Conn: conn} +} + +func (s *socketConn) Read(b []byte) (int, error) { + s.readMutex.Lock() + defer s.readMutex.Unlock() + + if s.reader == nil { + msgType, r, err := s.Conn.NextReader() + if err != nil { + return 0, fmt.Errorf("error getting next reader: %v", err) + } + if msgType != websocket.BinaryMessage { + return 0, fmt.Errorf("invalid message type") + } + s.reader = r + } + + bytesRead, err := s.reader.Read(b) + if err != nil { + s.reader = nil + + if err == io.EOF { + err = nil + } + } + + return bytesRead, err +} + +func (s *socketConn) Write(b []byte) (int, error) { + s.writeMutex.Lock() + defer s.writeMutex.Unlock() + + w, err := s.Conn.NextWriter(websocket.BinaryMessage) + if err != nil { + return 0, fmt.Errorf("error getting next writer: %v", err) + } + + n, err := w.Write(b) + if err != nil { + return 0, fmt.Errorf("error writing: %v", err) + } + + if err := w.Close(); err != nil { + return 0, fmt.Errorf("error closing writer: %v", err) + } + + return n, nil +} + +func (s *socketConn) SetDeadline(deadline time.Time) error { + if err := s.Conn.SetReadDeadline(deadline); err != nil { + return err + } + return s.Conn.SetWriteDeadline(deadline) +}