From fbf0d286729dd355889a8caad0b80be71e4ae601 Mon Sep 17 00:00:00 2001 From: Jose Garcia Date: Fri, 6 Aug 2021 01:03:03 +0000 Subject: [PATCH] port forwarding err handling and test refactors --- example/main.go | 120 ---------------------------------------------- port_forwarder.go | 21 +++++--- test/server.go | 105 +++++++++++----------------------------- test/socket.go | 77 +++++++++++++++++++++++++++++ 4 files changed, 121 insertions(+), 202 deletions(-) delete mode 100644 example/main.go create mode 100644 test/socket.go diff --git a/example/main.go b/example/main.go deleted file mode 100644 index e9347bd14..000000000 --- a/example/main.go +++ /dev/null @@ -1,120 +0,0 @@ -package main - -import ( - "bufio" - "context" - "flag" - "fmt" - "log" - "os" - "time" - - "github.com/github/go-liveshare" -) - -var workspaceIdFlag = flag.String("w", "", "workspace session id") - -func init() { - flag.Parse() -} - -func main() { - liveShare, err := liveshare.New( - liveshare.WithWorkspaceID(*workspaceIdFlag), - liveshare.WithToken(os.Getenv("CODESPACE_TOKEN")), - ) - if err != nil { - log.Fatal(fmt.Errorf("error creating liveshare: %v", err)) - } - - ctx := context.Background() - liveShareClient := liveShare.NewClient() - if err := liveShareClient.Join(ctx); err != nil { - log.Fatal(fmt.Errorf("error joining liveshare with client: %v", err)) - } - - terminal, err := liveShareClient.NewTerminal() - if err != nil { - log.Fatal(fmt.Errorf("error creating liveshare terminal")) - } - - containerID, err := getContainerID(ctx, terminal) - if err != nil { - log.Fatal(fmt.Errorf("error getting container id: %v", err)) - } - - if err := setupSSH(ctx, terminal, containerID); err != nil { - log.Fatal(fmt.Errorf("error setting up ssh: %v", err)) - } - - fmt.Println("Starting server...") - - server, err := liveShareClient.NewServer() - if err != nil { - log.Fatal(fmt.Errorf("error creating server: %v", err)) - } - - fmt.Println("Starting sharing...") - if err := server.StartSharing(ctx, "sshd", 2222); err != nil { - log.Fatal(fmt.Errorf("error server sharing: %v", err)) - } - - portForwarder := liveshare.NewLocalPortForwarder(liveShareClient, server, 2222) - - fmt.Println("Listening on port 2222") - if err := portForwarder.Start(ctx); err != nil { - log.Fatal(fmt.Errorf("error forwarding port: %v", err)) - } -} - -func setupSSH(ctx context.Context, terminal *liveshare.Terminal, containerID string) error { - cmd := terminal.NewCommand( - "/", - fmt.Sprintf("/usr/bin/docker exec -t %s /bin/bash -c \"echo -e \\\"testpwd1\\ntestpwd1\\n\\\" | sudo passwd codespace;/usr/local/share/ssh-init.sh\"", containerID), - ) - stream, err := cmd.Run(ctx) - if err != nil { - return fmt.Errorf("error running command: %v", err) - } - - scanner := bufio.NewScanner(stream) - scanner.Scan() - - fmt.Println("> Debug:", scanner.Text()) - if err := scanner.Err(); err != nil { - return fmt.Errorf("error scanning stream: %v", err) - } - - if err := stream.Close(); err != nil { - return fmt.Errorf("error closing stream: %v", err) - } - - time.Sleep(2 * time.Second) - - return nil -} - -func getContainerID(ctx context.Context, terminal *liveshare.Terminal) (string, error) { - cmd := terminal.NewCommand( - "/", - "/usr/bin/docker ps -aq --filter label=Type=codespaces --filter status=running", - ) - stream, err := cmd.Run(ctx) - if err != nil { - return "", fmt.Errorf("error running command: %v", err) - } - - scanner := bufio.NewScanner(stream) - scanner.Scan() - - containerID := scanner.Text() - if err := scanner.Err(); err != nil { - return "", fmt.Errorf("error scanning stream: %v", err) - } - - if err := stream.Close(); err != nil { - return "", fmt.Errorf("error closing stream: %v", err) - } - - return containerID, nil -} diff --git a/port_forwarder.go b/port_forwarder.go index 0a049d586..3a73e3fce 100644 --- a/port_forwarder.go +++ b/port_forwarder.go @@ -33,13 +33,22 @@ func (l *PortForwarder) Start(ctx context.Context) error { return fmt.Errorf("error listening on tcp port: %v", err) } - for { - conn, err := ln.Accept() - if err != nil { - return fmt.Errorf("error accepting incoming connection: %v", err) - } + go func() { + for { + conn, err := ln.Accept() + if err != nil { + l.errCh <- fmt.Errorf("error accepting incoming connection: %v", err) + } - go l.handleConnection(ctx, conn) + go l.handleConnection(ctx, conn) + } + }() + + select { + case err := <-l.errCh: + return err + case <-ctx.Done(): + return ln.Close() } return nil diff --git a/test/server.go b/test/server.go index a52d31ab9..159a2a982 100644 --- a/test/server.go +++ b/test/server.go @@ -5,19 +5,43 @@ import ( "errors" "fmt" "io" - "io/ioutil" "net/http" "net/http/httptest" - "path/filepath" "strings" - "sync" - "time" "github.com/gorilla/websocket" "github.com/sourcegraph/jsonrpc2" "golang.org/x/crypto/ssh" ) +const sshPrivateKey = `-----BEGIN RSA PRIVATE KEY----- +MIIEogIBAAKCAQEAp/Jmzy/HaPNx5Bug09FX5Q/KGY4G9c4DfplhWrn31OQCqNiT +ZSLd46rdXC75liHzE7e5Ic0RJN61cYN9SNArjvEXx2vvs7szhwO7LonwPOvpYpUf +daayrgbr6S46plpx+hEZ1kO/6BqMgFuvnkIVThrEyx5b48ll8zgDABsYrKF8/p1V +SjGfb+bLwjn1NtnZF2prBG5P4ZtMR06HaPglLqBJhmc0ZMG5IZGUE7ew/VrPDqdC +f1v4XvvGiU4BLoKYy4QOhyrCGh9Uk/9u0Ea56M2bh4RqwhbpR8m7TYJZ0DVMLbGW +8C+4lCWp+xRyBNxAQh8qeQVCxYl02hPE4bXLGQIDAQABAoIBAEoVPk6UZ+UexhV2 +LnphNOFhFqgxI1bYWmhE5lHsCKuLLLUoW9RYDgL4gw6/1e7o6N3AxFRpre9Soj0B +YIl28k/qf6/DKAhjQnaDKdV8mVF2Swvmdesi7lyfxv6kGtD4wqApXPlMB2IuG94f +E5e+1MEQQ9DJgoU3eNZR1dj9GuRC3PyzPcNNJ2R/MMGFw3sOOVcLOgAukotoicuL +0SiL51rHPQu8a5/darH9EltN1GFeceJSDDhgqMP5T8Tp7g/c3//H6szon4H9W+uN +Z3UrImJ+teJjFOaVDqN93+J2eQSUk0lCPGQCd4U9I4AGDGyU6ucdcLQ58Aha9gmU +uQwkfKUCgYEA0UkuPOSDE9dbXe+yhsbOwMb1kKzJYgFDKjRTSP7D9BOMZu4YyASo +J95R4DWjePlDopafG2tNJoWX+CwUl7Uld1R3Ex6xHBa2B7hwZj860GZtr7D4mdWc +DTVjczAjp4P0K1MIFYQui1mVJterkjKuePiI6q/27L1c2jIa/39BWBcCgYEAzW8R +MFZamVw3eA2JYSpBuqhQgE5gX5IWrmVJZSUhpAQTNG/A4nxf7WGtjy9p99tm0RMb +ld05+sOmNLrzw8Pq8SBpFOd+MAca7lPLS1A2CoaAHbOqRqrzVcZ4EZ2jB3WjoLoq +yctwslGb9KmrhBCdcwT48aPAYUIJCZdqEen2xE8CgYBoMowvywGrvjwCH9X9njvP +5P7cAfrdrY04FQcmP5lmCtmLYZ267/6couaWv33dPBU9fMpIh3rI5BiOebvi8FBw +AgCq50v8lR4Z5+0mKvLoUSbpIy4SwTRJqzwRXHVT8LF/ZH6Q39egj4Bf716/kjYl +im/4kJVatsjk5a9lZ4EsDwKBgERkJ3rKJNtNggHrr8KzSLKVekdc0GTAw+BHRAny +NKLf4Gzij3pXIbBrhlZW2JZ1amNMUzCvN7AuFlUTsDeKL9saiSE2eCIRG3wgVVu7 +VmJmqJw6xgNEwkHaEvr6Wd4P4euOTtRjcB9NX/gxzDHpPiGelCoN8+vtCgkxaVSR +aV+tAoGAO4HtLOfBAVDNbVXa27aJAjQSUq8qfkwUNJNz+rwgpVQahfiVkyqAPCQM +IfRJxKWb0Wbt9ojw3AowK/k0d3LZA7FS41JSiiGKIllSGb+i7JKqKW7RHLA3VJ/E +Bq5TLNIbUzPVNVwRcGjUYpOhKU6EIw8phTJOvxnUC+g6MVqBP8U= +-----END RSA PRIVATE KEY-----` + type Server struct { password string services map[string]RpcHandleFunc @@ -41,11 +65,7 @@ func NewServer(opts ...ServerOption) (*Server, error) { server.sshConfig = &ssh.ServerConfig{ PasswordCallback: sshPasswordCallback(server.password), } - b, err := ioutil.ReadFile(filepath.Join("test", "private.key")) - if err != nil { - return nil, fmt.Errorf("error reading private.key: %v", err) - } - privateKey, err := ssh.ParsePrivateKey(b) + privateKey, err := ssh.ParsePrivateKey([]byte(sshPrivateKey)) if err != nil { return nil, fmt.Errorf("error parsing key: %v", err) } @@ -221,70 +241,3 @@ func (r *rpcHandler) Handle(ctx context.Context, conn *jsonrpc2.Conn, req *jsonr r.server.errCh <- fmt.Errorf("error replying: %v", err) } } - -type socketConn struct { - *websocket.Conn - - reader io.Reader - writeMutex sync.Mutex - readMutex sync.Mutex -} - -func newSocketConn(conn *websocket.Conn) *socketConn { - return &socketConn{Conn: conn} -} - -func (s *socketConn) Read(b []byte) (int, error) { - s.readMutex.Lock() - defer s.readMutex.Unlock() - - if s.reader == nil { - msgType, r, err := s.Conn.NextReader() - if err != nil { - return 0, fmt.Errorf("error getting next reader: %v", err) - } - if msgType != websocket.BinaryMessage { - return 0, fmt.Errorf("invalid message type") - } - s.reader = r - } - - bytesRead, err := s.reader.Read(b) - if err != nil { - s.reader = nil - - if err == io.EOF { - err = nil - } - } - - return bytesRead, err -} - -func (s *socketConn) Write(b []byte) (int, error) { - s.writeMutex.Lock() - defer s.writeMutex.Unlock() - - w, err := s.Conn.NextWriter(websocket.BinaryMessage) - if err != nil { - return 0, fmt.Errorf("error getting next writer: %v", err) - } - - n, err := w.Write(b) - if err != nil { - return 0, fmt.Errorf("error writing: %v", err) - } - - if err := w.Close(); err != nil { - return 0, fmt.Errorf("error closing writer: %v", err) - } - - return n, nil -} - -func (s *socketConn) SetDeadline(deadline time.Time) error { - if err := s.Conn.SetReadDeadline(deadline); err != nil { - return err - } - return s.Conn.SetWriteDeadline(deadline) -} diff --git a/test/socket.go b/test/socket.go new file mode 100644 index 000000000..9a2d92491 --- /dev/null +++ b/test/socket.go @@ -0,0 +1,77 @@ +package livesharetest + +import ( + "fmt" + "io" + "sync" + "time" + + "github.com/gorilla/websocket" +) + +type socketConn struct { + *websocket.Conn + + reader io.Reader + writeMutex sync.Mutex + readMutex sync.Mutex +} + +func newSocketConn(conn *websocket.Conn) *socketConn { + return &socketConn{Conn: conn} +} + +func (s *socketConn) Read(b []byte) (int, error) { + s.readMutex.Lock() + defer s.readMutex.Unlock() + + if s.reader == nil { + msgType, r, err := s.Conn.NextReader() + if err != nil { + return 0, fmt.Errorf("error getting next reader: %v", err) + } + if msgType != websocket.BinaryMessage { + return 0, fmt.Errorf("invalid message type") + } + s.reader = r + } + + bytesRead, err := s.reader.Read(b) + if err != nil { + s.reader = nil + + if err == io.EOF { + err = nil + } + } + + return bytesRead, err +} + +func (s *socketConn) Write(b []byte) (int, error) { + s.writeMutex.Lock() + defer s.writeMutex.Unlock() + + w, err := s.Conn.NextWriter(websocket.BinaryMessage) + if err != nil { + return 0, fmt.Errorf("error getting next writer: %v", err) + } + + n, err := w.Write(b) + if err != nil { + return 0, fmt.Errorf("error writing: %v", err) + } + + if err := w.Close(); err != nil { + return 0, fmt.Errorf("error closing writer: %v", err) + } + + return n, nil +} + +func (s *socketConn) SetDeadline(deadline time.Time) error { + if err := s.Conn.SetReadDeadline(deadline); err != nil { + return err + } + return s.Conn.SetWriteDeadline(deadline) +}