Pass conn to handlers instead of obj stream
This commit is contained in:
parent
ca7e2d386d
commit
ed376f3691
5 changed files with 41 additions and 48 deletions
|
|
@ -154,7 +154,7 @@ type joinWorkspaceResult struct {
|
|||
func runUpdateVisibilityTest(t *testing.T, portVisibilities []portVisibility, eventResponses []string, portsData []liveshare.PortNotification) error {
|
||||
t.Helper()
|
||||
|
||||
joinWorkspace := func(req *jsonrpc2.Request) (interface{}, error) {
|
||||
joinWorkspace := func(conn *jsonrpc2.Conn, req *jsonrpc2.Request) (interface{}, error) {
|
||||
return joinWorkspaceResult{1}, nil
|
||||
}
|
||||
const sessionToken = "session-token"
|
||||
|
|
@ -162,14 +162,14 @@ func runUpdateVisibilityTest(t *testing.T, portVisibilities []portVisibility, ev
|
|||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
ch := make(chan float64, 1)
|
||||
updateSharedVisibility := func(rpcReq *jsonrpc2.Request) (interface{}, error) {
|
||||
ch := make(chan *jsonrpc2.Conn, 1)
|
||||
updateSharedVisibility := func(conn *jsonrpc2.Conn, rpcReq *jsonrpc2.Request) (interface{}, error) {
|
||||
var req []interface{}
|
||||
if err := json.Unmarshal(*rpcReq.Params, &req); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal req: %w", err)
|
||||
}
|
||||
|
||||
ch <- req[0].(float64)
|
||||
ch <- conn
|
||||
return nil, nil
|
||||
}
|
||||
testServer, err := livesharetest.NewServer(
|
||||
|
|
@ -193,12 +193,9 @@ func runUpdateVisibilityTest(t *testing.T, portVisibilities []portVisibility, ev
|
|||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ch:
|
||||
case conn := <-ch:
|
||||
pd := portsData[i]
|
||||
_ = testServer.WriteToObjectStream(rpcMessage{
|
||||
Method: eventResponses[i],
|
||||
Params: pd.PortUpdate,
|
||||
})
|
||||
_, _ = conn.DispatchCall(context.Background(), eventResponses[i], pd.PortUpdate, nil)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ func TestConnect(t *testing.T) {
|
|||
HostPublicKeys: []string{livesharetest.SSHPublicKey},
|
||||
Logger: newMockLogger(),
|
||||
}
|
||||
joinWorkspace := func(req *jsonrpc2.Request) (interface{}, error) {
|
||||
joinWorkspace := func(conn *jsonrpc2.Conn, req *jsonrpc2.Request) (interface{}, error) {
|
||||
var joinWorkspaceReq joinWorkspaceArgs
|
||||
if err := json.Unmarshal(*req.Params, &joinWorkspaceReq); err != nil {
|
||||
return nil, fmt.Errorf("error unmarshaling req: %w", err)
|
||||
|
|
|
|||
|
|
@ -26,18 +26,26 @@ func TestNewPortForwarder(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
type portUpdateNotification struct {
|
||||
PortUpdate
|
||||
conn *jsonrpc2.Conn
|
||||
}
|
||||
|
||||
func TestPortForwarderStart(t *testing.T) {
|
||||
streamName, streamCondition := "stream-name", "stream-condition"
|
||||
port := 8000
|
||||
sendNotification := make(chan PortUpdate)
|
||||
serverSharing := func(req *jsonrpc2.Request) (interface{}, error) {
|
||||
sendNotification <- PortUpdate{
|
||||
Port: int(port),
|
||||
ChangeKind: PortChangeKindStart,
|
||||
sendNotification := make(chan portUpdateNotification)
|
||||
serverSharing := func(conn *jsonrpc2.Conn, req *jsonrpc2.Request) (interface{}, error) {
|
||||
sendNotification <- portUpdateNotification{
|
||||
PortUpdate: PortUpdate{
|
||||
Port: int(port),
|
||||
ChangeKind: PortChangeKindStart,
|
||||
},
|
||||
conn: conn,
|
||||
}
|
||||
return Port{StreamName: streamName, StreamCondition: streamCondition}, nil
|
||||
}
|
||||
getStream := func(req *jsonrpc2.Request) (interface{}, error) {
|
||||
getStream := func(conn *jsonrpc2.Conn, req *jsonrpc2.Request) (interface{}, error) {
|
||||
return "stream-id", nil
|
||||
}
|
||||
|
||||
|
|
@ -62,10 +70,8 @@ func TestPortForwarderStart(t *testing.T) {
|
|||
defer cancel()
|
||||
|
||||
go func() {
|
||||
_ = testServer.WriteToObjectStream(rpcPortTestMessage{
|
||||
Method: "serverSharing.sharingSucceeded",
|
||||
Params: <-sendNotification,
|
||||
})
|
||||
notif := <-sendNotification
|
||||
_, _ = notif.conn.DispatchCall(context.Background(), "serverSharing.sharingSucceeded", notif.PortUpdate)
|
||||
}()
|
||||
|
||||
done := make(chan error)
|
||||
|
|
|
|||
|
|
@ -19,7 +19,7 @@ import (
|
|||
const mockClientName = "liveshare-client"
|
||||
|
||||
func makeMockSession(opts ...livesharetest.ServerOption) (*livesharetest.Server, *Session, error) {
|
||||
joinWorkspace := func(req *jsonrpc2.Request) (interface{}, error) {
|
||||
joinWorkspace := func(conn *jsonrpc2.Conn, req *jsonrpc2.Request) (interface{}, error) {
|
||||
return joinWorkspaceResult{1}, nil
|
||||
}
|
||||
const sessionToken = "session-token"
|
||||
|
|
@ -56,8 +56,8 @@ type rpcPortTestMessage struct {
|
|||
|
||||
func TestServerStartSharing(t *testing.T) {
|
||||
serverPort, serverProtocol := 2222, "sshd"
|
||||
sendNotification := make(chan PortUpdate)
|
||||
startSharing := func(req *jsonrpc2.Request) (interface{}, error) {
|
||||
sendNotification := make(chan portUpdateNotification)
|
||||
startSharing := func(conn *jsonrpc2.Conn, req *jsonrpc2.Request) (interface{}, error) {
|
||||
var args []interface{}
|
||||
if err := json.Unmarshal(*req.Params, &args); err != nil {
|
||||
return nil, fmt.Errorf("error unmarshaling request: %w", err)
|
||||
|
|
@ -82,9 +82,12 @@ func TestServerStartSharing(t *testing.T) {
|
|||
} else if browseURL != fmt.Sprintf("http://localhost:%d", serverPort) {
|
||||
return nil, errors.New("browseURL does not match expected")
|
||||
}
|
||||
sendNotification <- PortUpdate{
|
||||
Port: int(port),
|
||||
ChangeKind: PortChangeKindStart,
|
||||
sendNotification <- portUpdateNotification{
|
||||
PortUpdate: PortUpdate{
|
||||
Port: int(port),
|
||||
ChangeKind: PortChangeKindStart,
|
||||
},
|
||||
conn: conn,
|
||||
}
|
||||
return Port{StreamName: "stream-name", StreamCondition: "stream-condition"}, nil
|
||||
}
|
||||
|
|
@ -99,10 +102,8 @@ func TestServerStartSharing(t *testing.T) {
|
|||
ctx := context.Background()
|
||||
|
||||
go func() {
|
||||
_ = testServer.WriteToObjectStream(rpcPortTestMessage{
|
||||
Method: "serverSharing.sharingSucceeded",
|
||||
Params: <-sendNotification,
|
||||
})
|
||||
notif := <-sendNotification
|
||||
_, _ = notif.conn.DispatchCall(context.Background(), "serverSharing.sharingSucceeded", notif.PortUpdate)
|
||||
}()
|
||||
|
||||
done := make(chan error)
|
||||
|
|
@ -133,7 +134,7 @@ func TestServerGetSharedServers(t *testing.T) {
|
|||
StreamName: "stream-name",
|
||||
StreamCondition: "stream-condition",
|
||||
}
|
||||
getSharedServers := func(req *jsonrpc2.Request) (interface{}, error) {
|
||||
getSharedServers := func(conn *jsonrpc2.Conn, req *jsonrpc2.Request) (interface{}, error) {
|
||||
return []*Port{&sharedServer}, nil
|
||||
}
|
||||
testServer, session, err := makeMockSession(
|
||||
|
|
@ -176,7 +177,7 @@ func TestServerGetSharedServers(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestServerUpdateSharedServerPrivacy(t *testing.T) {
|
||||
updateSharedVisibility := func(rpcReq *jsonrpc2.Request) (interface{}, error) {
|
||||
updateSharedVisibility := func(conn *jsonrpc2.Conn, rpcReq *jsonrpc2.Request) (interface{}, error) {
|
||||
var req []interface{}
|
||||
if err := json.Unmarshal(*rpcReq.Params, &req); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal req: %w", err)
|
||||
|
|
@ -223,7 +224,7 @@ func TestServerUpdateSharedServerPrivacy(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestInvalidHostKey(t *testing.T) {
|
||||
joinWorkspace := func(req *jsonrpc2.Request) (interface{}, error) {
|
||||
joinWorkspace := func(conn *jsonrpc2.Conn, req *jsonrpc2.Request) (interface{}, error) {
|
||||
return joinWorkspaceResult{1}, nil
|
||||
}
|
||||
const sessionToken = "session-token"
|
||||
|
|
@ -259,7 +260,7 @@ func TestKeepAliveNonBlocking(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestNotifyHostOfActivity(t *testing.T) {
|
||||
notifyHostOfActivity := func(rpcReq *jsonrpc2.Request) (interface{}, error) {
|
||||
notifyHostOfActivity := func(conn *jsonrpc2.Conn, rpcReq *jsonrpc2.Request) (interface{}, error) {
|
||||
var req []interface{}
|
||||
if err := json.Unmarshal(*rpcReq.Params, &req); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal req: %w", err)
|
||||
|
|
@ -318,7 +319,7 @@ func TestSessionHeartbeat(t *testing.T) {
|
|||
wg sync.WaitGroup
|
||||
)
|
||||
wg.Add(1)
|
||||
notifyHostOfActivity := func(rpcReq *jsonrpc2.Request) (interface{}, error) {
|
||||
notifyHostOfActivity := func(conn *jsonrpc2.Conn, rpcReq *jsonrpc2.Request) (interface{}, error) {
|
||||
defer wg.Done()
|
||||
requestsMu.Lock()
|
||||
requests++
|
||||
|
|
|
|||
|
|
@ -43,8 +43,6 @@ type Server struct {
|
|||
httptestServer *httptest.Server
|
||||
errCh chan error
|
||||
nonSecure bool
|
||||
|
||||
objectStream jsonrpc2.ObjectStream
|
||||
}
|
||||
|
||||
// NewServer creates a new Server. ServerOptions can be passed to configure
|
||||
|
|
@ -149,13 +147,6 @@ func (s *Server) Err() <-chan error {
|
|||
return s.errCh
|
||||
}
|
||||
|
||||
func (s *Server) WriteToObjectStream(obj interface{}) error {
|
||||
if s.objectStream == nil {
|
||||
return errors.New("object stream not set")
|
||||
}
|
||||
return s.objectStream.WriteObject(obj)
|
||||
}
|
||||
|
||||
var upgrader = websocket.Upgrader{}
|
||||
|
||||
func makeConnection(server *Server) http.HandlerFunc {
|
||||
|
|
@ -322,12 +313,10 @@ func forwardStream(ctx context.Context, server *Server, streamName string, chann
|
|||
|
||||
func handleChannel(server *Server, channel ssh.Channel) {
|
||||
stream := jsonrpc2.NewBufferedStream(channel, jsonrpc2.VSCodeObjectCodec{})
|
||||
server.objectStream = stream
|
||||
|
||||
jsonrpc2.NewConn(context.Background(), stream, newRPCHandler(server))
|
||||
}
|
||||
|
||||
type RPCHandleFunc func(req *jsonrpc2.Request) (interface{}, error)
|
||||
type RPCHandleFunc func(conn *jsonrpc2.Conn, req *jsonrpc2.Request) (interface{}, error)
|
||||
|
||||
type rpcHandler struct {
|
||||
server *Server
|
||||
|
|
@ -346,7 +335,7 @@ func (r *rpcHandler) Handle(ctx context.Context, conn *jsonrpc2.Conn, req *jsonr
|
|||
return
|
||||
}
|
||||
|
||||
result, err := handler(req)
|
||||
result, err := handler(conn, req)
|
||||
if err != nil {
|
||||
sendError(r.server.errCh, fmt.Errorf("error handling: '%s': %w", req.Method, err))
|
||||
return
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue