More linter fixes

This commit is contained in:
Jose Garcia 2021-09-23 11:47:52 -04:00
parent d0c65e5490
commit 958990cef8
3 changed files with 13 additions and 15 deletions

View file

@ -71,11 +71,7 @@ func TestServerStartSharing(t *testing.T) {
testServer, session, err := makeMockSession(
livesharetest.WithService("serverSharing.startSharing", startSharing),
)
defer func() {
if err := testServer.Close(); err != nil {
t.Errorf("failed to close test server: %w", err)
}
}()
defer testServer.Close()
if err != nil {
t.Errorf("error creating mock session: %v", err)

View file

@ -138,6 +138,9 @@ var upgrader = websocket.Upgrader{}
func makeConnection(server *Server) http.HandlerFunc {
return func(w http.ResponseWriter, req *http.Request) {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
if server.relaySAS != "" {
// validate the sas key
sasParam := req.URL.Query().Get("sb-hc-token")
@ -167,13 +170,13 @@ func makeConnection(server *Server) http.HandlerFunc {
server.errCh <- fmt.Errorf("error accepting new channel: %w", err)
return
}
go handleNewRequests(server, ch, reqs)
go handleNewRequests(ctx, server, ch, reqs)
go handleNewChannel(server, ch)
}
}
}
func handleNewRequests(server *Server, channel ssh.Channel, reqs <-chan *ssh.Request) {
func handleNewRequests(ctx context.Context, server *Server, channel ssh.Channel, reqs <-chan *ssh.Request) {
for req := range reqs {
if req.WantReply {
if err := req.Reply(true, nil); err != nil {
@ -181,16 +184,16 @@ func handleNewRequests(server *Server, channel ssh.Channel, reqs <-chan *ssh.Req
}
}
if strings.HasPrefix(req.Type, "stream-transport") {
forwardStream(server, req.Type, channel)
forwardStream(ctx, server, req.Type, channel)
}
}
}
func forwardStream(server *Server, streamName string, channel ssh.Channel) {
func forwardStream(ctx context.Context, server *Server, streamName string, channel ssh.Channel) {
simpleStreamName := strings.TrimPrefix(streamName, "stream-transport-")
stream, found := server.streams[simpleStreamName]
if !found {
server.errCh <- fmt.Errorf("stream '%w' not found", simpleStreamName)
server.errCh <- fmt.Errorf("stream '%s' not found", simpleStreamName)
return
}
@ -205,8 +208,7 @@ func forwardStream(server *Server, streamName string, channel ssh.Channel) {
go copy(stream, channel)
go copy(channel, stream)
for {
}
<-ctx.Done() // TODO(josebalius): improve this
}
func handleNewChannel(server *Server, channel ssh.Channel) {

View file

@ -28,7 +28,7 @@ func (s *socketConn) Read(b []byte) (int, error) {
if s.reader == nil {
msgType, r, err := s.Conn.NextReader()
if err != nil {
return 0, fmt.Errorf("error getting next reader: %v", err)
return 0, fmt.Errorf("error getting next reader: %w", err)
}
if msgType != websocket.BinaryMessage {
return 0, fmt.Errorf("invalid message type")
@ -54,7 +54,7 @@ func (s *socketConn) Write(b []byte) (int, error) {
w, err := s.Conn.NextWriter(websocket.BinaryMessage)
if err != nil {
return 0, fmt.Errorf("error getting next writer: %v", err)
return 0, fmt.Errorf("error getting next writer: %w", err)
}
n, err := w.Write(b)
@ -63,7 +63,7 @@ func (s *socketConn) Write(b []byte) (int, error) {
}
if err := w.Close(); err != nil {
return 0, fmt.Errorf("error closing writer: %v", err)
return 0, fmt.Errorf("error closing writer: %w", err)
}
return n, nil