Merge pull request #4 from github/jg/port-forwarding-errors-test-server

Port forwarding improvements & slight refactor
This commit is contained in:
Jose Garcia 2021-08-13 08:41:43 -04:00 committed by GitHub
commit 79111d85ac
4 changed files with 121 additions and 203 deletions

View file

@ -1,120 +0,0 @@
package main
import (
"bufio"
"context"
"flag"
"fmt"
"log"
"os"
"time"
"github.com/github/go-liveshare"
)
var workspaceIdFlag = flag.String("w", "", "workspace session id")
func init() {
flag.Parse()
}
func main() {
liveShare, err := liveshare.New(
liveshare.WithWorkspaceID(*workspaceIdFlag),
liveshare.WithToken(os.Getenv("CODESPACE_TOKEN")),
)
if err != nil {
log.Fatal(fmt.Errorf("error creating liveshare: %v", err))
}
ctx := context.Background()
liveShareClient := liveShare.NewClient()
if err := liveShareClient.Join(ctx); err != nil {
log.Fatal(fmt.Errorf("error joining liveshare with client: %v", err))
}
terminal, err := liveShareClient.NewTerminal()
if err != nil {
log.Fatal(fmt.Errorf("error creating liveshare terminal"))
}
containerID, err := getContainerID(ctx, terminal)
if err != nil {
log.Fatal(fmt.Errorf("error getting container id: %v", err))
}
if err := setupSSH(ctx, terminal, containerID); err != nil {
log.Fatal(fmt.Errorf("error setting up ssh: %v", err))
}
fmt.Println("Starting server...")
server, err := liveShareClient.NewServer()
if err != nil {
log.Fatal(fmt.Errorf("error creating server: %v", err))
}
fmt.Println("Starting sharing...")
if err := server.StartSharing(ctx, "sshd", 2222); err != nil {
log.Fatal(fmt.Errorf("error server sharing: %v", err))
}
portForwarder := liveshare.NewLocalPortForwarder(liveShareClient, server, 2222)
fmt.Println("Listening on port 2222")
if err := portForwarder.Start(ctx); err != nil {
log.Fatal(fmt.Errorf("error forwarding port: %v", err))
}
}
func setupSSH(ctx context.Context, terminal *liveshare.Terminal, containerID string) error {
cmd := terminal.NewCommand(
"/",
fmt.Sprintf("/usr/bin/docker exec -t %s /bin/bash -c \"echo -e \\\"testpwd1\\ntestpwd1\\n\\\" | sudo passwd codespace;/usr/local/share/ssh-init.sh\"", containerID),
)
stream, err := cmd.Run(ctx)
if err != nil {
return fmt.Errorf("error running command: %v", err)
}
scanner := bufio.NewScanner(stream)
scanner.Scan()
fmt.Println("> Debug:", scanner.Text())
if err := scanner.Err(); err != nil {
return fmt.Errorf("error scanning stream: %v", err)
}
if err := stream.Close(); err != nil {
return fmt.Errorf("error closing stream: %v", err)
}
time.Sleep(2 * time.Second)
return nil
}
func getContainerID(ctx context.Context, terminal *liveshare.Terminal) (string, error) {
cmd := terminal.NewCommand(
"/",
"/usr/bin/docker ps -aq --filter label=Type=codespaces --filter status=running",
)
stream, err := cmd.Run(ctx)
if err != nil {
return "", fmt.Errorf("error running command: %v", err)
}
scanner := bufio.NewScanner(stream)
scanner.Scan()
containerID := scanner.Text()
if err := scanner.Err(); err != nil {
return "", fmt.Errorf("error scanning stream: %v", err)
}
if err := stream.Close(); err != nil {
return "", fmt.Errorf("error closing stream: %v", err)
}
return containerID, nil
}

View file

@ -33,13 +33,22 @@ func (l *PortForwarder) Start(ctx context.Context) error {
return fmt.Errorf("error listening on tcp port: %v", err)
}
for {
conn, err := ln.Accept()
if err != nil {
return fmt.Errorf("error accepting incoming connection: %v", err)
}
go func() {
for {
conn, err := ln.Accept()
if err != nil {
l.errCh <- fmt.Errorf("error accepting incoming connection: %v", err)
}
go l.handleConnection(ctx, conn)
go l.handleConnection(ctx, conn)
}
}()
select {
case err := <-l.errCh:
return err
case <-ctx.Done():
return ln.Close()
}
return nil
@ -54,7 +63,6 @@ func (l *PortForwarder) handleConnection(ctx context.Context, conn net.Conn) {
copyConn := func(writer io.Writer, reader io.Reader) {
if _, err := io.Copy(writer, reader); err != nil {
fmt.Println(err)
channel.Close()
conn.Close()
if err != io.EOF {

View file

@ -5,19 +5,43 @@ import (
"errors"
"fmt"
"io"
"io/ioutil"
"net/http"
"net/http/httptest"
"path/filepath"
"strings"
"sync"
"time"
"github.com/gorilla/websocket"
"github.com/sourcegraph/jsonrpc2"
"golang.org/x/crypto/ssh"
)
const sshPrivateKey = `-----BEGIN RSA PRIVATE KEY-----
MIIEogIBAAKCAQEAp/Jmzy/HaPNx5Bug09FX5Q/KGY4G9c4DfplhWrn31OQCqNiT
ZSLd46rdXC75liHzE7e5Ic0RJN61cYN9SNArjvEXx2vvs7szhwO7LonwPOvpYpUf
daayrgbr6S46plpx+hEZ1kO/6BqMgFuvnkIVThrEyx5b48ll8zgDABsYrKF8/p1V
SjGfb+bLwjn1NtnZF2prBG5P4ZtMR06HaPglLqBJhmc0ZMG5IZGUE7ew/VrPDqdC
f1v4XvvGiU4BLoKYy4QOhyrCGh9Uk/9u0Ea56M2bh4RqwhbpR8m7TYJZ0DVMLbGW
8C+4lCWp+xRyBNxAQh8qeQVCxYl02hPE4bXLGQIDAQABAoIBAEoVPk6UZ+UexhV2
LnphNOFhFqgxI1bYWmhE5lHsCKuLLLUoW9RYDgL4gw6/1e7o6N3AxFRpre9Soj0B
YIl28k/qf6/DKAhjQnaDKdV8mVF2Swvmdesi7lyfxv6kGtD4wqApXPlMB2IuG94f
E5e+1MEQQ9DJgoU3eNZR1dj9GuRC3PyzPcNNJ2R/MMGFw3sOOVcLOgAukotoicuL
0SiL51rHPQu8a5/darH9EltN1GFeceJSDDhgqMP5T8Tp7g/c3//H6szon4H9W+uN
Z3UrImJ+teJjFOaVDqN93+J2eQSUk0lCPGQCd4U9I4AGDGyU6ucdcLQ58Aha9gmU
uQwkfKUCgYEA0UkuPOSDE9dbXe+yhsbOwMb1kKzJYgFDKjRTSP7D9BOMZu4YyASo
J95R4DWjePlDopafG2tNJoWX+CwUl7Uld1R3Ex6xHBa2B7hwZj860GZtr7D4mdWc
DTVjczAjp4P0K1MIFYQui1mVJterkjKuePiI6q/27L1c2jIa/39BWBcCgYEAzW8R
MFZamVw3eA2JYSpBuqhQgE5gX5IWrmVJZSUhpAQTNG/A4nxf7WGtjy9p99tm0RMb
ld05+sOmNLrzw8Pq8SBpFOd+MAca7lPLS1A2CoaAHbOqRqrzVcZ4EZ2jB3WjoLoq
yctwslGb9KmrhBCdcwT48aPAYUIJCZdqEen2xE8CgYBoMowvywGrvjwCH9X9njvP
5P7cAfrdrY04FQcmP5lmCtmLYZ267/6couaWv33dPBU9fMpIh3rI5BiOebvi8FBw
AgCq50v8lR4Z5+0mKvLoUSbpIy4SwTRJqzwRXHVT8LF/ZH6Q39egj4Bf716/kjYl
im/4kJVatsjk5a9lZ4EsDwKBgERkJ3rKJNtNggHrr8KzSLKVekdc0GTAw+BHRAny
NKLf4Gzij3pXIbBrhlZW2JZ1amNMUzCvN7AuFlUTsDeKL9saiSE2eCIRG3wgVVu7
VmJmqJw6xgNEwkHaEvr6Wd4P4euOTtRjcB9NX/gxzDHpPiGelCoN8+vtCgkxaVSR
aV+tAoGAO4HtLOfBAVDNbVXa27aJAjQSUq8qfkwUNJNz+rwgpVQahfiVkyqAPCQM
IfRJxKWb0Wbt9ojw3AowK/k0d3LZA7FS41JSiiGKIllSGb+i7JKqKW7RHLA3VJ/E
Bq5TLNIbUzPVNVwRcGjUYpOhKU6EIw8phTJOvxnUC+g6MVqBP8U=
-----END RSA PRIVATE KEY-----`
type Server struct {
password string
services map[string]RpcHandleFunc
@ -41,11 +65,7 @@ func NewServer(opts ...ServerOption) (*Server, error) {
server.sshConfig = &ssh.ServerConfig{
PasswordCallback: sshPasswordCallback(server.password),
}
b, err := ioutil.ReadFile(filepath.Join("test", "private.key"))
if err != nil {
return nil, fmt.Errorf("error reading private.key: %v", err)
}
privateKey, err := ssh.ParsePrivateKey(b)
privateKey, err := ssh.ParsePrivateKey([]byte(sshPrivateKey))
if err != nil {
return nil, fmt.Errorf("error parsing key: %v", err)
}
@ -221,70 +241,3 @@ func (r *rpcHandler) Handle(ctx context.Context, conn *jsonrpc2.Conn, req *jsonr
r.server.errCh <- fmt.Errorf("error replying: %v", err)
}
}
type socketConn struct {
*websocket.Conn
reader io.Reader
writeMutex sync.Mutex
readMutex sync.Mutex
}
func newSocketConn(conn *websocket.Conn) *socketConn {
return &socketConn{Conn: conn}
}
func (s *socketConn) Read(b []byte) (int, error) {
s.readMutex.Lock()
defer s.readMutex.Unlock()
if s.reader == nil {
msgType, r, err := s.Conn.NextReader()
if err != nil {
return 0, fmt.Errorf("error getting next reader: %v", err)
}
if msgType != websocket.BinaryMessage {
return 0, fmt.Errorf("invalid message type")
}
s.reader = r
}
bytesRead, err := s.reader.Read(b)
if err != nil {
s.reader = nil
if err == io.EOF {
err = nil
}
}
return bytesRead, err
}
func (s *socketConn) Write(b []byte) (int, error) {
s.writeMutex.Lock()
defer s.writeMutex.Unlock()
w, err := s.Conn.NextWriter(websocket.BinaryMessage)
if err != nil {
return 0, fmt.Errorf("error getting next writer: %v", err)
}
n, err := w.Write(b)
if err != nil {
return 0, fmt.Errorf("error writing: %v", err)
}
if err := w.Close(); err != nil {
return 0, fmt.Errorf("error closing writer: %v", err)
}
return n, nil
}
func (s *socketConn) SetDeadline(deadline time.Time) error {
if err := s.Conn.SetReadDeadline(deadline); err != nil {
return err
}
return s.Conn.SetWriteDeadline(deadline)
}

77
test/socket.go Normal file
View file

@ -0,0 +1,77 @@
package livesharetest
import (
"fmt"
"io"
"sync"
"time"
"github.com/gorilla/websocket"
)
type socketConn struct {
*websocket.Conn
reader io.Reader
writeMutex sync.Mutex
readMutex sync.Mutex
}
func newSocketConn(conn *websocket.Conn) *socketConn {
return &socketConn{Conn: conn}
}
func (s *socketConn) Read(b []byte) (int, error) {
s.readMutex.Lock()
defer s.readMutex.Unlock()
if s.reader == nil {
msgType, r, err := s.Conn.NextReader()
if err != nil {
return 0, fmt.Errorf("error getting next reader: %v", err)
}
if msgType != websocket.BinaryMessage {
return 0, fmt.Errorf("invalid message type")
}
s.reader = r
}
bytesRead, err := s.reader.Read(b)
if err != nil {
s.reader = nil
if err == io.EOF {
err = nil
}
}
return bytesRead, err
}
func (s *socketConn) Write(b []byte) (int, error) {
s.writeMutex.Lock()
defer s.writeMutex.Unlock()
w, err := s.Conn.NextWriter(websocket.BinaryMessage)
if err != nil {
return 0, fmt.Errorf("error getting next writer: %v", err)
}
n, err := w.Write(b)
if err != nil {
return 0, fmt.Errorf("error writing: %v", err)
}
if err := w.Close(); err != nil {
return 0, fmt.Errorf("error closing writer: %v", err)
}
return n, nil
}
func (s *socketConn) SetDeadline(deadline time.Time) error {
if err := s.Conn.SetReadDeadline(deadline); err != nil {
return err
}
return s.Conn.SetWriteDeadline(deadline)
}