Final changes to finish this refactor
This commit is contained in:
parent
892f73221c
commit
0ab67badfa
7 changed files with 206 additions and 25 deletions
33
connection_test.go
Normal file
33
connection_test.go
Normal file
|
|
@ -0,0 +1,33 @@
|
|||
package liveshare
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestConnectionValid(t *testing.T) {
|
||||
conn := Connection{"sess-id", "sess-token", "sas", "endpoint"}
|
||||
if err := conn.validate(); err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConnectionInvalid(t *testing.T) {
|
||||
conn := Connection{"", "sess-token", "sas", "endpoint"}
|
||||
if err := conn.validate(); err == nil {
|
||||
t.Error(err)
|
||||
}
|
||||
conn = Connection{"sess-id", "", "sas", "endpoint"}
|
||||
if err := conn.validate(); err == nil {
|
||||
t.Error(err)
|
||||
}
|
||||
conn = Connection{"sess-id", "sess-token", "", "endpoint"}
|
||||
if err := conn.validate(); err == nil {
|
||||
t.Error(err)
|
||||
}
|
||||
conn = Connection{"sess-id", "sess-token", "sas", ""}
|
||||
if err := conn.validate(); err == nil {
|
||||
t.Error(err)
|
||||
}
|
||||
conn = Connection{"", "", "", ""}
|
||||
if err := conn.validate(); err == nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}
|
||||
|
|
@ -54,8 +54,12 @@ func (l *PortForwarder) handleConnection(ctx context.Context, conn net.Conn) {
|
|||
|
||||
copyConn := func(writer io.Writer, reader io.Reader) {
|
||||
if _, err := io.Copy(writer, reader); err != nil {
|
||||
fmt.Println(err)
|
||||
channel.Close()
|
||||
conn.Close()
|
||||
if err != io.EOF {
|
||||
l.errCh <- fmt.Errorf("tunnel connection: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -1 +1,103 @@
|
|||
package liveshare
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
livesharetest "github.com/github/go-liveshare/test"
|
||||
"github.com/sourcegraph/jsonrpc2"
|
||||
)
|
||||
|
||||
func TestNewPortForwarder(t *testing.T) {
|
||||
testServer, client, err := makeMockJoinedClient()
|
||||
if err != nil {
|
||||
t.Errorf("create mock client: %v", err)
|
||||
}
|
||||
defer testServer.Close()
|
||||
server, err := NewServer(client)
|
||||
if err != nil {
|
||||
t.Errorf("create new server: %v", err)
|
||||
}
|
||||
pf := NewPortForwarder(client, server, 80)
|
||||
if pf == nil {
|
||||
t.Error("port forwarder is nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPortForwarderStart(t *testing.T) {
|
||||
streamName, streamCondition := "stream-name", "stream-condition"
|
||||
serverSharing := func(req *jsonrpc2.Request) (interface{}, error) {
|
||||
return Port{StreamName: streamName, StreamCondition: streamCondition}, nil
|
||||
}
|
||||
getStream := func(req *jsonrpc2.Request) (interface{}, error) {
|
||||
return "stream-id", nil
|
||||
}
|
||||
|
||||
stream := bytes.NewBufferString("stream-data")
|
||||
testServer, client, err := makeMockJoinedClient(
|
||||
livesharetest.WithService("serverSharing.startSharing", serverSharing),
|
||||
livesharetest.WithService("streamManager.getStream", getStream),
|
||||
livesharetest.WithStream("stream-id", stream),
|
||||
)
|
||||
if err != nil {
|
||||
t.Errorf("create mock client: %v", err)
|
||||
}
|
||||
defer testServer.Close()
|
||||
|
||||
server, err := NewServer(client)
|
||||
if err != nil {
|
||||
t.Errorf("create new server: %v", err)
|
||||
}
|
||||
|
||||
ctx, _ := context.WithCancel(context.Background())
|
||||
pf := NewPortForwarder(client, server, 8000)
|
||||
done := make(chan error)
|
||||
|
||||
go func() {
|
||||
if err := server.StartSharing(ctx, "http", 8000); err != nil {
|
||||
done <- fmt.Errorf("start sharing: %v", err)
|
||||
}
|
||||
if err := pf.Start(ctx); err != nil {
|
||||
done <- err
|
||||
}
|
||||
done <- nil
|
||||
}()
|
||||
|
||||
go func() {
|
||||
var conn net.Conn
|
||||
retries := 0
|
||||
for conn == nil && retries < 2 {
|
||||
conn, err = net.DialTimeout("tcp", ":8000", 2*time.Second)
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
if conn == nil {
|
||||
done <- errors.New("failed to connect to forwarded port")
|
||||
}
|
||||
b := make([]byte, len("stream-data"))
|
||||
if _, err := conn.Read(b); err != nil && err != io.EOF {
|
||||
done <- fmt.Errorf("reading stream: %v", err)
|
||||
}
|
||||
if string(b) != "stream-data" {
|
||||
done <- fmt.Errorf("stream data is not expected value, got: %v", string(b))
|
||||
}
|
||||
if _, err := conn.Write([]byte("new-data")); err != nil {
|
||||
done <- fmt.Errorf("writing to stream: %v", err)
|
||||
}
|
||||
done <- nil
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-testServer.Err():
|
||||
t.Errorf("error from server: %v", err)
|
||||
case err := <-done:
|
||||
if err != nil {
|
||||
t.Errorf("error from client: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -7,12 +7,14 @@ import (
|
|||
"strconv"
|
||||
)
|
||||
|
||||
// A Server represents the liveshare host and container server
|
||||
type Server struct {
|
||||
client *Client
|
||||
port int
|
||||
streamName, streamCondition string
|
||||
}
|
||||
|
||||
// NewServer creates a new Server with a given Client
|
||||
func NewServer(client *Client) (*Server, error) {
|
||||
if !client.hasJoined() {
|
||||
return nil, errors.New("client must join before creating server")
|
||||
|
|
@ -21,6 +23,7 @@ func NewServer(client *Client) (*Server, error) {
|
|||
return &Server{client: client}, nil
|
||||
}
|
||||
|
||||
// Port represents an open port on the container
|
||||
type Port struct {
|
||||
SourcePort int `json:"sourcePort"`
|
||||
DestinationPort int `json:"destinationPort"`
|
||||
|
|
@ -33,6 +36,7 @@ type Port struct {
|
|||
HasTSLHandshakePassed bool `json:"hasTSLHandshakePassed"`
|
||||
}
|
||||
|
||||
// StartSharing tells the liveshare host to start sharing the port from the container
|
||||
func (s *Server) StartSharing(ctx context.Context, protocol string, port int) error {
|
||||
s.port = port
|
||||
|
||||
|
|
@ -49,8 +53,10 @@ func (s *Server) StartSharing(ctx context.Context, protocol string, port int) er
|
|||
return nil
|
||||
}
|
||||
|
||||
// Ports is a slice of Port pointers
|
||||
type Ports []*Port
|
||||
|
||||
// GetSharedServers returns a list of available/open ports from the container
|
||||
func (s *Server) GetSharedServers(ctx context.Context) (Ports, error) {
|
||||
var response Ports
|
||||
if err := s.client.rpc.do(ctx, "serverSharing.getSharedServers", []string{}, &response); err != nil {
|
||||
|
|
@ -60,6 +66,8 @@ func (s *Server) GetSharedServers(ctx context.Context) (Ports, error) {
|
|||
return response, nil
|
||||
}
|
||||
|
||||
// UpdateSharedVisibility controls port permissions and whether it can be accessed publicly
|
||||
// via the Browse URL
|
||||
func (s *Server) UpdateSharedVisibility(ctx context.Context, port int, public bool) error {
|
||||
if err := s.client.rpc.do(ctx, "serverSharing.updateSharedServerVisibility", []interface{}{port, public}, nil); err != nil {
|
||||
return err
|
||||
|
|
|
|||
|
|
@ -23,7 +23,7 @@ func TestNewServerWithNotJoinedClient(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func newMockJoinedClient(opts ...livesharetest.ServerOption) (*livesharetest.Server, *Client, error) {
|
||||
func makeMockJoinedClient(opts ...livesharetest.ServerOption) (*livesharetest.Server, *Client, error) {
|
||||
connection := Connection{
|
||||
SessionID: "session-id",
|
||||
SessionToken: "session-token",
|
||||
|
|
@ -54,7 +54,7 @@ func newMockJoinedClient(opts ...livesharetest.ServerOption) (*livesharetest.Ser
|
|||
}
|
||||
|
||||
func TestNewServer(t *testing.T) {
|
||||
testServer, client, err := newMockJoinedClient()
|
||||
testServer, client, err := makeMockJoinedClient()
|
||||
defer testServer.Close()
|
||||
if err != nil {
|
||||
t.Errorf("error creating mock joined client: %v", err)
|
||||
|
|
@ -95,7 +95,7 @@ func TestServerStartSharing(t *testing.T) {
|
|||
}
|
||||
return Port{StreamName: "stream-name", StreamCondition: "stream-condition"}, nil
|
||||
}
|
||||
testServer, client, err := newMockJoinedClient(
|
||||
testServer, client, err := makeMockJoinedClient(
|
||||
livesharetest.WithService("serverSharing.startSharing", startSharing),
|
||||
)
|
||||
defer testServer.Close()
|
||||
|
|
@ -138,7 +138,7 @@ func TestServerGetSharedServers(t *testing.T) {
|
|||
getSharedServers := func(req *jsonrpc2.Request) (interface{}, error) {
|
||||
return Ports{&sharedServer}, nil
|
||||
}
|
||||
testServer, client, err := newMockJoinedClient(
|
||||
testServer, client, err := makeMockJoinedClient(
|
||||
livesharetest.WithService("serverSharing.getSharedServers", getSharedServers),
|
||||
)
|
||||
if err != nil {
|
||||
|
|
@ -206,7 +206,7 @@ func TestServerUpdateSharedVisibility(t *testing.T) {
|
|||
}
|
||||
return nil, nil
|
||||
}
|
||||
testServer, client, err := newMockJoinedClient(
|
||||
testServer, client, err := makeMockJoinedClient(
|
||||
livesharetest.WithService("serverSharing.updateSharedServerVisibility", updateSharedVisibility),
|
||||
)
|
||||
if err != nil {
|
||||
|
|
|
|||
20
socket.go
20
socket.go
|
|
@ -3,11 +3,9 @@ package liveshare
|
|||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
|
|
@ -17,10 +15,8 @@ type socket struct {
|
|||
addr string
|
||||
tlsConfig *tls.Config
|
||||
|
||||
conn *websocket.Conn
|
||||
readMutex sync.Mutex
|
||||
writeMutex sync.Mutex
|
||||
reader io.Reader
|
||||
conn *websocket.Conn
|
||||
reader io.Reader
|
||||
}
|
||||
|
||||
func newSocket(clientConn Connection, tlsConfig *tls.Config) *socket {
|
||||
|
|
@ -42,19 +38,12 @@ func (s *socket) connect(ctx context.Context) error {
|
|||
}
|
||||
|
||||
func (s *socket) Read(b []byte) (int, error) {
|
||||
s.readMutex.Lock()
|
||||
defer s.readMutex.Unlock()
|
||||
|
||||
if s.reader == nil {
|
||||
messageType, reader, err := s.conn.NextReader()
|
||||
_, reader, err := s.conn.NextReader()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if messageType != websocket.BinaryMessage {
|
||||
return 0, errors.New("unexpected websocket message type")
|
||||
}
|
||||
|
||||
s.reader = reader
|
||||
}
|
||||
|
||||
|
|
@ -71,9 +60,6 @@ func (s *socket) Read(b []byte) (int, error) {
|
|||
}
|
||||
|
||||
func (s *socket) Write(b []byte) (int, error) {
|
||||
s.writeMutex.Lock()
|
||||
defer s.writeMutex.Unlock()
|
||||
|
||||
nextWriter, err := s.conn.NextWriter(websocket.BinaryMessage)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ import (
|
|||
"net/http"
|
||||
"net/http/httptest"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
|
|
@ -21,6 +22,7 @@ type Server struct {
|
|||
password string
|
||||
services map[string]RpcHandleFunc
|
||||
relaySAS string
|
||||
streams map[string]io.ReadWriter
|
||||
|
||||
sshConfig *ssh.ServerConfig
|
||||
httptestServer *httptest.Server
|
||||
|
|
@ -50,7 +52,7 @@ func NewServer(opts ...ServerOption) (*Server, error) {
|
|||
server.sshConfig.AddHostKey(privateKey)
|
||||
|
||||
server.errCh = make(chan error)
|
||||
server.httptestServer = httptest.NewTLSServer(http.HandlerFunc(newConnection(server)))
|
||||
server.httptestServer = httptest.NewTLSServer(http.HandlerFunc(makeConnection(server)))
|
||||
return server, nil
|
||||
}
|
||||
|
||||
|
|
@ -81,6 +83,16 @@ func WithRelaySAS(sas string) ServerOption {
|
|||
}
|
||||
}
|
||||
|
||||
func WithStream(name string, stream io.ReadWriter) ServerOption {
|
||||
return func(s *Server) error {
|
||||
if s.streams == nil {
|
||||
s.streams = make(map[string]io.ReadWriter)
|
||||
}
|
||||
s.streams[name] = stream
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func sshPasswordCallback(serverPassword string) func(ssh.ConnMetadata, []byte) (*ssh.Permissions, error) {
|
||||
return func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
|
||||
if string(password) == serverPassword {
|
||||
|
|
@ -104,7 +116,7 @@ func (s *Server) Err() <-chan error {
|
|||
|
||||
var upgrader = websocket.Upgrader{}
|
||||
|
||||
func newConnection(server *Server) http.HandlerFunc {
|
||||
func makeConnection(server *Server) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, req *http.Request) {
|
||||
if server.relaySAS != "" {
|
||||
// validate the sas key
|
||||
|
|
@ -135,12 +147,48 @@ func newConnection(server *Server) http.HandlerFunc {
|
|||
server.errCh <- fmt.Errorf("error accepting new channel: %v", err)
|
||||
return
|
||||
}
|
||||
go ssh.DiscardRequests(reqs)
|
||||
go handleNewRequests(server, ch, reqs)
|
||||
go handleNewChannel(server, ch)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func handleNewRequests(server *Server, channel ssh.Channel, reqs <-chan *ssh.Request) {
|
||||
for req := range reqs {
|
||||
if req.WantReply {
|
||||
if err := req.Reply(true, nil); err != nil {
|
||||
server.errCh <- fmt.Errorf("error replying to channel request: %v", err)
|
||||
}
|
||||
}
|
||||
if strings.HasPrefix(req.Type, "stream-transport") {
|
||||
forwardStream(server, req.Type, channel)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func forwardStream(server *Server, streamName string, channel ssh.Channel) {
|
||||
simpleStreamName := strings.TrimPrefix(streamName, "stream-transport-")
|
||||
stream, found := server.streams[simpleStreamName]
|
||||
if !found {
|
||||
server.errCh <- fmt.Errorf("stream '%v' not found", simpleStreamName)
|
||||
return
|
||||
}
|
||||
|
||||
copy := func(dst io.Writer, src io.Reader) {
|
||||
if _, err := io.Copy(dst, src); err != nil {
|
||||
fmt.Println(err)
|
||||
server.errCh <- fmt.Errorf("io copy: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
go copy(stream, channel)
|
||||
go copy(channel, stream)
|
||||
|
||||
for {
|
||||
}
|
||||
}
|
||||
|
||||
func handleNewChannel(server *Server, channel ssh.Channel) {
|
||||
stream := jsonrpc2.NewBufferedStream(channel, jsonrpc2.VSCodeObjectCodec{})
|
||||
jsonrpc2.NewConn(context.Background(), stream, newRpcHandler(server))
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue