Final changes to finish this refactor

This commit is contained in:
Jose Garcia 2021-07-27 23:19:55 +00:00 committed by GitHub
parent 892f73221c
commit 0ab67badfa
7 changed files with 206 additions and 25 deletions

33
connection_test.go Normal file
View file

@ -0,0 +1,33 @@
package liveshare
import "testing"
func TestConnectionValid(t *testing.T) {
conn := Connection{"sess-id", "sess-token", "sas", "endpoint"}
if err := conn.validate(); err != nil {
t.Error(err)
}
}
func TestConnectionInvalid(t *testing.T) {
conn := Connection{"", "sess-token", "sas", "endpoint"}
if err := conn.validate(); err == nil {
t.Error(err)
}
conn = Connection{"sess-id", "", "sas", "endpoint"}
if err := conn.validate(); err == nil {
t.Error(err)
}
conn = Connection{"sess-id", "sess-token", "", "endpoint"}
if err := conn.validate(); err == nil {
t.Error(err)
}
conn = Connection{"sess-id", "sess-token", "sas", ""}
if err := conn.validate(); err == nil {
t.Error(err)
}
conn = Connection{"", "", "", ""}
if err := conn.validate(); err == nil {
t.Error(err)
}
}

View file

@ -54,8 +54,12 @@ func (l *PortForwarder) handleConnection(ctx context.Context, conn net.Conn) {
copyConn := func(writer io.Writer, reader io.Reader) {
if _, err := io.Copy(writer, reader); err != nil {
fmt.Println(err)
channel.Close()
conn.Close()
if err != io.EOF {
l.errCh <- fmt.Errorf("tunnel connection: %v", err)
}
}
}

View file

@ -1 +1,103 @@
package liveshare
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net"
"testing"
"time"
livesharetest "github.com/github/go-liveshare/test"
"github.com/sourcegraph/jsonrpc2"
)
func TestNewPortForwarder(t *testing.T) {
testServer, client, err := makeMockJoinedClient()
if err != nil {
t.Errorf("create mock client: %v", err)
}
defer testServer.Close()
server, err := NewServer(client)
if err != nil {
t.Errorf("create new server: %v", err)
}
pf := NewPortForwarder(client, server, 80)
if pf == nil {
t.Error("port forwarder is nil")
}
}
func TestPortForwarderStart(t *testing.T) {
streamName, streamCondition := "stream-name", "stream-condition"
serverSharing := func(req *jsonrpc2.Request) (interface{}, error) {
return Port{StreamName: streamName, StreamCondition: streamCondition}, nil
}
getStream := func(req *jsonrpc2.Request) (interface{}, error) {
return "stream-id", nil
}
stream := bytes.NewBufferString("stream-data")
testServer, client, err := makeMockJoinedClient(
livesharetest.WithService("serverSharing.startSharing", serverSharing),
livesharetest.WithService("streamManager.getStream", getStream),
livesharetest.WithStream("stream-id", stream),
)
if err != nil {
t.Errorf("create mock client: %v", err)
}
defer testServer.Close()
server, err := NewServer(client)
if err != nil {
t.Errorf("create new server: %v", err)
}
ctx, _ := context.WithCancel(context.Background())
pf := NewPortForwarder(client, server, 8000)
done := make(chan error)
go func() {
if err := server.StartSharing(ctx, "http", 8000); err != nil {
done <- fmt.Errorf("start sharing: %v", err)
}
if err := pf.Start(ctx); err != nil {
done <- err
}
done <- nil
}()
go func() {
var conn net.Conn
retries := 0
for conn == nil && retries < 2 {
conn, err = net.DialTimeout("tcp", ":8000", 2*time.Second)
time.Sleep(1 * time.Second)
}
if conn == nil {
done <- errors.New("failed to connect to forwarded port")
}
b := make([]byte, len("stream-data"))
if _, err := conn.Read(b); err != nil && err != io.EOF {
done <- fmt.Errorf("reading stream: %v", err)
}
if string(b) != "stream-data" {
done <- fmt.Errorf("stream data is not expected value, got: %v", string(b))
}
if _, err := conn.Write([]byte("new-data")); err != nil {
done <- fmt.Errorf("writing to stream: %v", err)
}
done <- nil
}()
select {
case err := <-testServer.Err():
t.Errorf("error from server: %v", err)
case err := <-done:
if err != nil {
t.Errorf("error from client: %v", err)
}
}
}

View file

@ -7,12 +7,14 @@ import (
"strconv"
)
// A Server represents the liveshare host and container server
type Server struct {
client *Client
port int
streamName, streamCondition string
}
// NewServer creates a new Server with a given Client
func NewServer(client *Client) (*Server, error) {
if !client.hasJoined() {
return nil, errors.New("client must join before creating server")
@ -21,6 +23,7 @@ func NewServer(client *Client) (*Server, error) {
return &Server{client: client}, nil
}
// Port represents an open port on the container
type Port struct {
SourcePort int `json:"sourcePort"`
DestinationPort int `json:"destinationPort"`
@ -33,6 +36,7 @@ type Port struct {
HasTSLHandshakePassed bool `json:"hasTSLHandshakePassed"`
}
// StartSharing tells the liveshare host to start sharing the port from the container
func (s *Server) StartSharing(ctx context.Context, protocol string, port int) error {
s.port = port
@ -49,8 +53,10 @@ func (s *Server) StartSharing(ctx context.Context, protocol string, port int) er
return nil
}
// Ports is a slice of Port pointers
type Ports []*Port
// GetSharedServers returns a list of available/open ports from the container
func (s *Server) GetSharedServers(ctx context.Context) (Ports, error) {
var response Ports
if err := s.client.rpc.do(ctx, "serverSharing.getSharedServers", []string{}, &response); err != nil {
@ -60,6 +66,8 @@ func (s *Server) GetSharedServers(ctx context.Context) (Ports, error) {
return response, nil
}
// UpdateSharedVisibility controls port permissions and whether it can be accessed publicly
// via the Browse URL
func (s *Server) UpdateSharedVisibility(ctx context.Context, port int, public bool) error {
if err := s.client.rpc.do(ctx, "serverSharing.updateSharedServerVisibility", []interface{}{port, public}, nil); err != nil {
return err

View file

@ -23,7 +23,7 @@ func TestNewServerWithNotJoinedClient(t *testing.T) {
}
}
func newMockJoinedClient(opts ...livesharetest.ServerOption) (*livesharetest.Server, *Client, error) {
func makeMockJoinedClient(opts ...livesharetest.ServerOption) (*livesharetest.Server, *Client, error) {
connection := Connection{
SessionID: "session-id",
SessionToken: "session-token",
@ -54,7 +54,7 @@ func newMockJoinedClient(opts ...livesharetest.ServerOption) (*livesharetest.Ser
}
func TestNewServer(t *testing.T) {
testServer, client, err := newMockJoinedClient()
testServer, client, err := makeMockJoinedClient()
defer testServer.Close()
if err != nil {
t.Errorf("error creating mock joined client: %v", err)
@ -95,7 +95,7 @@ func TestServerStartSharing(t *testing.T) {
}
return Port{StreamName: "stream-name", StreamCondition: "stream-condition"}, nil
}
testServer, client, err := newMockJoinedClient(
testServer, client, err := makeMockJoinedClient(
livesharetest.WithService("serverSharing.startSharing", startSharing),
)
defer testServer.Close()
@ -138,7 +138,7 @@ func TestServerGetSharedServers(t *testing.T) {
getSharedServers := func(req *jsonrpc2.Request) (interface{}, error) {
return Ports{&sharedServer}, nil
}
testServer, client, err := newMockJoinedClient(
testServer, client, err := makeMockJoinedClient(
livesharetest.WithService("serverSharing.getSharedServers", getSharedServers),
)
if err != nil {
@ -206,7 +206,7 @@ func TestServerUpdateSharedVisibility(t *testing.T) {
}
return nil, nil
}
testServer, client, err := newMockJoinedClient(
testServer, client, err := makeMockJoinedClient(
livesharetest.WithService("serverSharing.updateSharedServerVisibility", updateSharedVisibility),
)
if err != nil {

View file

@ -3,11 +3,9 @@ package liveshare
import (
"context"
"crypto/tls"
"errors"
"io"
"net"
"net/http"
"sync"
"time"
"github.com/gorilla/websocket"
@ -17,10 +15,8 @@ type socket struct {
addr string
tlsConfig *tls.Config
conn *websocket.Conn
readMutex sync.Mutex
writeMutex sync.Mutex
reader io.Reader
conn *websocket.Conn
reader io.Reader
}
func newSocket(clientConn Connection, tlsConfig *tls.Config) *socket {
@ -42,19 +38,12 @@ func (s *socket) connect(ctx context.Context) error {
}
func (s *socket) Read(b []byte) (int, error) {
s.readMutex.Lock()
defer s.readMutex.Unlock()
if s.reader == nil {
messageType, reader, err := s.conn.NextReader()
_, reader, err := s.conn.NextReader()
if err != nil {
return 0, err
}
if messageType != websocket.BinaryMessage {
return 0, errors.New("unexpected websocket message type")
}
s.reader = reader
}
@ -71,9 +60,6 @@ func (s *socket) Read(b []byte) (int, error) {
}
func (s *socket) Write(b []byte) (int, error) {
s.writeMutex.Lock()
defer s.writeMutex.Unlock()
nextWriter, err := s.conn.NextWriter(websocket.BinaryMessage)
if err != nil {
return 0, err

View file

@ -9,6 +9,7 @@ import (
"net/http"
"net/http/httptest"
"path/filepath"
"strings"
"sync"
"time"
@ -21,6 +22,7 @@ type Server struct {
password string
services map[string]RpcHandleFunc
relaySAS string
streams map[string]io.ReadWriter
sshConfig *ssh.ServerConfig
httptestServer *httptest.Server
@ -50,7 +52,7 @@ func NewServer(opts ...ServerOption) (*Server, error) {
server.sshConfig.AddHostKey(privateKey)
server.errCh = make(chan error)
server.httptestServer = httptest.NewTLSServer(http.HandlerFunc(newConnection(server)))
server.httptestServer = httptest.NewTLSServer(http.HandlerFunc(makeConnection(server)))
return server, nil
}
@ -81,6 +83,16 @@ func WithRelaySAS(sas string) ServerOption {
}
}
func WithStream(name string, stream io.ReadWriter) ServerOption {
return func(s *Server) error {
if s.streams == nil {
s.streams = make(map[string]io.ReadWriter)
}
s.streams[name] = stream
return nil
}
}
func sshPasswordCallback(serverPassword string) func(ssh.ConnMetadata, []byte) (*ssh.Permissions, error) {
return func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
if string(password) == serverPassword {
@ -104,7 +116,7 @@ func (s *Server) Err() <-chan error {
var upgrader = websocket.Upgrader{}
func newConnection(server *Server) http.HandlerFunc {
func makeConnection(server *Server) http.HandlerFunc {
return func(w http.ResponseWriter, req *http.Request) {
if server.relaySAS != "" {
// validate the sas key
@ -135,12 +147,48 @@ func newConnection(server *Server) http.HandlerFunc {
server.errCh <- fmt.Errorf("error accepting new channel: %v", err)
return
}
go ssh.DiscardRequests(reqs)
go handleNewRequests(server, ch, reqs)
go handleNewChannel(server, ch)
}
}
}
func handleNewRequests(server *Server, channel ssh.Channel, reqs <-chan *ssh.Request) {
for req := range reqs {
if req.WantReply {
if err := req.Reply(true, nil); err != nil {
server.errCh <- fmt.Errorf("error replying to channel request: %v", err)
}
}
if strings.HasPrefix(req.Type, "stream-transport") {
forwardStream(server, req.Type, channel)
}
}
}
func forwardStream(server *Server, streamName string, channel ssh.Channel) {
simpleStreamName := strings.TrimPrefix(streamName, "stream-transport-")
stream, found := server.streams[simpleStreamName]
if !found {
server.errCh <- fmt.Errorf("stream '%v' not found", simpleStreamName)
return
}
copy := func(dst io.Writer, src io.Reader) {
if _, err := io.Copy(dst, src); err != nil {
fmt.Println(err)
server.errCh <- fmt.Errorf("io copy: %v", err)
return
}
}
go copy(stream, channel)
go copy(channel, stream)
for {
}
}
func handleNewChannel(server *Server, channel ssh.Channel) {
stream := jsonrpc2.NewBufferedStream(channel, jsonrpc2.VSCodeObjectCodec{})
jsonrpc2.NewConn(context.Background(), stream, newRpcHandler(server))