port forwarding err handling and test refactors
This commit is contained in:
parent
cd99399290
commit
fbf0d28672
4 changed files with 121 additions and 202 deletions
120
example/main.go
120
example/main.go
|
|
@ -1,120 +0,0 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/github/go-liveshare"
|
||||
)
|
||||
|
||||
var workspaceIdFlag = flag.String("w", "", "workspace session id")
|
||||
|
||||
func init() {
|
||||
flag.Parse()
|
||||
}
|
||||
|
||||
func main() {
|
||||
liveShare, err := liveshare.New(
|
||||
liveshare.WithWorkspaceID(*workspaceIdFlag),
|
||||
liveshare.WithToken(os.Getenv("CODESPACE_TOKEN")),
|
||||
)
|
||||
if err != nil {
|
||||
log.Fatal(fmt.Errorf("error creating liveshare: %v", err))
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
liveShareClient := liveShare.NewClient()
|
||||
if err := liveShareClient.Join(ctx); err != nil {
|
||||
log.Fatal(fmt.Errorf("error joining liveshare with client: %v", err))
|
||||
}
|
||||
|
||||
terminal, err := liveShareClient.NewTerminal()
|
||||
if err != nil {
|
||||
log.Fatal(fmt.Errorf("error creating liveshare terminal"))
|
||||
}
|
||||
|
||||
containerID, err := getContainerID(ctx, terminal)
|
||||
if err != nil {
|
||||
log.Fatal(fmt.Errorf("error getting container id: %v", err))
|
||||
}
|
||||
|
||||
if err := setupSSH(ctx, terminal, containerID); err != nil {
|
||||
log.Fatal(fmt.Errorf("error setting up ssh: %v", err))
|
||||
}
|
||||
|
||||
fmt.Println("Starting server...")
|
||||
|
||||
server, err := liveShareClient.NewServer()
|
||||
if err != nil {
|
||||
log.Fatal(fmt.Errorf("error creating server: %v", err))
|
||||
}
|
||||
|
||||
fmt.Println("Starting sharing...")
|
||||
if err := server.StartSharing(ctx, "sshd", 2222); err != nil {
|
||||
log.Fatal(fmt.Errorf("error server sharing: %v", err))
|
||||
}
|
||||
|
||||
portForwarder := liveshare.NewLocalPortForwarder(liveShareClient, server, 2222)
|
||||
|
||||
fmt.Println("Listening on port 2222")
|
||||
if err := portForwarder.Start(ctx); err != nil {
|
||||
log.Fatal(fmt.Errorf("error forwarding port: %v", err))
|
||||
}
|
||||
}
|
||||
|
||||
func setupSSH(ctx context.Context, terminal *liveshare.Terminal, containerID string) error {
|
||||
cmd := terminal.NewCommand(
|
||||
"/",
|
||||
fmt.Sprintf("/usr/bin/docker exec -t %s /bin/bash -c \"echo -e \\\"testpwd1\\ntestpwd1\\n\\\" | sudo passwd codespace;/usr/local/share/ssh-init.sh\"", containerID),
|
||||
)
|
||||
stream, err := cmd.Run(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error running command: %v", err)
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(stream)
|
||||
scanner.Scan()
|
||||
|
||||
fmt.Println("> Debug:", scanner.Text())
|
||||
if err := scanner.Err(); err != nil {
|
||||
return fmt.Errorf("error scanning stream: %v", err)
|
||||
}
|
||||
|
||||
if err := stream.Close(); err != nil {
|
||||
return fmt.Errorf("error closing stream: %v", err)
|
||||
}
|
||||
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getContainerID(ctx context.Context, terminal *liveshare.Terminal) (string, error) {
|
||||
cmd := terminal.NewCommand(
|
||||
"/",
|
||||
"/usr/bin/docker ps -aq --filter label=Type=codespaces --filter status=running",
|
||||
)
|
||||
stream, err := cmd.Run(ctx)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("error running command: %v", err)
|
||||
}
|
||||
|
||||
scanner := bufio.NewScanner(stream)
|
||||
scanner.Scan()
|
||||
|
||||
containerID := scanner.Text()
|
||||
if err := scanner.Err(); err != nil {
|
||||
return "", fmt.Errorf("error scanning stream: %v", err)
|
||||
}
|
||||
|
||||
if err := stream.Close(); err != nil {
|
||||
return "", fmt.Errorf("error closing stream: %v", err)
|
||||
}
|
||||
|
||||
return containerID, nil
|
||||
}
|
||||
|
|
@ -33,13 +33,22 @@ func (l *PortForwarder) Start(ctx context.Context) error {
|
|||
return fmt.Errorf("error listening on tcp port: %v", err)
|
||||
}
|
||||
|
||||
for {
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error accepting incoming connection: %v", err)
|
||||
}
|
||||
go func() {
|
||||
for {
|
||||
conn, err := ln.Accept()
|
||||
if err != nil {
|
||||
l.errCh <- fmt.Errorf("error accepting incoming connection: %v", err)
|
||||
}
|
||||
|
||||
go l.handleConnection(ctx, conn)
|
||||
go l.handleConnection(ctx, conn)
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-l.errCh:
|
||||
return err
|
||||
case <-ctx.Done():
|
||||
return ln.Close()
|
||||
}
|
||||
|
||||
return nil
|
||||
|
|
|
|||
105
test/server.go
105
test/server.go
|
|
@ -5,19 +5,43 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/sourcegraph/jsonrpc2"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
const sshPrivateKey = `-----BEGIN RSA PRIVATE KEY-----
|
||||
MIIEogIBAAKCAQEAp/Jmzy/HaPNx5Bug09FX5Q/KGY4G9c4DfplhWrn31OQCqNiT
|
||||
ZSLd46rdXC75liHzE7e5Ic0RJN61cYN9SNArjvEXx2vvs7szhwO7LonwPOvpYpUf
|
||||
daayrgbr6S46plpx+hEZ1kO/6BqMgFuvnkIVThrEyx5b48ll8zgDABsYrKF8/p1V
|
||||
SjGfb+bLwjn1NtnZF2prBG5P4ZtMR06HaPglLqBJhmc0ZMG5IZGUE7ew/VrPDqdC
|
||||
f1v4XvvGiU4BLoKYy4QOhyrCGh9Uk/9u0Ea56M2bh4RqwhbpR8m7TYJZ0DVMLbGW
|
||||
8C+4lCWp+xRyBNxAQh8qeQVCxYl02hPE4bXLGQIDAQABAoIBAEoVPk6UZ+UexhV2
|
||||
LnphNOFhFqgxI1bYWmhE5lHsCKuLLLUoW9RYDgL4gw6/1e7o6N3AxFRpre9Soj0B
|
||||
YIl28k/qf6/DKAhjQnaDKdV8mVF2Swvmdesi7lyfxv6kGtD4wqApXPlMB2IuG94f
|
||||
E5e+1MEQQ9DJgoU3eNZR1dj9GuRC3PyzPcNNJ2R/MMGFw3sOOVcLOgAukotoicuL
|
||||
0SiL51rHPQu8a5/darH9EltN1GFeceJSDDhgqMP5T8Tp7g/c3//H6szon4H9W+uN
|
||||
Z3UrImJ+teJjFOaVDqN93+J2eQSUk0lCPGQCd4U9I4AGDGyU6ucdcLQ58Aha9gmU
|
||||
uQwkfKUCgYEA0UkuPOSDE9dbXe+yhsbOwMb1kKzJYgFDKjRTSP7D9BOMZu4YyASo
|
||||
J95R4DWjePlDopafG2tNJoWX+CwUl7Uld1R3Ex6xHBa2B7hwZj860GZtr7D4mdWc
|
||||
DTVjczAjp4P0K1MIFYQui1mVJterkjKuePiI6q/27L1c2jIa/39BWBcCgYEAzW8R
|
||||
MFZamVw3eA2JYSpBuqhQgE5gX5IWrmVJZSUhpAQTNG/A4nxf7WGtjy9p99tm0RMb
|
||||
ld05+sOmNLrzw8Pq8SBpFOd+MAca7lPLS1A2CoaAHbOqRqrzVcZ4EZ2jB3WjoLoq
|
||||
yctwslGb9KmrhBCdcwT48aPAYUIJCZdqEen2xE8CgYBoMowvywGrvjwCH9X9njvP
|
||||
5P7cAfrdrY04FQcmP5lmCtmLYZ267/6couaWv33dPBU9fMpIh3rI5BiOebvi8FBw
|
||||
AgCq50v8lR4Z5+0mKvLoUSbpIy4SwTRJqzwRXHVT8LF/ZH6Q39egj4Bf716/kjYl
|
||||
im/4kJVatsjk5a9lZ4EsDwKBgERkJ3rKJNtNggHrr8KzSLKVekdc0GTAw+BHRAny
|
||||
NKLf4Gzij3pXIbBrhlZW2JZ1amNMUzCvN7AuFlUTsDeKL9saiSE2eCIRG3wgVVu7
|
||||
VmJmqJw6xgNEwkHaEvr6Wd4P4euOTtRjcB9NX/gxzDHpPiGelCoN8+vtCgkxaVSR
|
||||
aV+tAoGAO4HtLOfBAVDNbVXa27aJAjQSUq8qfkwUNJNz+rwgpVQahfiVkyqAPCQM
|
||||
IfRJxKWb0Wbt9ojw3AowK/k0d3LZA7FS41JSiiGKIllSGb+i7JKqKW7RHLA3VJ/E
|
||||
Bq5TLNIbUzPVNVwRcGjUYpOhKU6EIw8phTJOvxnUC+g6MVqBP8U=
|
||||
-----END RSA PRIVATE KEY-----`
|
||||
|
||||
type Server struct {
|
||||
password string
|
||||
services map[string]RpcHandleFunc
|
||||
|
|
@ -41,11 +65,7 @@ func NewServer(opts ...ServerOption) (*Server, error) {
|
|||
server.sshConfig = &ssh.ServerConfig{
|
||||
PasswordCallback: sshPasswordCallback(server.password),
|
||||
}
|
||||
b, err := ioutil.ReadFile(filepath.Join("test", "private.key"))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error reading private.key: %v", err)
|
||||
}
|
||||
privateKey, err := ssh.ParsePrivateKey(b)
|
||||
privateKey, err := ssh.ParsePrivateKey([]byte(sshPrivateKey))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing key: %v", err)
|
||||
}
|
||||
|
|
@ -221,70 +241,3 @@ func (r *rpcHandler) Handle(ctx context.Context, conn *jsonrpc2.Conn, req *jsonr
|
|||
r.server.errCh <- fmt.Errorf("error replying: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
type socketConn struct {
|
||||
*websocket.Conn
|
||||
|
||||
reader io.Reader
|
||||
writeMutex sync.Mutex
|
||||
readMutex sync.Mutex
|
||||
}
|
||||
|
||||
func newSocketConn(conn *websocket.Conn) *socketConn {
|
||||
return &socketConn{Conn: conn}
|
||||
}
|
||||
|
||||
func (s *socketConn) Read(b []byte) (int, error) {
|
||||
s.readMutex.Lock()
|
||||
defer s.readMutex.Unlock()
|
||||
|
||||
if s.reader == nil {
|
||||
msgType, r, err := s.Conn.NextReader()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("error getting next reader: %v", err)
|
||||
}
|
||||
if msgType != websocket.BinaryMessage {
|
||||
return 0, fmt.Errorf("invalid message type")
|
||||
}
|
||||
s.reader = r
|
||||
}
|
||||
|
||||
bytesRead, err := s.reader.Read(b)
|
||||
if err != nil {
|
||||
s.reader = nil
|
||||
|
||||
if err == io.EOF {
|
||||
err = nil
|
||||
}
|
||||
}
|
||||
|
||||
return bytesRead, err
|
||||
}
|
||||
|
||||
func (s *socketConn) Write(b []byte) (int, error) {
|
||||
s.writeMutex.Lock()
|
||||
defer s.writeMutex.Unlock()
|
||||
|
||||
w, err := s.Conn.NextWriter(websocket.BinaryMessage)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("error getting next writer: %v", err)
|
||||
}
|
||||
|
||||
n, err := w.Write(b)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("error writing: %v", err)
|
||||
}
|
||||
|
||||
if err := w.Close(); err != nil {
|
||||
return 0, fmt.Errorf("error closing writer: %v", err)
|
||||
}
|
||||
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (s *socketConn) SetDeadline(deadline time.Time) error {
|
||||
if err := s.Conn.SetReadDeadline(deadline); err != nil {
|
||||
return err
|
||||
}
|
||||
return s.Conn.SetWriteDeadline(deadline)
|
||||
}
|
||||
|
|
|
|||
77
test/socket.go
Normal file
77
test/socket.go
Normal file
|
|
@ -0,0 +1,77 @@
|
|||
package livesharetest
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
type socketConn struct {
|
||||
*websocket.Conn
|
||||
|
||||
reader io.Reader
|
||||
writeMutex sync.Mutex
|
||||
readMutex sync.Mutex
|
||||
}
|
||||
|
||||
func newSocketConn(conn *websocket.Conn) *socketConn {
|
||||
return &socketConn{Conn: conn}
|
||||
}
|
||||
|
||||
func (s *socketConn) Read(b []byte) (int, error) {
|
||||
s.readMutex.Lock()
|
||||
defer s.readMutex.Unlock()
|
||||
|
||||
if s.reader == nil {
|
||||
msgType, r, err := s.Conn.NextReader()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("error getting next reader: %v", err)
|
||||
}
|
||||
if msgType != websocket.BinaryMessage {
|
||||
return 0, fmt.Errorf("invalid message type")
|
||||
}
|
||||
s.reader = r
|
||||
}
|
||||
|
||||
bytesRead, err := s.reader.Read(b)
|
||||
if err != nil {
|
||||
s.reader = nil
|
||||
|
||||
if err == io.EOF {
|
||||
err = nil
|
||||
}
|
||||
}
|
||||
|
||||
return bytesRead, err
|
||||
}
|
||||
|
||||
func (s *socketConn) Write(b []byte) (int, error) {
|
||||
s.writeMutex.Lock()
|
||||
defer s.writeMutex.Unlock()
|
||||
|
||||
w, err := s.Conn.NextWriter(websocket.BinaryMessage)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("error getting next writer: %v", err)
|
||||
}
|
||||
|
||||
n, err := w.Write(b)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("error writing: %v", err)
|
||||
}
|
||||
|
||||
if err := w.Close(); err != nil {
|
||||
return 0, fmt.Errorf("error closing writer: %v", err)
|
||||
}
|
||||
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (s *socketConn) SetDeadline(deadline time.Time) error {
|
||||
if err := s.Conn.SetReadDeadline(deadline); err != nil {
|
||||
return err
|
||||
}
|
||||
return s.Conn.SetWriteDeadline(deadline)
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue