diff --git a/connection_test.go b/connection_test.go new file mode 100644 index 000000000..e952290be --- /dev/null +++ b/connection_test.go @@ -0,0 +1,33 @@ +package liveshare + +import "testing" + +func TestConnectionValid(t *testing.T) { + conn := Connection{"sess-id", "sess-token", "sas", "endpoint"} + if err := conn.validate(); err != nil { + t.Error(err) + } +} + +func TestConnectionInvalid(t *testing.T) { + conn := Connection{"", "sess-token", "sas", "endpoint"} + if err := conn.validate(); err == nil { + t.Error(err) + } + conn = Connection{"sess-id", "", "sas", "endpoint"} + if err := conn.validate(); err == nil { + t.Error(err) + } + conn = Connection{"sess-id", "sess-token", "", "endpoint"} + if err := conn.validate(); err == nil { + t.Error(err) + } + conn = Connection{"sess-id", "sess-token", "sas", ""} + if err := conn.validate(); err == nil { + t.Error(err) + } + conn = Connection{"", "", "", ""} + if err := conn.validate(); err == nil { + t.Error(err) + } +} diff --git a/port_forwarder.go b/port_forwarder.go index 6d459b4d6..0a049d586 100644 --- a/port_forwarder.go +++ b/port_forwarder.go @@ -54,8 +54,12 @@ func (l *PortForwarder) handleConnection(ctx context.Context, conn net.Conn) { copyConn := func(writer io.Writer, reader io.Reader) { if _, err := io.Copy(writer, reader); err != nil { + fmt.Println(err) channel.Close() conn.Close() + if err != io.EOF { + l.errCh <- fmt.Errorf("tunnel connection: %v", err) + } } } diff --git a/port_forwarder_test.go b/port_forwarder_test.go index e3e219705..33a33b39b 100644 --- a/port_forwarder_test.go +++ b/port_forwarder_test.go @@ -1 +1,103 @@ package liveshare + +import ( + "bytes" + "context" + "errors" + "fmt" + "io" + "net" + "testing" + "time" + + livesharetest "github.com/github/go-liveshare/test" + "github.com/sourcegraph/jsonrpc2" +) + +func TestNewPortForwarder(t *testing.T) { + testServer, client, err := makeMockJoinedClient() + if err != nil { + t.Errorf("create mock client: %v", err) + } + defer testServer.Close() + server, err := NewServer(client) + if err != nil { + t.Errorf("create new server: %v", err) + } + pf := NewPortForwarder(client, server, 80) + if pf == nil { + t.Error("port forwarder is nil") + } +} + +func TestPortForwarderStart(t *testing.T) { + streamName, streamCondition := "stream-name", "stream-condition" + serverSharing := func(req *jsonrpc2.Request) (interface{}, error) { + return Port{StreamName: streamName, StreamCondition: streamCondition}, nil + } + getStream := func(req *jsonrpc2.Request) (interface{}, error) { + return "stream-id", nil + } + + stream := bytes.NewBufferString("stream-data") + testServer, client, err := makeMockJoinedClient( + livesharetest.WithService("serverSharing.startSharing", serverSharing), + livesharetest.WithService("streamManager.getStream", getStream), + livesharetest.WithStream("stream-id", stream), + ) + if err != nil { + t.Errorf("create mock client: %v", err) + } + defer testServer.Close() + + server, err := NewServer(client) + if err != nil { + t.Errorf("create new server: %v", err) + } + + ctx, _ := context.WithCancel(context.Background()) + pf := NewPortForwarder(client, server, 8000) + done := make(chan error) + + go func() { + if err := server.StartSharing(ctx, "http", 8000); err != nil { + done <- fmt.Errorf("start sharing: %v", err) + } + if err := pf.Start(ctx); err != nil { + done <- err + } + done <- nil + }() + + go func() { + var conn net.Conn + retries := 0 + for conn == nil && retries < 2 { + conn, err = net.DialTimeout("tcp", ":8000", 2*time.Second) + time.Sleep(1 * time.Second) + } + if conn == nil { + done <- errors.New("failed to connect to forwarded port") + } + b := make([]byte, len("stream-data")) + if _, err := conn.Read(b); err != nil && err != io.EOF { + done <- fmt.Errorf("reading stream: %v", err) + } + if string(b) != "stream-data" { + done <- fmt.Errorf("stream data is not expected value, got: %v", string(b)) + } + if _, err := conn.Write([]byte("new-data")); err != nil { + done <- fmt.Errorf("writing to stream: %v", err) + } + done <- nil + }() + + select { + case err := <-testServer.Err(): + t.Errorf("error from server: %v", err) + case err := <-done: + if err != nil { + t.Errorf("error from client: %v", err) + } + } +} diff --git a/server.go b/server.go index 6f17d5ac5..7e8c8b1cb 100644 --- a/server.go +++ b/server.go @@ -7,12 +7,14 @@ import ( "strconv" ) +// A Server represents the liveshare host and container server type Server struct { client *Client port int streamName, streamCondition string } +// NewServer creates a new Server with a given Client func NewServer(client *Client) (*Server, error) { if !client.hasJoined() { return nil, errors.New("client must join before creating server") @@ -21,6 +23,7 @@ func NewServer(client *Client) (*Server, error) { return &Server{client: client}, nil } +// Port represents an open port on the container type Port struct { SourcePort int `json:"sourcePort"` DestinationPort int `json:"destinationPort"` @@ -33,6 +36,7 @@ type Port struct { HasTSLHandshakePassed bool `json:"hasTSLHandshakePassed"` } +// StartSharing tells the liveshare host to start sharing the port from the container func (s *Server) StartSharing(ctx context.Context, protocol string, port int) error { s.port = port @@ -49,8 +53,10 @@ func (s *Server) StartSharing(ctx context.Context, protocol string, port int) er return nil } +// Ports is a slice of Port pointers type Ports []*Port +// GetSharedServers returns a list of available/open ports from the container func (s *Server) GetSharedServers(ctx context.Context) (Ports, error) { var response Ports if err := s.client.rpc.do(ctx, "serverSharing.getSharedServers", []string{}, &response); err != nil { @@ -60,6 +66,8 @@ func (s *Server) GetSharedServers(ctx context.Context) (Ports, error) { return response, nil } +// UpdateSharedVisibility controls port permissions and whether it can be accessed publicly +// via the Browse URL func (s *Server) UpdateSharedVisibility(ctx context.Context, port int, public bool) error { if err := s.client.rpc.do(ctx, "serverSharing.updateSharedServerVisibility", []interface{}{port, public}, nil); err != nil { return err diff --git a/server_test.go b/server_test.go index 7c9ee4288..b91fbfddc 100644 --- a/server_test.go +++ b/server_test.go @@ -23,7 +23,7 @@ func TestNewServerWithNotJoinedClient(t *testing.T) { } } -func newMockJoinedClient(opts ...livesharetest.ServerOption) (*livesharetest.Server, *Client, error) { +func makeMockJoinedClient(opts ...livesharetest.ServerOption) (*livesharetest.Server, *Client, error) { connection := Connection{ SessionID: "session-id", SessionToken: "session-token", @@ -54,7 +54,7 @@ func newMockJoinedClient(opts ...livesharetest.ServerOption) (*livesharetest.Ser } func TestNewServer(t *testing.T) { - testServer, client, err := newMockJoinedClient() + testServer, client, err := makeMockJoinedClient() defer testServer.Close() if err != nil { t.Errorf("error creating mock joined client: %v", err) @@ -95,7 +95,7 @@ func TestServerStartSharing(t *testing.T) { } return Port{StreamName: "stream-name", StreamCondition: "stream-condition"}, nil } - testServer, client, err := newMockJoinedClient( + testServer, client, err := makeMockJoinedClient( livesharetest.WithService("serverSharing.startSharing", startSharing), ) defer testServer.Close() @@ -138,7 +138,7 @@ func TestServerGetSharedServers(t *testing.T) { getSharedServers := func(req *jsonrpc2.Request) (interface{}, error) { return Ports{&sharedServer}, nil } - testServer, client, err := newMockJoinedClient( + testServer, client, err := makeMockJoinedClient( livesharetest.WithService("serverSharing.getSharedServers", getSharedServers), ) if err != nil { @@ -206,7 +206,7 @@ func TestServerUpdateSharedVisibility(t *testing.T) { } return nil, nil } - testServer, client, err := newMockJoinedClient( + testServer, client, err := makeMockJoinedClient( livesharetest.WithService("serverSharing.updateSharedServerVisibility", updateSharedVisibility), ) if err != nil { diff --git a/socket.go b/socket.go index e4f80a0cf..8744eeb96 100644 --- a/socket.go +++ b/socket.go @@ -3,11 +3,9 @@ package liveshare import ( "context" "crypto/tls" - "errors" "io" "net" "net/http" - "sync" "time" "github.com/gorilla/websocket" @@ -17,10 +15,8 @@ type socket struct { addr string tlsConfig *tls.Config - conn *websocket.Conn - readMutex sync.Mutex - writeMutex sync.Mutex - reader io.Reader + conn *websocket.Conn + reader io.Reader } func newSocket(clientConn Connection, tlsConfig *tls.Config) *socket { @@ -42,19 +38,12 @@ func (s *socket) connect(ctx context.Context) error { } func (s *socket) Read(b []byte) (int, error) { - s.readMutex.Lock() - defer s.readMutex.Unlock() - if s.reader == nil { - messageType, reader, err := s.conn.NextReader() + _, reader, err := s.conn.NextReader() if err != nil { return 0, err } - if messageType != websocket.BinaryMessage { - return 0, errors.New("unexpected websocket message type") - } - s.reader = reader } @@ -71,9 +60,6 @@ func (s *socket) Read(b []byte) (int, error) { } func (s *socket) Write(b []byte) (int, error) { - s.writeMutex.Lock() - defer s.writeMutex.Unlock() - nextWriter, err := s.conn.NextWriter(websocket.BinaryMessage) if err != nil { return 0, err diff --git a/test/server.go b/test/server.go index abb7ac96a..a52d31ab9 100644 --- a/test/server.go +++ b/test/server.go @@ -9,6 +9,7 @@ import ( "net/http" "net/http/httptest" "path/filepath" + "strings" "sync" "time" @@ -21,6 +22,7 @@ type Server struct { password string services map[string]RpcHandleFunc relaySAS string + streams map[string]io.ReadWriter sshConfig *ssh.ServerConfig httptestServer *httptest.Server @@ -50,7 +52,7 @@ func NewServer(opts ...ServerOption) (*Server, error) { server.sshConfig.AddHostKey(privateKey) server.errCh = make(chan error) - server.httptestServer = httptest.NewTLSServer(http.HandlerFunc(newConnection(server))) + server.httptestServer = httptest.NewTLSServer(http.HandlerFunc(makeConnection(server))) return server, nil } @@ -81,6 +83,16 @@ func WithRelaySAS(sas string) ServerOption { } } +func WithStream(name string, stream io.ReadWriter) ServerOption { + return func(s *Server) error { + if s.streams == nil { + s.streams = make(map[string]io.ReadWriter) + } + s.streams[name] = stream + return nil + } +} + func sshPasswordCallback(serverPassword string) func(ssh.ConnMetadata, []byte) (*ssh.Permissions, error) { return func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) { if string(password) == serverPassword { @@ -104,7 +116,7 @@ func (s *Server) Err() <-chan error { var upgrader = websocket.Upgrader{} -func newConnection(server *Server) http.HandlerFunc { +func makeConnection(server *Server) http.HandlerFunc { return func(w http.ResponseWriter, req *http.Request) { if server.relaySAS != "" { // validate the sas key @@ -135,12 +147,48 @@ func newConnection(server *Server) http.HandlerFunc { server.errCh <- fmt.Errorf("error accepting new channel: %v", err) return } - go ssh.DiscardRequests(reqs) + go handleNewRequests(server, ch, reqs) go handleNewChannel(server, ch) } } } +func handleNewRequests(server *Server, channel ssh.Channel, reqs <-chan *ssh.Request) { + for req := range reqs { + if req.WantReply { + if err := req.Reply(true, nil); err != nil { + server.errCh <- fmt.Errorf("error replying to channel request: %v", err) + } + } + if strings.HasPrefix(req.Type, "stream-transport") { + forwardStream(server, req.Type, channel) + } + } +} + +func forwardStream(server *Server, streamName string, channel ssh.Channel) { + simpleStreamName := strings.TrimPrefix(streamName, "stream-transport-") + stream, found := server.streams[simpleStreamName] + if !found { + server.errCh <- fmt.Errorf("stream '%v' not found", simpleStreamName) + return + } + + copy := func(dst io.Writer, src io.Reader) { + if _, err := io.Copy(dst, src); err != nil { + fmt.Println(err) + server.errCh <- fmt.Errorf("io copy: %v", err) + return + } + } + + go copy(stream, channel) + go copy(channel, stream) + + for { + } +} + func handleNewChannel(server *Server, channel ssh.Channel) { stream := jsonrpc2.NewBufferedStream(channel, jsonrpc2.VSCodeObjectCodec{}) jsonrpc2.NewConn(context.Background(), stream, newRpcHandler(server))