Pass conn to handlers instead of obj stream

This commit is contained in:
Jose Garcia 2022-03-14 10:29:31 -04:00
parent ca7e2d386d
commit ed376f3691
5 changed files with 41 additions and 48 deletions

View file

@ -154,7 +154,7 @@ type joinWorkspaceResult struct {
func runUpdateVisibilityTest(t *testing.T, portVisibilities []portVisibility, eventResponses []string, portsData []liveshare.PortNotification) error {
t.Helper()
joinWorkspace := func(req *jsonrpc2.Request) (interface{}, error) {
joinWorkspace := func(conn *jsonrpc2.Conn, req *jsonrpc2.Request) (interface{}, error) {
return joinWorkspaceResult{1}, nil
}
const sessionToken = "session-token"
@ -162,14 +162,14 @@ func runUpdateVisibilityTest(t *testing.T, portVisibilities []portVisibility, ev
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
ch := make(chan float64, 1)
updateSharedVisibility := func(rpcReq *jsonrpc2.Request) (interface{}, error) {
ch := make(chan *jsonrpc2.Conn, 1)
updateSharedVisibility := func(conn *jsonrpc2.Conn, rpcReq *jsonrpc2.Request) (interface{}, error) {
var req []interface{}
if err := json.Unmarshal(*rpcReq.Params, &req); err != nil {
return nil, fmt.Errorf("unmarshal req: %w", err)
}
ch <- req[0].(float64)
ch <- conn
return nil, nil
}
testServer, err := livesharetest.NewServer(
@ -193,12 +193,9 @@ func runUpdateVisibilityTest(t *testing.T, portVisibilities []portVisibility, ev
select {
case <-ctx.Done():
return
case <-ch:
case conn := <-ch:
pd := portsData[i]
_ = testServer.WriteToObjectStream(rpcMessage{
Method: eventResponses[i],
Params: pd.PortUpdate,
})
_, _ = conn.DispatchCall(context.Background(), eventResponses[i], pd.PortUpdate, nil)
}
}
}()

View file

@ -22,7 +22,7 @@ func TestConnect(t *testing.T) {
HostPublicKeys: []string{livesharetest.SSHPublicKey},
Logger: newMockLogger(),
}
joinWorkspace := func(req *jsonrpc2.Request) (interface{}, error) {
joinWorkspace := func(conn *jsonrpc2.Conn, req *jsonrpc2.Request) (interface{}, error) {
var joinWorkspaceReq joinWorkspaceArgs
if err := json.Unmarshal(*req.Params, &joinWorkspaceReq); err != nil {
return nil, fmt.Errorf("error unmarshaling req: %w", err)

View file

@ -26,18 +26,26 @@ func TestNewPortForwarder(t *testing.T) {
}
}
type portUpdateNotification struct {
PortUpdate
conn *jsonrpc2.Conn
}
func TestPortForwarderStart(t *testing.T) {
streamName, streamCondition := "stream-name", "stream-condition"
port := 8000
sendNotification := make(chan PortUpdate)
serverSharing := func(req *jsonrpc2.Request) (interface{}, error) {
sendNotification <- PortUpdate{
Port: int(port),
ChangeKind: PortChangeKindStart,
sendNotification := make(chan portUpdateNotification)
serverSharing := func(conn *jsonrpc2.Conn, req *jsonrpc2.Request) (interface{}, error) {
sendNotification <- portUpdateNotification{
PortUpdate: PortUpdate{
Port: int(port),
ChangeKind: PortChangeKindStart,
},
conn: conn,
}
return Port{StreamName: streamName, StreamCondition: streamCondition}, nil
}
getStream := func(req *jsonrpc2.Request) (interface{}, error) {
getStream := func(conn *jsonrpc2.Conn, req *jsonrpc2.Request) (interface{}, error) {
return "stream-id", nil
}
@ -62,10 +70,8 @@ func TestPortForwarderStart(t *testing.T) {
defer cancel()
go func() {
_ = testServer.WriteToObjectStream(rpcPortTestMessage{
Method: "serverSharing.sharingSucceeded",
Params: <-sendNotification,
})
notif := <-sendNotification
_, _ = notif.conn.DispatchCall(context.Background(), "serverSharing.sharingSucceeded", notif.PortUpdate)
}()
done := make(chan error)

View file

@ -19,7 +19,7 @@ import (
const mockClientName = "liveshare-client"
func makeMockSession(opts ...livesharetest.ServerOption) (*livesharetest.Server, *Session, error) {
joinWorkspace := func(req *jsonrpc2.Request) (interface{}, error) {
joinWorkspace := func(conn *jsonrpc2.Conn, req *jsonrpc2.Request) (interface{}, error) {
return joinWorkspaceResult{1}, nil
}
const sessionToken = "session-token"
@ -56,8 +56,8 @@ type rpcPortTestMessage struct {
func TestServerStartSharing(t *testing.T) {
serverPort, serverProtocol := 2222, "sshd"
sendNotification := make(chan PortUpdate)
startSharing := func(req *jsonrpc2.Request) (interface{}, error) {
sendNotification := make(chan portUpdateNotification)
startSharing := func(conn *jsonrpc2.Conn, req *jsonrpc2.Request) (interface{}, error) {
var args []interface{}
if err := json.Unmarshal(*req.Params, &args); err != nil {
return nil, fmt.Errorf("error unmarshaling request: %w", err)
@ -82,9 +82,12 @@ func TestServerStartSharing(t *testing.T) {
} else if browseURL != fmt.Sprintf("http://localhost:%d", serverPort) {
return nil, errors.New("browseURL does not match expected")
}
sendNotification <- PortUpdate{
Port: int(port),
ChangeKind: PortChangeKindStart,
sendNotification <- portUpdateNotification{
PortUpdate: PortUpdate{
Port: int(port),
ChangeKind: PortChangeKindStart,
},
conn: conn,
}
return Port{StreamName: "stream-name", StreamCondition: "stream-condition"}, nil
}
@ -99,10 +102,8 @@ func TestServerStartSharing(t *testing.T) {
ctx := context.Background()
go func() {
_ = testServer.WriteToObjectStream(rpcPortTestMessage{
Method: "serverSharing.sharingSucceeded",
Params: <-sendNotification,
})
notif := <-sendNotification
_, _ = notif.conn.DispatchCall(context.Background(), "serverSharing.sharingSucceeded", notif.PortUpdate)
}()
done := make(chan error)
@ -133,7 +134,7 @@ func TestServerGetSharedServers(t *testing.T) {
StreamName: "stream-name",
StreamCondition: "stream-condition",
}
getSharedServers := func(req *jsonrpc2.Request) (interface{}, error) {
getSharedServers := func(conn *jsonrpc2.Conn, req *jsonrpc2.Request) (interface{}, error) {
return []*Port{&sharedServer}, nil
}
testServer, session, err := makeMockSession(
@ -176,7 +177,7 @@ func TestServerGetSharedServers(t *testing.T) {
}
func TestServerUpdateSharedServerPrivacy(t *testing.T) {
updateSharedVisibility := func(rpcReq *jsonrpc2.Request) (interface{}, error) {
updateSharedVisibility := func(conn *jsonrpc2.Conn, rpcReq *jsonrpc2.Request) (interface{}, error) {
var req []interface{}
if err := json.Unmarshal(*rpcReq.Params, &req); err != nil {
return nil, fmt.Errorf("unmarshal req: %w", err)
@ -223,7 +224,7 @@ func TestServerUpdateSharedServerPrivacy(t *testing.T) {
}
func TestInvalidHostKey(t *testing.T) {
joinWorkspace := func(req *jsonrpc2.Request) (interface{}, error) {
joinWorkspace := func(conn *jsonrpc2.Conn, req *jsonrpc2.Request) (interface{}, error) {
return joinWorkspaceResult{1}, nil
}
const sessionToken = "session-token"
@ -259,7 +260,7 @@ func TestKeepAliveNonBlocking(t *testing.T) {
}
func TestNotifyHostOfActivity(t *testing.T) {
notifyHostOfActivity := func(rpcReq *jsonrpc2.Request) (interface{}, error) {
notifyHostOfActivity := func(conn *jsonrpc2.Conn, rpcReq *jsonrpc2.Request) (interface{}, error) {
var req []interface{}
if err := json.Unmarshal(*rpcReq.Params, &req); err != nil {
return nil, fmt.Errorf("unmarshal req: %w", err)
@ -318,7 +319,7 @@ func TestSessionHeartbeat(t *testing.T) {
wg sync.WaitGroup
)
wg.Add(1)
notifyHostOfActivity := func(rpcReq *jsonrpc2.Request) (interface{}, error) {
notifyHostOfActivity := func(conn *jsonrpc2.Conn, rpcReq *jsonrpc2.Request) (interface{}, error) {
defer wg.Done()
requestsMu.Lock()
requests++

View file

@ -43,8 +43,6 @@ type Server struct {
httptestServer *httptest.Server
errCh chan error
nonSecure bool
objectStream jsonrpc2.ObjectStream
}
// NewServer creates a new Server. ServerOptions can be passed to configure
@ -149,13 +147,6 @@ func (s *Server) Err() <-chan error {
return s.errCh
}
func (s *Server) WriteToObjectStream(obj interface{}) error {
if s.objectStream == nil {
return errors.New("object stream not set")
}
return s.objectStream.WriteObject(obj)
}
var upgrader = websocket.Upgrader{}
func makeConnection(server *Server) http.HandlerFunc {
@ -322,12 +313,10 @@ func forwardStream(ctx context.Context, server *Server, streamName string, chann
func handleChannel(server *Server, channel ssh.Channel) {
stream := jsonrpc2.NewBufferedStream(channel, jsonrpc2.VSCodeObjectCodec{})
server.objectStream = stream
jsonrpc2.NewConn(context.Background(), stream, newRPCHandler(server))
}
type RPCHandleFunc func(req *jsonrpc2.Request) (interface{}, error)
type RPCHandleFunc func(conn *jsonrpc2.Conn, req *jsonrpc2.Request) (interface{}, error)
type rpcHandler struct {
server *Server
@ -346,7 +335,7 @@ func (r *rpcHandler) Handle(ctx context.Context, conn *jsonrpc2.Conn, req *jsonr
return
}
result, err := handler(req)
result, err := handler(conn, req)
if err != nil {
sendError(r.server.errCh, fmt.Errorf("error handling: '%s': %w", req.Method, err))
return