cli/pkg/liveshare/session.go
David Gardiner 73dfb87680
Use gRPC client to notify codespace of activity (#6870)
* Use gRPC client to notify codespace of activity

* Add test

* Update comment
2023-01-18 16:31:16 -06:00

133 lines
4 KiB
Go

package liveshare
import (
"context"
"fmt"
"github.com/opentracing/opentracing-go"
"golang.org/x/crypto/ssh"
"golang.org/x/sync/errgroup"
)
// A ChannelID is an identifier for an exposed port on a remote
// container that may be used to open an SSH channel to it.
type ChannelID struct {
name, condition string
}
// Interface to allow the mocking of the liveshare session
type LiveshareSession interface {
Close() error
GetSharedServers(context.Context) ([]*Port, error)
KeepAlive(string)
OpenStreamingChannel(context.Context, ChannelID) (ssh.Channel, error)
StartSharing(context.Context, string, int) (ChannelID, error)
GetKeepAliveReason() string
}
// A Session represents the session between a connected Live Share client and server.
type Session struct {
ssh *sshSession
rpc *rpcClient
keepAliveReason chan string
logger logger
}
// Close should be called by users to clean up RPC and SSH resources whenever the session
// is no longer active.
func (s *Session) Close() error {
// Closing the RPC conn closes the underlying stream (SSH)
// So we only need to close once
if err := s.rpc.Close(); err != nil {
s.ssh.Close() // close SSH and ignore error
return fmt.Errorf("error while closing Live Share session: %w", err)
}
return nil
}
// Fetches the keep alive reason from the channel and returns it.
func (s *Session) GetKeepAliveReason() string {
return <-s.keepAliveReason
}
// registerRequestHandler registers a handler for the given request type with the RPC
// server and returns a callback function to deregister the handler
func (s *Session) registerRequestHandler(requestType string, h handler) func() {
return s.rpc.register(requestType, h)
}
// KeepAlive accepts a reason that is retained if there is no active reason
// to send to the server.
func (s *Session) KeepAlive(reason string) {
select {
case s.keepAliveReason <- reason:
default:
// there is already an active keep alive reason
// so we can ignore this one
}
}
// StartSharing tells the Live Share host to start sharing the specified port from the container.
// The sessionName describes the purpose of the remote port or service.
// It returns an identifier that can be used to open an SSH channel to the remote port.
func (s *Session) StartSharing(ctx context.Context, sessionName string, port int) (ChannelID, error) {
args := []interface{}{port, sessionName, fmt.Sprintf("http://localhost:%d", port)}
g, ctx := errgroup.WithContext(ctx)
g.Go(func() error {
startNotification, err := s.WaitForPortNotification(ctx, port, PortChangeKindStart)
if err != nil {
return fmt.Errorf("error while waiting for port notification: %w", err)
}
if !startNotification.Success {
return fmt.Errorf("error while starting port sharing: %s", startNotification.ErrorDetail)
}
return nil // success
})
var response Port
g.Go(func() error {
return s.rpc.do(ctx, "serverSharing.startSharing", args, &response)
})
if err := g.Wait(); err != nil {
return ChannelID{}, err
}
return ChannelID{response.StreamName, response.StreamCondition}, nil
}
func (s *Session) OpenStreamingChannel(ctx context.Context, id ChannelID) (ssh.Channel, error) {
type getStreamArgs struct {
StreamName string `json:"streamName"`
Condition string `json:"condition"`
}
args := getStreamArgs{
StreamName: id.name,
Condition: id.condition,
}
var streamID string
if err := s.rpc.do(ctx, "streamManager.getStream", args, &streamID); err != nil {
return nil, fmt.Errorf("error getting stream id: %w", err)
}
span, ctx := opentracing.StartSpanFromContext(ctx, "Session.OpenChannel+SendRequest")
defer span.Finish()
_ = ctx // ctx is not currently used
channel, reqs, err := s.ssh.conn.OpenChannel("session", nil)
if err != nil {
return nil, fmt.Errorf("error opening ssh channel for transport: %w", err)
}
go ssh.DiscardRequests(reqs)
requestType := fmt.Sprintf("stream-transport-%s", streamID)
if _, err = channel.SendRequest(requestType, true, nil); err != nil {
return nil, fmt.Errorf("error sending channel request: %w", err)
}
return channel, nil
}