Merge pull request #189 from github/jg/inline-go-liveshare
Inline go-liveshare v0.20.0
This commit is contained in:
commit
4eb15134a4
17 changed files with 1364 additions and 5 deletions
|
|
@ -9,7 +9,7 @@ import (
|
|||
"github.com/github/ghcs/cmd/ghcs/output"
|
||||
"github.com/github/ghcs/internal/api"
|
||||
"github.com/github/ghcs/internal/codespaces"
|
||||
"github.com/github/go-liveshare"
|
||||
"github.com/github/ghcs/internal/liveshare"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -14,7 +14,7 @@ import (
|
|||
"github.com/github/ghcs/cmd/ghcs/output"
|
||||
"github.com/github/ghcs/internal/api"
|
||||
"github.com/github/ghcs/internal/codespaces"
|
||||
"github.com/github/go-liveshare"
|
||||
"github.com/github/ghcs/internal/liveshare"
|
||||
"github.com/muhammadmuzzammil1998/jsonc"
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/sync/errgroup"
|
||||
|
|
|
|||
|
|
@ -9,7 +9,7 @@ import (
|
|||
"github.com/github/ghcs/cmd/ghcs/output"
|
||||
"github.com/github/ghcs/internal/api"
|
||||
"github.com/github/ghcs/internal/codespaces"
|
||||
"github.com/github/go-liveshare"
|
||||
"github.com/github/ghcs/internal/liveshare"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/github/ghcs/internal/api"
|
||||
"github.com/github/go-liveshare"
|
||||
"github.com/github/ghcs/internal/liveshare"
|
||||
)
|
||||
|
||||
type logger interface {
|
||||
|
|
|
|||
|
|
@ -10,7 +10,7 @@ import (
|
|||
"time"
|
||||
|
||||
"github.com/github/ghcs/internal/api"
|
||||
"github.com/github/go-liveshare"
|
||||
"github.com/github/ghcs/internal/liveshare"
|
||||
)
|
||||
|
||||
// PostCreateStateStatus is a string value representing the different statuses a state can have.
|
||||
|
|
|
|||
149
internal/liveshare/client.go
Normal file
149
internal/liveshare/client.go
Normal file
|
|
@ -0,0 +1,149 @@
|
|||
// Package liveshare is a Go client library for the Visual Studio Live Share
|
||||
// service, which provides collaborative, distibuted editing and debugging.
|
||||
// See https://docs.microsoft.com/en-us/visualstudio/liveshare for an overview.
|
||||
//
|
||||
// It provides the ability for a Go program to connect to a Live Share
|
||||
// workspace (Connect), to expose a TCP port on a remote host
|
||||
// (UpdateSharedVisibility), to start an SSH server listening on an
|
||||
// exposed port (StartSSHServer), and to forward connections between
|
||||
// the remote port and a local listening TCP port (ForwardToListener)
|
||||
// or a local Go reader/writer (Forward).
|
||||
package liveshare
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/opentracing/opentracing-go"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
// An Options specifies Live Share connection parameters.
|
||||
type Options struct {
|
||||
SessionID string
|
||||
SessionToken string // token for SSH session
|
||||
RelaySAS string
|
||||
RelayEndpoint string
|
||||
TLSConfig *tls.Config // (optional)
|
||||
}
|
||||
|
||||
// uri returns a websocket URL for the specified options.
|
||||
func (opts *Options) uri(action string) (string, error) {
|
||||
if opts.SessionID == "" {
|
||||
return "", errors.New("SessionID is required")
|
||||
}
|
||||
if opts.RelaySAS == "" {
|
||||
return "", errors.New("RelaySAS is required")
|
||||
}
|
||||
if opts.RelayEndpoint == "" {
|
||||
return "", errors.New("RelayEndpoint is required")
|
||||
}
|
||||
|
||||
sas := url.QueryEscape(opts.RelaySAS)
|
||||
uri := opts.RelayEndpoint
|
||||
uri = strings.Replace(uri, "sb:", "wss:", -1)
|
||||
uri = strings.Replace(uri, ".net/", ".net:443/$hc/", 1)
|
||||
uri = uri + "?sb-hc-action=" + action + "&sb-hc-token=" + sas
|
||||
return uri, nil
|
||||
}
|
||||
|
||||
// Connect connects to a Live Share workspace specified by the
|
||||
// options, and returns a session representing the connection.
|
||||
// The caller must call the session's Close method to end the session.
|
||||
func Connect(ctx context.Context, opts Options) (*Session, error) {
|
||||
uri, err := opts.uri("connect")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, "Connect")
|
||||
defer span.Finish()
|
||||
|
||||
sock := newSocket(uri, opts.TLSConfig)
|
||||
if err := sock.connect(ctx); err != nil {
|
||||
return nil, fmt.Errorf("error connecting websocket: %w", err)
|
||||
}
|
||||
|
||||
if opts.SessionToken == "" {
|
||||
return nil, errors.New("SessionToken is required")
|
||||
}
|
||||
ssh := newSSHSession(opts.SessionToken, sock)
|
||||
if err := ssh.connect(ctx); err != nil {
|
||||
return nil, fmt.Errorf("error connecting to ssh session: %w", err)
|
||||
}
|
||||
|
||||
rpc := newRPCClient(ssh)
|
||||
rpc.connect(ctx)
|
||||
|
||||
args := joinWorkspaceArgs{
|
||||
ID: opts.SessionID,
|
||||
ConnectionMode: "local",
|
||||
JoiningUserSessionToken: opts.SessionToken,
|
||||
ClientCapabilities: clientCapabilities{
|
||||
IsNonInteractive: false,
|
||||
},
|
||||
}
|
||||
var result joinWorkspaceResult
|
||||
if err := rpc.do(ctx, "workspace.joinWorkspace", &args, &result); err != nil {
|
||||
return nil, fmt.Errorf("error joining Live Share workspace: %w", err)
|
||||
}
|
||||
|
||||
return &Session{ssh: ssh, rpc: rpc}, nil
|
||||
}
|
||||
|
||||
type clientCapabilities struct {
|
||||
IsNonInteractive bool `json:"isNonInteractive"`
|
||||
}
|
||||
|
||||
type joinWorkspaceArgs struct {
|
||||
ID string `json:"id"`
|
||||
ConnectionMode string `json:"connectionMode"`
|
||||
JoiningUserSessionToken string `json:"joiningUserSessionToken"`
|
||||
ClientCapabilities clientCapabilities `json:"clientCapabilities"`
|
||||
}
|
||||
|
||||
type joinWorkspaceResult struct {
|
||||
SessionNumber int `json:"sessionNumber"`
|
||||
}
|
||||
|
||||
// A channelID is an identifier for an exposed port on a remote
|
||||
// container that may be used to open an SSH channel to it.
|
||||
type channelID struct {
|
||||
name, condition string
|
||||
}
|
||||
|
||||
func (s *Session) openStreamingChannel(ctx context.Context, id channelID) (ssh.Channel, error) {
|
||||
type getStreamArgs struct {
|
||||
StreamName string `json:"streamName"`
|
||||
Condition string `json:"condition"`
|
||||
}
|
||||
args := getStreamArgs{
|
||||
StreamName: id.name,
|
||||
Condition: id.condition,
|
||||
}
|
||||
var streamID string
|
||||
if err := s.rpc.do(ctx, "streamManager.getStream", args, &streamID); err != nil {
|
||||
return nil, fmt.Errorf("error getting stream id: %w", err)
|
||||
}
|
||||
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, "Session.OpenChannel+SendRequest")
|
||||
defer span.Finish()
|
||||
_ = ctx // ctx is not currently used
|
||||
|
||||
channel, reqs, err := s.ssh.conn.OpenChannel("session", nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error opening ssh channel for transport: %w", err)
|
||||
}
|
||||
go ssh.DiscardRequests(reqs)
|
||||
|
||||
requestType := fmt.Sprintf("stream-transport-%s", streamID)
|
||||
if _, err = channel.SendRequest(requestType, true, nil); err != nil {
|
||||
return nil, fmt.Errorf("error sending channel request: %w", err)
|
||||
}
|
||||
|
||||
return channel, nil
|
||||
}
|
||||
71
internal/liveshare/client_test.go
Normal file
71
internal/liveshare/client_test.go
Normal file
|
|
@ -0,0 +1,71 @@
|
|||
package liveshare
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
livesharetest "github.com/github/ghcs/internal/liveshare/test"
|
||||
"github.com/sourcegraph/jsonrpc2"
|
||||
)
|
||||
|
||||
func TestConnect(t *testing.T) {
|
||||
opts := Options{
|
||||
SessionID: "session-id",
|
||||
SessionToken: "session-token",
|
||||
RelaySAS: "relay-sas",
|
||||
}
|
||||
joinWorkspace := func(req *jsonrpc2.Request) (interface{}, error) {
|
||||
var joinWorkspaceReq joinWorkspaceArgs
|
||||
if err := json.Unmarshal(*req.Params, &joinWorkspaceReq); err != nil {
|
||||
return nil, fmt.Errorf("error unmarshaling req: %w", err)
|
||||
}
|
||||
if joinWorkspaceReq.ID != opts.SessionID {
|
||||
return nil, errors.New("connection session id does not match")
|
||||
}
|
||||
if joinWorkspaceReq.ConnectionMode != "local" {
|
||||
return nil, errors.New("connection mode is not local")
|
||||
}
|
||||
if joinWorkspaceReq.JoiningUserSessionToken != opts.SessionToken {
|
||||
return nil, errors.New("connection user token does not match")
|
||||
}
|
||||
if joinWorkspaceReq.ClientCapabilities.IsNonInteractive != false {
|
||||
return nil, errors.New("non interactive is not false")
|
||||
}
|
||||
return joinWorkspaceResult{1}, nil
|
||||
}
|
||||
|
||||
server, err := livesharetest.NewServer(
|
||||
livesharetest.WithPassword(opts.SessionToken),
|
||||
livesharetest.WithService("workspace.joinWorkspace", joinWorkspace),
|
||||
livesharetest.WithRelaySAS(opts.RelaySAS),
|
||||
)
|
||||
if err != nil {
|
||||
t.Errorf("error creating Live Share server: %w", err)
|
||||
}
|
||||
defer server.Close()
|
||||
opts.RelayEndpoint = "sb" + strings.TrimPrefix(server.URL(), "https")
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
opts.TLSConfig = &tls.Config{InsecureSkipVerify: true}
|
||||
|
||||
done := make(chan error)
|
||||
go func() {
|
||||
_, err := Connect(ctx, opts) // ignore session
|
||||
done <- err
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-server.Err():
|
||||
t.Errorf("error from server: %w", err)
|
||||
case err := <-done:
|
||||
if err != nil {
|
||||
t.Errorf("error from client: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
56
internal/liveshare/options_test.go
Normal file
56
internal/liveshare/options_test.go
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
package liveshare
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestBadOptions(t *testing.T) {
|
||||
goodOptions := Options{
|
||||
SessionID: "sess-id",
|
||||
SessionToken: "sess-token",
|
||||
RelaySAS: "sas",
|
||||
RelayEndpoint: "endpoint",
|
||||
}
|
||||
|
||||
opts := goodOptions
|
||||
opts.SessionID = ""
|
||||
checkBadOptions(t, opts)
|
||||
|
||||
opts = goodOptions
|
||||
opts.SessionToken = ""
|
||||
checkBadOptions(t, opts)
|
||||
|
||||
opts = goodOptions
|
||||
opts.RelaySAS = ""
|
||||
checkBadOptions(t, opts)
|
||||
|
||||
opts = goodOptions
|
||||
opts.RelayEndpoint = ""
|
||||
checkBadOptions(t, opts)
|
||||
|
||||
opts = Options{}
|
||||
checkBadOptions(t, opts)
|
||||
}
|
||||
|
||||
func checkBadOptions(t *testing.T, opts Options) {
|
||||
if _, err := Connect(context.Background(), opts); err == nil {
|
||||
t.Errorf("Connect(%+v): no error", opts)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOptionsURI(t *testing.T) {
|
||||
opts := Options{
|
||||
SessionID: "sess-id",
|
||||
SessionToken: "sess-token",
|
||||
RelaySAS: "sas",
|
||||
RelayEndpoint: "sb://endpoint/.net/liveshare",
|
||||
}
|
||||
uri, err := opts.uri("connect")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if uri != "wss://endpoint/.net:443/$hc/liveshare?sb-hc-action=connect&sb-hc-token=sas" {
|
||||
t.Errorf("uri is not correct, got: '%v'", uri)
|
||||
}
|
||||
}
|
||||
162
internal/liveshare/port_forwarder.go
Normal file
162
internal/liveshare/port_forwarder.go
Normal file
|
|
@ -0,0 +1,162 @@
|
|||
package liveshare
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
|
||||
"github.com/opentracing/opentracing-go"
|
||||
)
|
||||
|
||||
// A PortForwarder forwards TCP traffic over a Live Share session from a port on a remote
|
||||
// container to a local destination such as a network port or Go reader/writer.
|
||||
type PortForwarder struct {
|
||||
session *Session
|
||||
name string
|
||||
remotePort int
|
||||
}
|
||||
|
||||
// NewPortForwarder returns a new PortForwarder for the specified
|
||||
// remote port and Live Share session. The name describes the purpose
|
||||
// of the remote port or service.
|
||||
func NewPortForwarder(session *Session, name string, remotePort int) *PortForwarder {
|
||||
return &PortForwarder{
|
||||
session: session,
|
||||
name: name,
|
||||
remotePort: remotePort,
|
||||
}
|
||||
}
|
||||
|
||||
// ForwardToListener forwards traffic between the container's remote
|
||||
// port and a local port, which must already be listening for
|
||||
// connections. (Accepting a listener rather than a port number avoids
|
||||
// races against other processes opening ports, and against a client
|
||||
// connecting to the socket prematurely.)
|
||||
//
|
||||
// ForwardToListener accepts and handles connections on the local port
|
||||
// until it encounters the first error, which may include context
|
||||
// cancellation. Its error result is always non-nil. The caller is
|
||||
// responsible for closing the listening port.
|
||||
func (fwd *PortForwarder) ForwardToListener(ctx context.Context, listen net.Listener) (err error) {
|
||||
id, err := fwd.shareRemotePort(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
errc := make(chan error, 1)
|
||||
sendError := func(err error) {
|
||||
// Use non-blocking send, to avoid goroutines getting
|
||||
// stuck in case of concurrent or sequential errors.
|
||||
select {
|
||||
case errc <- err:
|
||||
default:
|
||||
}
|
||||
}
|
||||
go func() {
|
||||
for {
|
||||
conn, err := listen.Accept()
|
||||
if err != nil {
|
||||
sendError(err)
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
if err := fwd.handleConnection(ctx, id, conn); err != nil {
|
||||
sendError(err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}()
|
||||
|
||||
return awaitError(ctx, errc)
|
||||
}
|
||||
|
||||
// Forward forwards traffic between the container's remote port and
|
||||
// the specified read/write stream. On return, the stream is closed.
|
||||
func (fwd *PortForwarder) Forward(ctx context.Context, conn io.ReadWriteCloser) error {
|
||||
id, err := fwd.shareRemotePort(ctx)
|
||||
if err != nil {
|
||||
conn.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
// Create buffered channel so that send doesn't get stuck after context cancellation.
|
||||
errc := make(chan error, 1)
|
||||
go func() {
|
||||
errc <- fwd.handleConnection(ctx, id, conn)
|
||||
}()
|
||||
return awaitError(ctx, errc)
|
||||
}
|
||||
|
||||
func (fwd *PortForwarder) shareRemotePort(ctx context.Context) (channelID, error) {
|
||||
id, err := fwd.session.startSharing(ctx, fwd.name, fwd.remotePort)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to share remote port %d: %w", fwd.remotePort, err)
|
||||
}
|
||||
return id, err
|
||||
}
|
||||
|
||||
func awaitError(ctx context.Context, errc <-chan error) error {
|
||||
select {
|
||||
case err := <-errc:
|
||||
return err
|
||||
case <-ctx.Done():
|
||||
return ctx.Err() // canceled
|
||||
}
|
||||
}
|
||||
|
||||
// handleConnection handles forwarding for a single accepted connection, then closes it.
|
||||
func (fwd *PortForwarder) handleConnection(ctx context.Context, id channelID, conn io.ReadWriteCloser) (err error) {
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, "PortForwarder.handleConnection")
|
||||
defer span.Finish()
|
||||
|
||||
defer safeClose(conn, &err)
|
||||
|
||||
channel, err := fwd.session.openStreamingChannel(ctx, id)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error opening streaming channel for new connection: %w", err)
|
||||
}
|
||||
// Ideally we would call safeClose again, but (*ssh.channel).Close
|
||||
// appears to have a bug that causes it return io.EOF spuriously
|
||||
// if its peer closed first; see github.com/golang/go/issues/38115.
|
||||
defer func() {
|
||||
closeErr := channel.Close()
|
||||
if err == nil && closeErr != io.EOF {
|
||||
err = closeErr
|
||||
}
|
||||
}()
|
||||
|
||||
// bi-directional copy of data.
|
||||
errs := make(chan error, 2)
|
||||
copyConn := func(w io.Writer, r io.Reader) {
|
||||
_, err := io.Copy(w, r)
|
||||
errs <- err
|
||||
}
|
||||
go copyConn(conn, channel)
|
||||
go copyConn(channel, conn)
|
||||
|
||||
// Wait until context is cancelled or both copies are done.
|
||||
// Discard errors from io.Copy; they should not cause (e.g.) ForwardToListener to fail.
|
||||
// TODO: how can we proxy errors from Copy so that each peer can distinguish an error from a short file?
|
||||
for i := 0; ; {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-errs:
|
||||
i++
|
||||
if i == 2 {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// safeClose reports the error (to *err) from closing the stream only
|
||||
// if no other error was previously reported.
|
||||
func safeClose(closer io.Closer, err *error) {
|
||||
closeErr := closer.Close()
|
||||
if *err == nil {
|
||||
*err = closeErr
|
||||
}
|
||||
}
|
||||
95
internal/liveshare/port_forwarder_test.go
Normal file
95
internal/liveshare/port_forwarder_test.go
Normal file
|
|
@ -0,0 +1,95 @@
|
|||
package liveshare
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
livesharetest "github.com/github/ghcs/internal/liveshare/test"
|
||||
"github.com/sourcegraph/jsonrpc2"
|
||||
)
|
||||
|
||||
func TestNewPortForwarder(t *testing.T) {
|
||||
testServer, session, err := makeMockSession()
|
||||
if err != nil {
|
||||
t.Errorf("create mock client: %w", err)
|
||||
}
|
||||
defer testServer.Close()
|
||||
pf := NewPortForwarder(session, "ssh", 80)
|
||||
if pf == nil {
|
||||
t.Error("port forwarder is nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestPortForwarderStart(t *testing.T) {
|
||||
streamName, streamCondition := "stream-name", "stream-condition"
|
||||
serverSharing := func(req *jsonrpc2.Request) (interface{}, error) {
|
||||
return Port{StreamName: streamName, StreamCondition: streamCondition}, nil
|
||||
}
|
||||
getStream := func(req *jsonrpc2.Request) (interface{}, error) {
|
||||
return "stream-id", nil
|
||||
}
|
||||
|
||||
stream := bytes.NewBufferString("stream-data")
|
||||
testServer, session, err := makeMockSession(
|
||||
livesharetest.WithService("serverSharing.startSharing", serverSharing),
|
||||
livesharetest.WithService("streamManager.getStream", getStream),
|
||||
livesharetest.WithStream("stream-id", stream),
|
||||
)
|
||||
if err != nil {
|
||||
t.Errorf("create mock session: %w", err)
|
||||
}
|
||||
defer testServer.Close()
|
||||
|
||||
listen, err := net.Listen("tcp", ":8000")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer listen.Close()
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
done := make(chan error)
|
||||
go func() {
|
||||
const name, remote = "ssh", 8000
|
||||
done <- NewPortForwarder(session, name, remote).ForwardToListener(ctx, listen)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
var conn net.Conn
|
||||
retries := 0
|
||||
for conn == nil && retries < 2 {
|
||||
conn, err = net.DialTimeout("tcp", ":8000", 2*time.Second)
|
||||
time.Sleep(1 * time.Second)
|
||||
}
|
||||
if conn == nil {
|
||||
done <- errors.New("failed to connect to forwarded port")
|
||||
}
|
||||
b := make([]byte, len("stream-data"))
|
||||
if _, err := conn.Read(b); err != nil && err != io.EOF {
|
||||
done <- fmt.Errorf("reading stream: %w", err)
|
||||
}
|
||||
if string(b) != "stream-data" {
|
||||
done <- fmt.Errorf("stream data is not expected value, got: %s", string(b))
|
||||
}
|
||||
if _, err := conn.Write([]byte("new-data")); err != nil {
|
||||
done <- fmt.Errorf("writing to stream: %w", err)
|
||||
}
|
||||
done <- nil
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-testServer.Err():
|
||||
t.Errorf("error from server: %w", err)
|
||||
case err := <-done:
|
||||
if err != nil {
|
||||
t.Errorf("error from client: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
41
internal/liveshare/rpc.go
Normal file
41
internal/liveshare/rpc.go
Normal file
|
|
@ -0,0 +1,41 @@
|
|||
package liveshare
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/opentracing/opentracing-go"
|
||||
"github.com/sourcegraph/jsonrpc2"
|
||||
)
|
||||
|
||||
type rpcClient struct {
|
||||
*jsonrpc2.Conn
|
||||
conn io.ReadWriteCloser
|
||||
}
|
||||
|
||||
func newRPCClient(conn io.ReadWriteCloser) *rpcClient {
|
||||
return &rpcClient{conn: conn}
|
||||
}
|
||||
|
||||
func (r *rpcClient) connect(ctx context.Context) {
|
||||
stream := jsonrpc2.NewBufferedStream(r.conn, jsonrpc2.VSCodeObjectCodec{})
|
||||
r.Conn = jsonrpc2.NewConn(ctx, stream, nullHandler{})
|
||||
}
|
||||
|
||||
func (r *rpcClient) do(ctx context.Context, method string, args, result interface{}) error {
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, method)
|
||||
defer span.Finish()
|
||||
|
||||
waiter, err := r.Conn.DispatchCall(ctx, method, args)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error dispatching %q call: %w", method, err)
|
||||
}
|
||||
|
||||
return waiter.Wait(ctx, result)
|
||||
}
|
||||
|
||||
type nullHandler struct{}
|
||||
|
||||
func (nullHandler) Handle(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2.Request) {
|
||||
}
|
||||
99
internal/liveshare/session.go
Normal file
99
internal/liveshare/session.go
Normal file
|
|
@ -0,0 +1,99 @@
|
|||
package liveshare
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
// A Session represents the session between a connected Live Share client and server.
|
||||
type Session struct {
|
||||
ssh *sshSession
|
||||
rpc *rpcClient
|
||||
}
|
||||
|
||||
// Close should be called by users to clean up RPC and SSH resources whenever the session
|
||||
// is no longer active.
|
||||
func (s *Session) Close() error {
|
||||
// Closing the RPC conn closes the underlying stream (SSH)
|
||||
// So we only need to close once
|
||||
if err := s.rpc.Close(); err != nil {
|
||||
s.ssh.Close() // close SSH and ignore error
|
||||
return fmt.Errorf("error while closing Live Share session: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Port describes a port exposed by the container.
|
||||
type Port struct {
|
||||
SourcePort int `json:"sourcePort"`
|
||||
DestinationPort int `json:"destinationPort"`
|
||||
SessionName string `json:"sessionName"`
|
||||
StreamName string `json:"streamName"`
|
||||
StreamCondition string `json:"streamCondition"`
|
||||
BrowseURL string `json:"browseUrl"`
|
||||
IsPublic bool `json:"isPublic"`
|
||||
IsTCPServerConnectionEstablished bool `json:"isTCPServerConnectionEstablished"`
|
||||
HasTLSHandshakePassed bool `json:"hasTLSHandshakePassed"`
|
||||
}
|
||||
|
||||
// startSharing tells the Live Share host to start sharing the specified port from the container.
|
||||
// The sessionName describes the purpose of the remote port or service.
|
||||
// It returns an identifier that can be used to open an SSH channel to the remote port.
|
||||
func (s *Session) startSharing(ctx context.Context, sessionName string, port int) (channelID, error) {
|
||||
args := []interface{}{port, sessionName, fmt.Sprintf("http://localhost:%d", port)}
|
||||
var response Port
|
||||
if err := s.rpc.do(ctx, "serverSharing.startSharing", args, &response); err != nil {
|
||||
return channelID{}, err
|
||||
}
|
||||
|
||||
return channelID{response.StreamName, response.StreamCondition}, nil
|
||||
}
|
||||
|
||||
// GetSharedServers returns a description of each container port
|
||||
// shared by a prior call to StartSharing by some client.
|
||||
func (s *Session) GetSharedServers(ctx context.Context) ([]*Port, error) {
|
||||
var response []*Port
|
||||
if err := s.rpc.do(ctx, "serverSharing.getSharedServers", []string{}, &response); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return response, nil
|
||||
}
|
||||
|
||||
// UpdateSharedVisibility controls port permissions and whether it can be accessed publicly
|
||||
// via the Browse URL
|
||||
func (s *Session) UpdateSharedVisibility(ctx context.Context, port int, public bool) error {
|
||||
if err := s.rpc.do(ctx, "serverSharing.updateSharedServerVisibility", []interface{}{port, public}, nil); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// StartsSSHServer starts an SSH server in the container, installing sshd if necessary,
|
||||
// and returns the port on which it listens and the user name clients should provide.
|
||||
func (s *Session) StartSSHServer(ctx context.Context) (int, string, error) {
|
||||
var response struct {
|
||||
Result bool `json:"result"`
|
||||
ServerPort string `json:"serverPort"`
|
||||
User string `json:"user"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
if err := s.rpc.do(ctx, "ISshServerHostService.startRemoteServer", []string{}, &response); err != nil {
|
||||
return 0, "", err
|
||||
}
|
||||
|
||||
if !response.Result {
|
||||
return 0, "", fmt.Errorf("failed to start server: %s", response.Message)
|
||||
}
|
||||
|
||||
port, err := strconv.Atoi(response.ServerPort)
|
||||
if err != nil {
|
||||
return 0, "", fmt.Errorf("failed to parse port: %w", err)
|
||||
}
|
||||
|
||||
return port, response.User, nil
|
||||
}
|
||||
196
internal/liveshare/session_test.go
Normal file
196
internal/liveshare/session_test.go
Normal file
|
|
@ -0,0 +1,196 @@
|
|||
package liveshare
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
livesharetest "github.com/github/ghcs/internal/liveshare/test"
|
||||
"github.com/sourcegraph/jsonrpc2"
|
||||
)
|
||||
|
||||
func makeMockSession(opts ...livesharetest.ServerOption) (*livesharetest.Server, *Session, error) {
|
||||
joinWorkspace := func(req *jsonrpc2.Request) (interface{}, error) {
|
||||
return joinWorkspaceResult{1}, nil
|
||||
}
|
||||
const sessionToken = "session-token"
|
||||
opts = append(
|
||||
opts,
|
||||
livesharetest.WithPassword(sessionToken),
|
||||
livesharetest.WithService("workspace.joinWorkspace", joinWorkspace),
|
||||
)
|
||||
testServer, err := livesharetest.NewServer(opts...)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("error creating server: %w", err)
|
||||
}
|
||||
|
||||
session, err := Connect(context.Background(), Options{
|
||||
SessionID: "session-id",
|
||||
SessionToken: sessionToken,
|
||||
RelayEndpoint: "sb" + strings.TrimPrefix(testServer.URL(), "https"),
|
||||
RelaySAS: "relay-sas",
|
||||
TLSConfig: &tls.Config{InsecureSkipVerify: true},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("error connecting to Live Share: %w", err)
|
||||
}
|
||||
return testServer, session, nil
|
||||
}
|
||||
|
||||
func TestServerStartSharing(t *testing.T) {
|
||||
serverPort, serverProtocol := 2222, "sshd"
|
||||
startSharing := func(req *jsonrpc2.Request) (interface{}, error) {
|
||||
var args []interface{}
|
||||
if err := json.Unmarshal(*req.Params, &args); err != nil {
|
||||
return nil, fmt.Errorf("error unmarshaling request: %w", err)
|
||||
}
|
||||
if len(args) < 3 {
|
||||
return nil, errors.New("not enough arguments to start sharing")
|
||||
}
|
||||
if port, ok := args[0].(float64); !ok {
|
||||
return nil, errors.New("port argument is not an int")
|
||||
} else if port != float64(serverPort) {
|
||||
return nil, errors.New("port does not match serverPort")
|
||||
}
|
||||
if protocol, ok := args[1].(string); !ok {
|
||||
return nil, errors.New("protocol argument is not a string")
|
||||
} else if protocol != serverProtocol {
|
||||
return nil, errors.New("protocol does not match serverProtocol")
|
||||
}
|
||||
if browseURL, ok := args[2].(string); !ok {
|
||||
return nil, errors.New("browse url is not a string")
|
||||
} else if browseURL != fmt.Sprintf("http://localhost:%d", serverPort) {
|
||||
return nil, errors.New("browseURL does not match expected")
|
||||
}
|
||||
return Port{StreamName: "stream-name", StreamCondition: "stream-condition"}, nil
|
||||
}
|
||||
testServer, session, err := makeMockSession(
|
||||
livesharetest.WithService("serverSharing.startSharing", startSharing),
|
||||
)
|
||||
defer testServer.Close() //nolint:staticcheck // httptest.Server does not return errors on Close()
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("error creating mock session: %w", err)
|
||||
}
|
||||
ctx := context.Background()
|
||||
|
||||
done := make(chan error)
|
||||
go func() {
|
||||
streamID, err := session.startSharing(ctx, serverProtocol, serverPort)
|
||||
if err != nil {
|
||||
done <- fmt.Errorf("error sharing server: %w", err)
|
||||
}
|
||||
if streamID.name == "" || streamID.condition == "" {
|
||||
done <- errors.New("stream name or condition is blank")
|
||||
}
|
||||
done <- nil
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-testServer.Err():
|
||||
t.Errorf("error from server: %w", err)
|
||||
case err := <-done:
|
||||
if err != nil {
|
||||
t.Errorf("error from client: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerGetSharedServers(t *testing.T) {
|
||||
sharedServer := Port{
|
||||
SourcePort: 2222,
|
||||
StreamName: "stream-name",
|
||||
StreamCondition: "stream-condition",
|
||||
}
|
||||
getSharedServers := func(req *jsonrpc2.Request) (interface{}, error) {
|
||||
return []*Port{&sharedServer}, nil
|
||||
}
|
||||
testServer, session, err := makeMockSession(
|
||||
livesharetest.WithService("serverSharing.getSharedServers", getSharedServers),
|
||||
)
|
||||
if err != nil {
|
||||
t.Errorf("error creating mock session: %w", err)
|
||||
}
|
||||
defer testServer.Close()
|
||||
ctx := context.Background()
|
||||
done := make(chan error)
|
||||
go func() {
|
||||
ports, err := session.GetSharedServers(ctx)
|
||||
if err != nil {
|
||||
done <- fmt.Errorf("error getting shared servers: %w", err)
|
||||
}
|
||||
if len(ports) < 1 {
|
||||
done <- errors.New("not enough ports returned")
|
||||
}
|
||||
if ports[0].SourcePort != sharedServer.SourcePort {
|
||||
done <- errors.New("source port does not match")
|
||||
}
|
||||
if ports[0].StreamName != sharedServer.StreamName {
|
||||
done <- errors.New("stream name does not match")
|
||||
}
|
||||
if ports[0].StreamCondition != sharedServer.StreamCondition {
|
||||
done <- errors.New("stream condiion does not match")
|
||||
}
|
||||
done <- nil
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-testServer.Err():
|
||||
t.Errorf("error from server: %w", err)
|
||||
case err := <-done:
|
||||
if err != nil {
|
||||
t.Errorf("error from client: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestServerUpdateSharedVisibility(t *testing.T) {
|
||||
updateSharedVisibility := func(rpcReq *jsonrpc2.Request) (interface{}, error) {
|
||||
var req []interface{}
|
||||
if err := json.Unmarshal(*rpcReq.Params, &req); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal req: %w", err)
|
||||
}
|
||||
if len(req) < 2 {
|
||||
return nil, errors.New("request arguments is less than 2")
|
||||
}
|
||||
if port, ok := req[0].(float64); ok {
|
||||
if port != 80.0 {
|
||||
return nil, errors.New("port param is not expected value")
|
||||
}
|
||||
} else {
|
||||
return nil, errors.New("port param is not a float64")
|
||||
}
|
||||
if public, ok := req[1].(bool); ok {
|
||||
if public != true {
|
||||
return nil, errors.New("pulic param is not expected value")
|
||||
}
|
||||
} else {
|
||||
return nil, errors.New("public param is not a bool")
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
testServer, session, err := makeMockSession(
|
||||
livesharetest.WithService("serverSharing.updateSharedServerVisibility", updateSharedVisibility),
|
||||
)
|
||||
if err != nil {
|
||||
t.Errorf("creating mock session: %w", err)
|
||||
}
|
||||
defer testServer.Close()
|
||||
ctx := context.Background()
|
||||
done := make(chan error)
|
||||
go func() {
|
||||
done <- session.UpdateSharedVisibility(ctx, 80, true)
|
||||
}()
|
||||
select {
|
||||
case err := <-testServer.Err():
|
||||
t.Errorf("error from server: %w", err)
|
||||
case err := <-done:
|
||||
if err != nil {
|
||||
t.Errorf("error from client: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
100
internal/liveshare/socket.go
Normal file
100
internal/liveshare/socket.go
Normal file
|
|
@ -0,0 +1,100 @@
|
|||
package liveshare
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
type socket struct {
|
||||
addr string
|
||||
tlsConfig *tls.Config
|
||||
|
||||
conn *websocket.Conn
|
||||
reader io.Reader
|
||||
}
|
||||
|
||||
func newSocket(uri string, tlsConfig *tls.Config) *socket {
|
||||
return &socket{addr: uri, tlsConfig: tlsConfig}
|
||||
}
|
||||
|
||||
func (s *socket) connect(ctx context.Context) error {
|
||||
dialer := websocket.Dialer{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
HandshakeTimeout: 45 * time.Second,
|
||||
TLSClientConfig: s.tlsConfig,
|
||||
}
|
||||
ws, _, err := dialer.Dial(s.addr, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.conn = ws
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *socket) Read(b []byte) (int, error) {
|
||||
if s.reader == nil {
|
||||
_, reader, err := s.conn.NextReader()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
s.reader = reader
|
||||
}
|
||||
|
||||
bytesRead, err := s.reader.Read(b)
|
||||
if err != nil {
|
||||
s.reader = nil
|
||||
|
||||
if err == io.EOF {
|
||||
err = nil
|
||||
}
|
||||
}
|
||||
|
||||
return bytesRead, err
|
||||
}
|
||||
|
||||
func (s *socket) Write(b []byte) (int, error) {
|
||||
nextWriter, err := s.conn.NextWriter(websocket.BinaryMessage)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
bytesWritten, err := nextWriter.Write(b)
|
||||
nextWriter.Close()
|
||||
|
||||
return bytesWritten, err
|
||||
}
|
||||
|
||||
func (s *socket) Close() error {
|
||||
return s.conn.Close()
|
||||
}
|
||||
|
||||
func (s *socket) LocalAddr() net.Addr {
|
||||
return s.conn.LocalAddr()
|
||||
}
|
||||
|
||||
func (s *socket) RemoteAddr() net.Addr {
|
||||
return s.conn.RemoteAddr()
|
||||
}
|
||||
|
||||
func (s *socket) SetDeadline(t time.Time) error {
|
||||
if err := s.SetReadDeadline(t); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return s.SetWriteDeadline(t)
|
||||
}
|
||||
|
||||
func (s *socket) SetReadDeadline(t time.Time) error {
|
||||
return s.conn.SetReadDeadline(t)
|
||||
}
|
||||
|
||||
func (s *socket) SetWriteDeadline(t time.Time) error {
|
||||
return s.conn.SetWriteDeadline(t)
|
||||
}
|
||||
68
internal/liveshare/ssh.go
Normal file
68
internal/liveshare/ssh.go
Normal file
|
|
@ -0,0 +1,68 @@
|
|||
package liveshare
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
type sshSession struct {
|
||||
*ssh.Session
|
||||
token string
|
||||
socket net.Conn
|
||||
conn ssh.Conn
|
||||
reader io.Reader
|
||||
writer io.Writer
|
||||
}
|
||||
|
||||
func newSSHSession(token string, socket net.Conn) *sshSession {
|
||||
return &sshSession{token: token, socket: socket}
|
||||
}
|
||||
|
||||
func (s *sshSession) connect(ctx context.Context) error {
|
||||
clientConfig := ssh.ClientConfig{
|
||||
User: "",
|
||||
Auth: []ssh.AuthMethod{
|
||||
ssh.Password(s.token),
|
||||
},
|
||||
HostKeyAlgorithms: []string{"rsa-sha2-512", "rsa-sha2-256"},
|
||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||
Timeout: 10 * time.Second,
|
||||
}
|
||||
|
||||
sshClientConn, chans, reqs, err := ssh.NewClientConn(s.socket, "", &clientConfig)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating ssh client connection: %w", err)
|
||||
}
|
||||
s.conn = sshClientConn
|
||||
|
||||
sshClient := ssh.NewClient(sshClientConn, chans, reqs)
|
||||
s.Session, err = sshClient.NewSession()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating ssh client session: %w", err)
|
||||
}
|
||||
|
||||
s.reader, err = s.Session.StdoutPipe()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating ssh session reader: %w", err)
|
||||
}
|
||||
|
||||
s.writer, err = s.Session.StdinPipe()
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating ssh session writer: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *sshSession) Read(p []byte) (n int, err error) {
|
||||
return s.reader.Read(p)
|
||||
}
|
||||
|
||||
func (s *sshSession) Write(p []byte) (n int, err error) {
|
||||
return s.writer.Write(p)
|
||||
}
|
||||
245
internal/liveshare/test/server.go
Normal file
245
internal/liveshare/test/server.go
Normal file
|
|
@ -0,0 +1,245 @@
|
|||
package livesharetest
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/sourcegraph/jsonrpc2"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
const sshPrivateKey = `-----BEGIN RSA PRIVATE KEY-----
|
||||
MIIEogIBAAKCAQEAp/Jmzy/HaPNx5Bug09FX5Q/KGY4G9c4DfplhWrn31OQCqNiT
|
||||
ZSLd46rdXC75liHzE7e5Ic0RJN61cYN9SNArjvEXx2vvs7szhwO7LonwPOvpYpUf
|
||||
daayrgbr6S46plpx+hEZ1kO/6BqMgFuvnkIVThrEyx5b48ll8zgDABsYrKF8/p1V
|
||||
SjGfb+bLwjn1NtnZF2prBG5P4ZtMR06HaPglLqBJhmc0ZMG5IZGUE7ew/VrPDqdC
|
||||
f1v4XvvGiU4BLoKYy4QOhyrCGh9Uk/9u0Ea56M2bh4RqwhbpR8m7TYJZ0DVMLbGW
|
||||
8C+4lCWp+xRyBNxAQh8qeQVCxYl02hPE4bXLGQIDAQABAoIBAEoVPk6UZ+UexhV2
|
||||
LnphNOFhFqgxI1bYWmhE5lHsCKuLLLUoW9RYDgL4gw6/1e7o6N3AxFRpre9Soj0B
|
||||
YIl28k/qf6/DKAhjQnaDKdV8mVF2Swvmdesi7lyfxv6kGtD4wqApXPlMB2IuG94f
|
||||
E5e+1MEQQ9DJgoU3eNZR1dj9GuRC3PyzPcNNJ2R/MMGFw3sOOVcLOgAukotoicuL
|
||||
0SiL51rHPQu8a5/darH9EltN1GFeceJSDDhgqMP5T8Tp7g/c3//H6szon4H9W+uN
|
||||
Z3UrImJ+teJjFOaVDqN93+J2eQSUk0lCPGQCd4U9I4AGDGyU6ucdcLQ58Aha9gmU
|
||||
uQwkfKUCgYEA0UkuPOSDE9dbXe+yhsbOwMb1kKzJYgFDKjRTSP7D9BOMZu4YyASo
|
||||
J95R4DWjePlDopafG2tNJoWX+CwUl7Uld1R3Ex6xHBa2B7hwZj860GZtr7D4mdWc
|
||||
DTVjczAjp4P0K1MIFYQui1mVJterkjKuePiI6q/27L1c2jIa/39BWBcCgYEAzW8R
|
||||
MFZamVw3eA2JYSpBuqhQgE5gX5IWrmVJZSUhpAQTNG/A4nxf7WGtjy9p99tm0RMb
|
||||
ld05+sOmNLrzw8Pq8SBpFOd+MAca7lPLS1A2CoaAHbOqRqrzVcZ4EZ2jB3WjoLoq
|
||||
yctwslGb9KmrhBCdcwT48aPAYUIJCZdqEen2xE8CgYBoMowvywGrvjwCH9X9njvP
|
||||
5P7cAfrdrY04FQcmP5lmCtmLYZ267/6couaWv33dPBU9fMpIh3rI5BiOebvi8FBw
|
||||
AgCq50v8lR4Z5+0mKvLoUSbpIy4SwTRJqzwRXHVT8LF/ZH6Q39egj4Bf716/kjYl
|
||||
im/4kJVatsjk5a9lZ4EsDwKBgERkJ3rKJNtNggHrr8KzSLKVekdc0GTAw+BHRAny
|
||||
NKLf4Gzij3pXIbBrhlZW2JZ1amNMUzCvN7AuFlUTsDeKL9saiSE2eCIRG3wgVVu7
|
||||
VmJmqJw6xgNEwkHaEvr6Wd4P4euOTtRjcB9NX/gxzDHpPiGelCoN8+vtCgkxaVSR
|
||||
aV+tAoGAO4HtLOfBAVDNbVXa27aJAjQSUq8qfkwUNJNz+rwgpVQahfiVkyqAPCQM
|
||||
IfRJxKWb0Wbt9ojw3AowK/k0d3LZA7FS41JSiiGKIllSGb+i7JKqKW7RHLA3VJ/E
|
||||
Bq5TLNIbUzPVNVwRcGjUYpOhKU6EIw8phTJOvxnUC+g6MVqBP8U=
|
||||
-----END RSA PRIVATE KEY-----`
|
||||
|
||||
type Server struct {
|
||||
password string
|
||||
services map[string]RPCHandleFunc
|
||||
relaySAS string
|
||||
streams map[string]io.ReadWriter
|
||||
|
||||
sshConfig *ssh.ServerConfig
|
||||
httptestServer *httptest.Server
|
||||
errCh chan error
|
||||
}
|
||||
|
||||
func NewServer(opts ...ServerOption) (*Server, error) {
|
||||
server := new(Server)
|
||||
|
||||
for _, o := range opts {
|
||||
if err := o(server); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
server.sshConfig = &ssh.ServerConfig{
|
||||
PasswordCallback: sshPasswordCallback(server.password),
|
||||
}
|
||||
privateKey, err := ssh.ParsePrivateKey([]byte(sshPrivateKey))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error parsing key: %w", err)
|
||||
}
|
||||
server.sshConfig.AddHostKey(privateKey)
|
||||
|
||||
server.errCh = make(chan error)
|
||||
server.httptestServer = httptest.NewTLSServer(http.HandlerFunc(makeConnection(server)))
|
||||
return server, nil
|
||||
}
|
||||
|
||||
type ServerOption func(*Server) error
|
||||
|
||||
func WithPassword(password string) ServerOption {
|
||||
return func(s *Server) error {
|
||||
s.password = password
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func WithService(serviceName string, handler RPCHandleFunc) ServerOption {
|
||||
return func(s *Server) error {
|
||||
if s.services == nil {
|
||||
s.services = make(map[string]RPCHandleFunc)
|
||||
}
|
||||
|
||||
s.services[serviceName] = handler
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func WithRelaySAS(sas string) ServerOption {
|
||||
return func(s *Server) error {
|
||||
s.relaySAS = sas
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func WithStream(name string, stream io.ReadWriter) ServerOption {
|
||||
return func(s *Server) error {
|
||||
if s.streams == nil {
|
||||
s.streams = make(map[string]io.ReadWriter)
|
||||
}
|
||||
s.streams[name] = stream
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func sshPasswordCallback(serverPassword string) func(ssh.ConnMetadata, []byte) (*ssh.Permissions, error) {
|
||||
return func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
|
||||
if string(password) == serverPassword {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, errors.New("password rejected")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) Close() {
|
||||
s.httptestServer.Close()
|
||||
}
|
||||
|
||||
func (s *Server) URL() string {
|
||||
return s.httptestServer.URL
|
||||
}
|
||||
|
||||
func (s *Server) Err() <-chan error {
|
||||
return s.errCh
|
||||
}
|
||||
|
||||
var upgrader = websocket.Upgrader{}
|
||||
|
||||
func makeConnection(server *Server) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, req *http.Request) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
if server.relaySAS != "" {
|
||||
// validate the sas key
|
||||
sasParam := req.URL.Query().Get("sb-hc-token")
|
||||
if sasParam != server.relaySAS {
|
||||
server.errCh <- errors.New("error validating sas")
|
||||
return
|
||||
}
|
||||
}
|
||||
c, err := upgrader.Upgrade(w, req, nil)
|
||||
if err != nil {
|
||||
server.errCh <- fmt.Errorf("error upgrading connection: %w", err)
|
||||
return
|
||||
}
|
||||
defer c.Close()
|
||||
|
||||
socketConn := newSocketConn(c)
|
||||
_, chans, reqs, err := ssh.NewServerConn(socketConn, server.sshConfig)
|
||||
if err != nil {
|
||||
server.errCh <- fmt.Errorf("error creating new ssh conn: %w", err)
|
||||
return
|
||||
}
|
||||
go ssh.DiscardRequests(reqs)
|
||||
|
||||
for newChannel := range chans {
|
||||
ch, reqs, err := newChannel.Accept()
|
||||
if err != nil {
|
||||
server.errCh <- fmt.Errorf("error accepting new channel: %w", err)
|
||||
return
|
||||
}
|
||||
go handleNewRequests(ctx, server, ch, reqs)
|
||||
go handleNewChannel(server, ch)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func handleNewRequests(ctx context.Context, server *Server, channel ssh.Channel, reqs <-chan *ssh.Request) {
|
||||
for req := range reqs {
|
||||
if req.WantReply {
|
||||
if err := req.Reply(true, nil); err != nil {
|
||||
server.errCh <- fmt.Errorf("error replying to channel request: %w", err)
|
||||
}
|
||||
}
|
||||
if strings.HasPrefix(req.Type, "stream-transport") {
|
||||
forwardStream(ctx, server, req.Type, channel)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func forwardStream(ctx context.Context, server *Server, streamName string, channel ssh.Channel) {
|
||||
simpleStreamName := strings.TrimPrefix(streamName, "stream-transport-")
|
||||
stream, found := server.streams[simpleStreamName]
|
||||
if !found {
|
||||
server.errCh <- fmt.Errorf("stream '%s' not found", simpleStreamName)
|
||||
return
|
||||
}
|
||||
|
||||
copy := func(dst io.Writer, src io.Reader) {
|
||||
if _, err := io.Copy(dst, src); err != nil {
|
||||
fmt.Println(err)
|
||||
server.errCh <- fmt.Errorf("io copy: %w", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
go copy(stream, channel)
|
||||
go copy(channel, stream)
|
||||
|
||||
<-ctx.Done() // TODO(josebalius): improve this
|
||||
}
|
||||
|
||||
func handleNewChannel(server *Server, channel ssh.Channel) {
|
||||
stream := jsonrpc2.NewBufferedStream(channel, jsonrpc2.VSCodeObjectCodec{})
|
||||
jsonrpc2.NewConn(context.Background(), stream, newRPCHandler(server))
|
||||
}
|
||||
|
||||
type RPCHandleFunc func(req *jsonrpc2.Request) (interface{}, error)
|
||||
|
||||
type rpcHandler struct {
|
||||
server *Server
|
||||
}
|
||||
|
||||
func newRPCHandler(server *Server) *rpcHandler {
|
||||
return &rpcHandler{server}
|
||||
}
|
||||
|
||||
func (r *rpcHandler) Handle(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2.Request) {
|
||||
handler, found := r.server.services[req.Method]
|
||||
if !found {
|
||||
r.server.errCh <- fmt.Errorf("RPC Method: '%s' not serviced", req.Method)
|
||||
return
|
||||
}
|
||||
|
||||
result, err := handler(req)
|
||||
if err != nil {
|
||||
r.server.errCh <- fmt.Errorf("error handling: '%s': %w", req.Method, err)
|
||||
return
|
||||
}
|
||||
|
||||
if err := conn.Reply(ctx, req.ID, result); err != nil {
|
||||
r.server.errCh <- fmt.Errorf("error replying: %w", err)
|
||||
}
|
||||
}
|
||||
77
internal/liveshare/test/socket.go
Normal file
77
internal/liveshare/test/socket.go
Normal file
|
|
@ -0,0 +1,77 @@
|
|||
package livesharetest
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
type socketConn struct {
|
||||
*websocket.Conn
|
||||
|
||||
reader io.Reader
|
||||
writeMutex sync.Mutex
|
||||
readMutex sync.Mutex
|
||||
}
|
||||
|
||||
func newSocketConn(conn *websocket.Conn) *socketConn {
|
||||
return &socketConn{Conn: conn}
|
||||
}
|
||||
|
||||
func (s *socketConn) Read(b []byte) (int, error) {
|
||||
s.readMutex.Lock()
|
||||
defer s.readMutex.Unlock()
|
||||
|
||||
if s.reader == nil {
|
||||
msgType, r, err := s.Conn.NextReader()
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("error getting next reader: %w", err)
|
||||
}
|
||||
if msgType != websocket.BinaryMessage {
|
||||
return 0, fmt.Errorf("invalid message type")
|
||||
}
|
||||
s.reader = r
|
||||
}
|
||||
|
||||
bytesRead, err := s.reader.Read(b)
|
||||
if err != nil {
|
||||
s.reader = nil
|
||||
|
||||
if err == io.EOF {
|
||||
err = nil
|
||||
}
|
||||
}
|
||||
|
||||
return bytesRead, err
|
||||
}
|
||||
|
||||
func (s *socketConn) Write(b []byte) (int, error) {
|
||||
s.writeMutex.Lock()
|
||||
defer s.writeMutex.Unlock()
|
||||
|
||||
w, err := s.Conn.NextWriter(websocket.BinaryMessage)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("error getting next writer: %w", err)
|
||||
}
|
||||
|
||||
n, err := w.Write(b)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("error writing: %w", err)
|
||||
}
|
||||
|
||||
if err := w.Close(); err != nil {
|
||||
return 0, fmt.Errorf("error closing writer: %w", err)
|
||||
}
|
||||
|
||||
return n, nil
|
||||
}
|
||||
|
||||
func (s *socketConn) SetDeadline(deadline time.Time) error {
|
||||
if err := s.Conn.SetReadDeadline(deadline); err != nil {
|
||||
return err
|
||||
}
|
||||
return s.Conn.SetWriteDeadline(deadline)
|
||||
}
|
||||
Loading…
Add table
Add a link
Reference in a new issue