From 958990cef83defd3c70278d3b4597c9165641128 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Thu, 23 Sep 2021 11:47:52 -0400 Subject: [PATCH] More linter fixes --- internal/liveshare/session_test.go | 6 +----- internal/liveshare/test/server.go | 16 +++++++++------- internal/liveshare/test/socket.go | 6 +++--- 3 files changed, 13 insertions(+), 15 deletions(-) diff --git a/internal/liveshare/session_test.go b/internal/liveshare/session_test.go index 0ffdfe136..c830c33b1 100644 --- a/internal/liveshare/session_test.go +++ b/internal/liveshare/session_test.go @@ -71,11 +71,7 @@ func TestServerStartSharing(t *testing.T) { testServer, session, err := makeMockSession( livesharetest.WithService("serverSharing.startSharing", startSharing), ) - defer func() { - if err := testServer.Close(); err != nil { - t.Errorf("failed to close test server: %w", err) - } - }() + defer testServer.Close() if err != nil { t.Errorf("error creating mock session: %v", err) diff --git a/internal/liveshare/test/server.go b/internal/liveshare/test/server.go index 8f80d1bce..9b898dafb 100644 --- a/internal/liveshare/test/server.go +++ b/internal/liveshare/test/server.go @@ -138,6 +138,9 @@ var upgrader = websocket.Upgrader{} func makeConnection(server *Server) http.HandlerFunc { return func(w http.ResponseWriter, req *http.Request) { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + if server.relaySAS != "" { // validate the sas key sasParam := req.URL.Query().Get("sb-hc-token") @@ -167,13 +170,13 @@ func makeConnection(server *Server) http.HandlerFunc { server.errCh <- fmt.Errorf("error accepting new channel: %w", err) return } - go handleNewRequests(server, ch, reqs) + go handleNewRequests(ctx, server, ch, reqs) go handleNewChannel(server, ch) } } } -func handleNewRequests(server *Server, channel ssh.Channel, reqs <-chan *ssh.Request) { +func handleNewRequests(ctx context.Context, server *Server, channel ssh.Channel, reqs <-chan *ssh.Request) { for req := range reqs { if req.WantReply { if err := req.Reply(true, nil); err != nil { @@ -181,16 +184,16 @@ func handleNewRequests(server *Server, channel ssh.Channel, reqs <-chan *ssh.Req } } if strings.HasPrefix(req.Type, "stream-transport") { - forwardStream(server, req.Type, channel) + forwardStream(ctx, server, req.Type, channel) } } } -func forwardStream(server *Server, streamName string, channel ssh.Channel) { +func forwardStream(ctx context.Context, server *Server, streamName string, channel ssh.Channel) { simpleStreamName := strings.TrimPrefix(streamName, "stream-transport-") stream, found := server.streams[simpleStreamName] if !found { - server.errCh <- fmt.Errorf("stream '%w' not found", simpleStreamName) + server.errCh <- fmt.Errorf("stream '%s' not found", simpleStreamName) return } @@ -205,8 +208,7 @@ func forwardStream(server *Server, streamName string, channel ssh.Channel) { go copy(stream, channel) go copy(channel, stream) - for { - } + <-ctx.Done() // TODO(josebalius): improve this } func handleNewChannel(server *Server, channel ssh.Channel) { diff --git a/internal/liveshare/test/socket.go b/internal/liveshare/test/socket.go index 9a2d92491..0a7a8baf0 100644 --- a/internal/liveshare/test/socket.go +++ b/internal/liveshare/test/socket.go @@ -28,7 +28,7 @@ func (s *socketConn) Read(b []byte) (int, error) { if s.reader == nil { msgType, r, err := s.Conn.NextReader() if err != nil { - return 0, fmt.Errorf("error getting next reader: %v", err) + return 0, fmt.Errorf("error getting next reader: %w", err) } if msgType != websocket.BinaryMessage { return 0, fmt.Errorf("invalid message type") @@ -54,7 +54,7 @@ func (s *socketConn) Write(b []byte) (int, error) { w, err := s.Conn.NextWriter(websocket.BinaryMessage) if err != nil { - return 0, fmt.Errorf("error getting next writer: %v", err) + return 0, fmt.Errorf("error getting next writer: %w", err) } n, err := w.Write(b) @@ -63,7 +63,7 @@ func (s *socketConn) Write(b []byte) (int, error) { } if err := w.Close(); err != nil { - return 0, fmt.Errorf("error closing writer: %v", err) + return 0, fmt.Errorf("error closing writer: %w", err) } return n, nil