diff --git a/cmd/ghcs/output/logger.go b/cmd/ghcs/output/logger.go index a2aa68ba1..6ad7513f1 100644 --- a/cmd/ghcs/output/logger.go +++ b/cmd/ghcs/output/logger.go @@ -3,6 +3,7 @@ package output import ( "fmt" "io" + "sync" ) // NewLogger returns a Logger that will write to the given stdout/stderr writers. @@ -19,6 +20,7 @@ func NewLogger(stdout, stderr io.Writer, disabled bool) *Logger { // If not enabled, Print functions will noop but Error functions will continue // to write to the stderr writer. type Logger struct { + mu sync.Mutex // guards the writers out io.Writer errout io.Writer enabled bool @@ -29,6 +31,9 @@ func (l *Logger) Print(v ...interface{}) (int, error) { if !l.enabled { return 0, nil } + + l.mu.Lock() + defer l.mu.Unlock() return fmt.Fprint(l.out, v...) } @@ -37,6 +42,9 @@ func (l *Logger) Println(v ...interface{}) (int, error) { if !l.enabled { return 0, nil } + + l.mu.Lock() + defer l.mu.Unlock() return fmt.Fprintln(l.out, v...) } @@ -45,15 +53,22 @@ func (l *Logger) Printf(f string, v ...interface{}) (int, error) { if !l.enabled { return 0, nil } + + l.mu.Lock() + defer l.mu.Unlock() return fmt.Fprintf(l.out, f, v...) } // Errorf writes the formatted arguments to the stderr writer. func (l *Logger) Errorf(f string, v ...interface{}) (int, error) { + l.mu.Lock() + defer l.mu.Unlock() return fmt.Fprintf(l.errout, f, v...) } // Errorln writes the arguments to the stderr writer with a newline at the end. func (l *Logger) Errorln(v ...interface{}) (int, error) { + l.mu.Lock() + defer l.mu.Unlock() return fmt.Fprintln(l.errout, v...) } diff --git a/internal/liveshare/test/server.go b/internal/liveshare/test/server.go index 9b898dafb..058080b56 100644 --- a/internal/liveshare/test/server.go +++ b/internal/liveshare/test/server.go @@ -8,6 +8,7 @@ import ( "net/http" "net/http/httptest" "strings" + "sync" "github.com/gorilla/websocket" "github.com/sourcegraph/jsonrpc2" @@ -42,17 +43,19 @@ IfRJxKWb0Wbt9ojw3AowK/k0d3LZA7FS41JSiiGKIllSGb+i7JKqKW7RHLA3VJ/E Bq5TLNIbUzPVNVwRcGjUYpOhKU6EIw8phTJOvxnUC+g6MVqBP8U= -----END RSA PRIVATE KEY-----` +// Server represents a LiveShare relay host server. type Server struct { - password string - services map[string]RPCHandleFunc - relaySAS string - streams map[string]io.ReadWriter - + password string + services map[string]RPCHandleFunc + relaySAS string + streams map[string]io.ReadWriter sshConfig *ssh.ServerConfig httptestServer *httptest.Server errCh chan error } +// NewServer creates a new Server. ServerOptions can be passed to configure +// the SSH password, backing service, secrets and more. func NewServer(opts ...ServerOption) (*Server, error) { server := new(Server) @@ -71,13 +74,15 @@ func NewServer(opts ...ServerOption) (*Server, error) { } server.sshConfig.AddHostKey(privateKey) - server.errCh = make(chan error) + server.errCh = make(chan error, 1) server.httptestServer = httptest.NewTLSServer(http.HandlerFunc(makeConnection(server))) return server, nil } +// ServerOption is used to configure the Server. type ServerOption func(*Server) error +// WithPassword configures the Server password for SSH. func WithPassword(password string) ServerOption { return func(s *Server) error { s.password = password @@ -85,6 +90,7 @@ func WithPassword(password string) ServerOption { } } +// WithService accepts a mock RPC service for the Server to invoke. func WithService(serviceName string, handler RPCHandleFunc) ServerOption { return func(s *Server) error { if s.services == nil { @@ -96,6 +102,7 @@ func WithService(serviceName string, handler RPCHandleFunc) ServerOption { } } +// WithRelaySAS configures the relay SAS configuration key. func WithRelaySAS(sas string) ServerOption { return func(s *Server) error { s.relaySAS = sas @@ -103,6 +110,7 @@ func WithRelaySAS(sas string) ServerOption { } } +// WithStream allows you to specify a mock data stream for the server. func WithStream(name string, stream io.ReadWriter) ServerOption { return func(s *Server) error { if s.streams == nil { @@ -122,10 +130,12 @@ func sshPasswordCallback(serverPassword string) func(ssh.ConnMetadata, []byte) ( } } +// Close closes the underlying httptest Server. func (s *Server) Close() { s.httptestServer.Close() } +// URL returns the httptest Server url. func (s *Server) URL() string { return s.httptestServer.URL } @@ -145,73 +155,160 @@ func makeConnection(server *Server) http.HandlerFunc { // validate the sas key sasParam := req.URL.Query().Get("sb-hc-token") if sasParam != server.relaySAS { - server.errCh <- errors.New("error validating sas") + sendError(server.errCh, errors.New("error validating sas")) return } } c, err := upgrader.Upgrade(w, req, nil) if err != nil { - server.errCh <- fmt.Errorf("error upgrading connection: %w", err) + sendError(server.errCh, fmt.Errorf("error upgrading connection: %w", err)) return } - defer c.Close() + defer func() { + if err := c.Close(); err != nil { + sendError(server.errCh, err) + } + }() socketConn := newSocketConn(c) _, chans, reqs, err := ssh.NewServerConn(socketConn, server.sshConfig) if err != nil { - server.errCh <- fmt.Errorf("error creating new ssh conn: %w", err) + sendError(server.errCh, fmt.Errorf("error creating new ssh conn: %w", err)) return } go ssh.DiscardRequests(reqs) - for newChannel := range chans { - ch, reqs, err := newChannel.Accept() + if err := handleChannels(ctx, server, chans); err != nil { + sendError(server.errCh, err) + } + } +} + +// sendError does a non-blocking send of the error to the err channel. +func sendError(errc chan<- error, err error) { + select { + case errc <- err: + default: + // channel is blocked with a previous error, so we ignore + // this current error + } +} + +// awaitError waits for the context to finish and returns its error (if any). +// It also waits for an err to come through the err channel. +func awaitError(ctx context.Context, errc <-chan error) error { + select { + case <-ctx.Done(): + return ctx.Err() + case err := <-errc: + return err + } +} + +// handleChannels services the sshChannels channel. For each SSH channel received +// it creates a go routine to service the channel's requests. It returns on the first +// error encountered. +func handleChannels(ctx context.Context, server *Server, sshChannels <-chan ssh.NewChannel) error { + errc := make(chan error, 1) + go func() { + for sshCh := range sshChannels { + ch, reqs, err := sshCh.Accept() if err != nil { - server.errCh <- fmt.Errorf("error accepting new channel: %w", err) + sendError(errc, fmt.Errorf("failed to accept channel: %w", err)) return } - go handleNewRequests(ctx, server, ch, reqs) - go handleNewChannel(server, ch) + + go func() { + if err := handleRequests(ctx, server, ch, reqs); err != nil { + sendError(errc, fmt.Errorf("failed to handle requests: %w", err)) + } + }() + + handleChannel(server, ch) } - } + }() + return awaitError(ctx, errc) } -func handleNewRequests(ctx context.Context, server *Server, channel ssh.Channel, reqs <-chan *ssh.Request) { - for req := range reqs { - if req.WantReply { - if err := req.Reply(true, nil); err != nil { - server.errCh <- fmt.Errorf("error replying to channel request: %w", err) +// handleRequests services the SSH channel requests channel. It replies to requests and +// when stream transport requests are encountered, creates a go routine to create a +// bi-directional data stream between the channel and server stream. It returns on the first error +// encountered. +func handleRequests(ctx context.Context, server *Server, channel ssh.Channel, reqs <-chan *ssh.Request) error { + errc := make(chan error, 1) + go func() { + for req := range reqs { + if req.WantReply { + if err := req.Reply(true, nil); err != nil { + sendError(errc, fmt.Errorf("error replying to channel request: %w", err)) + return + } + } + + if strings.HasPrefix(req.Type, "stream-transport") { + go func() { + if err := forwardStream(ctx, server, req.Type, channel); err != nil { + sendError(errc, fmt.Errorf("failed to forward stream: %w", err)) + } + }() } } - if strings.HasPrefix(req.Type, "stream-transport") { - forwardStream(ctx, server, req.Type, channel) - } - } + }() + + return awaitError(ctx, errc) } -func forwardStream(ctx context.Context, server *Server, streamName string, channel ssh.Channel) { +// concurrentStream is a concurrency safe io.ReadWriter. +type concurrentStream struct { + sync.RWMutex + stream io.ReadWriter +} + +func newConcurrentStream(rw io.ReadWriter) *concurrentStream { + return &concurrentStream{stream: rw} +} + +func (cs *concurrentStream) Read(b []byte) (int, error) { + cs.RLock() + defer cs.RUnlock() + return cs.stream.Read(b) +} + +func (cs *concurrentStream) Write(b []byte) (int, error) { + cs.Lock() + defer cs.Unlock() + return cs.stream.Write(b) +} + +// forwardStream does a bi-directional copy of the stream <-> with the SSH channel. The io.Copy +// runs until an error is encountered. +func forwardStream(ctx context.Context, server *Server, streamName string, channel ssh.Channel) (err error) { simpleStreamName := strings.TrimPrefix(streamName, "stream-transport-") stream, found := server.streams[simpleStreamName] if !found { - server.errCh <- fmt.Errorf("stream '%s' not found", simpleStreamName) - return + return fmt.Errorf("stream '%s' not found", simpleStreamName) } + defer func() { + if closeErr := channel.Close(); err == nil && closeErr != io.EOF { + err = closeErr + } + }() + errc := make(chan error, 2) copy := func(dst io.Writer, src io.Reader) { if _, err := io.Copy(dst, src); err != nil { - fmt.Println(err) - server.errCh <- fmt.Errorf("io copy: %w", err) - return + errc <- err } } - go copy(stream, channel) - go copy(channel, stream) + csStream := newConcurrentStream(stream) + go copy(csStream, channel) + go copy(channel, csStream) - <-ctx.Done() // TODO(josebalius): improve this + return awaitError(ctx, errc) } -func handleNewChannel(server *Server, channel ssh.Channel) { +func handleChannel(server *Server, channel ssh.Channel) { stream := jsonrpc2.NewBufferedStream(channel, jsonrpc2.VSCodeObjectCodec{}) jsonrpc2.NewConn(context.Background(), stream, newRPCHandler(server)) } @@ -226,20 +323,22 @@ func newRPCHandler(server *Server) *rpcHandler { return &rpcHandler{server} } +// Handle satisfies the jsonrpc2 pkg handler interface. It tries to find a mocked +// RPC service method and if found, it invokes the handler and replies to the request. func (r *rpcHandler) Handle(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2.Request) { handler, found := r.server.services[req.Method] if !found { - r.server.errCh <- fmt.Errorf("RPC Method: '%s' not serviced", req.Method) + sendError(r.server.errCh, fmt.Errorf("RPC Method: '%s' not serviced", req.Method)) return } result, err := handler(req) if err != nil { - r.server.errCh <- fmt.Errorf("error handling: '%s': %w", req.Method, err) + sendError(r.server.errCh, fmt.Errorf("error handling: '%s': %w", req.Method, err)) return } if err := conn.Reply(ctx, req.ID, result); err != nil { - r.server.errCh <- fmt.Errorf("error replying: %w", err) + sendError(r.server.errCh, fmt.Errorf("error replying: %w", err)) } }