Merge pull request #6838 from dmgardiner25/start-remote-server-grpc

Start SSH server with gRPC client
This commit is contained in:
David Gardiner 2023-01-10 14:54:27 -08:00 committed by GitHub
commit ba27e5bfb8
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 475 additions and 71 deletions

View file

@ -23,5 +23,6 @@ function generate {
generate jupyter/jupyter_server_host_service.v1.proto
generate codespace/codespace_host_service.v1.proto
generate ssh/ssh_server_host_service.v1.proto
echo 'Done!'

View file

@ -7,11 +7,14 @@ import (
"context"
"fmt"
"net"
"os"
"strconv"
"strings"
"time"
"github.com/cli/cli/v2/internal/codespaces/rpc/codespace"
"github.com/cli/cli/v2/internal/codespaces/rpc/jupyter"
"github.com/cli/cli/v2/internal/codespaces/rpc/ssh"
"github.com/cli/cli/v2/pkg/liveshare"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
@ -28,12 +31,16 @@ const (
codespacesInternalSessionName = "CodespacesInternal"
)
type StartSSHServerOptions struct {
UserPublicKeyFile string
}
type Invoker interface {
Close() error
StartJupyterServer(ctx context.Context) (int, string, error)
RebuildContainer(ctx context.Context, full bool) error
StartSSHServer(ctx context.Context) (int, string, error)
StartSSHServerWithOptions(ctx context.Context, options liveshare.StartSSHServerOptions) (int, string, error)
StartSSHServerWithOptions(ctx context.Context, options StartSSHServerOptions) (int, string, error)
}
type invoker struct {
@ -42,6 +49,7 @@ type invoker struct {
listener net.Listener
jupyterClient jupyter.JupyterServerHostClient
codespaceClient codespace.CodespaceHostClient
sshClient ssh.SshServerHostClient
cancelPF context.CancelFunc
}
@ -118,6 +126,7 @@ func connect(ctx context.Context, session liveshare.LiveshareSession) (Invoker,
invoker.conn = conn
invoker.jupyterClient = jupyter.NewJupyterServerHostClient(conn)
invoker.codespaceClient = codespace.NewCodespaceHostClient(conn)
invoker.sshClient = ssh.NewSshServerHostClient(conn)
return invoker, nil
}
@ -185,10 +194,38 @@ func (i *invoker) RebuildContainer(ctx context.Context, full bool) error {
// Starts a remote SSH server to allow the user to connect to the codespace via SSH
func (i *invoker) StartSSHServer(ctx context.Context) (int, string, error) {
return i.session.StartSSHServer(ctx)
return i.StartSSHServerWithOptions(ctx, StartSSHServerOptions{})
}
// Starts a remote SSH server to allow the user to connect to the codespace via SSH
func (i *invoker) StartSSHServerWithOptions(ctx context.Context, options liveshare.StartSSHServerOptions) (int, string, error) {
return i.session.StartSSHServerWithOptions(ctx, options)
func (i *invoker) StartSSHServerWithOptions(ctx context.Context, options StartSSHServerOptions) (int, string, error) {
ctx = i.appendMetadata(ctx)
ctx, cancel := context.WithTimeout(ctx, requestTimeout)
defer cancel()
userPublicKey := ""
if options.UserPublicKeyFile != "" {
publicKeyBytes, err := os.ReadFile(options.UserPublicKeyFile)
if err != nil {
return 0, "", fmt.Errorf("failed to read public key file: %w", err)
}
userPublicKey = strings.TrimSpace(string(publicKeyBytes))
}
response, err := i.sshClient.StartRemoteServerAsync(ctx, &ssh.StartRemoteServerRequest{UserPublicKey: userPublicKey})
if err != nil {
return 0, "", fmt.Errorf("failed to invoke SSH RPC: %w", err)
}
if !response.Result {
return 0, "", fmt.Errorf("failed to start SSH server: %s", response.Message)
}
port, err := strconv.Atoi(response.ServerPort)
if err != nil {
return 0, "", fmt.Errorf("failed to parse SSH server port: %w", err)
}
return port, response.User, nil
}

View file

@ -113,3 +113,38 @@ func TestRebuildContainerFailure(t *testing.T) {
t.Fatalf("expected %v, got %v", errorMessage, err)
}
}
// Test that the RPC invoker returns the correct port and user when the SSH server starts successfully
func TestStartSSHServerSuccess(t *testing.T) {
startServer(t)
invoker := createTestInvoker(t)
port, user, err := invoker.StartSSHServer(context.Background())
if err != nil {
t.Fatalf("expected %v, got %v", nil, err)
}
if port != rpctest.SshServerPort {
t.Fatalf("expected %d, got %d", rpctest.SshServerPort, port)
}
if user != rpctest.SshUser {
t.Fatalf("expected %s, got %s", rpctest.SshUser, user)
}
}
// Test that the RPC invoker returns an error when the SSH server fails to start
func TestStartSSHServerFailure(t *testing.T) {
startServer(t)
invoker := createTestInvoker(t)
rpctest.SshMessage = "error message"
rpctest.SshResult = false
errorMessage := fmt.Sprintf("failed to start SSH server: %s", rpctest.SshMessage)
port, user, err := invoker.StartSSHServer(context.Background())
if err.Error() != errorMessage {
t.Fatalf("expected %v, got %v", errorMessage, err)
}
if port != 0 {
t.Fatalf("expected %d, got %d", 0, port)
}
if user != "" {
t.Fatalf("expected %s, got %s", "", user)
}
}

View file

@ -0,0 +1,252 @@
// Code generated by protoc-gen-go. DO NOT EDIT.
// versions:
// protoc-gen-go v1.28.0
// protoc v3.21.12
// source: ssh/ssh_server_host_service.v1.proto
package ssh
import (
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
protoimpl "google.golang.org/protobuf/runtime/protoimpl"
reflect "reflect"
sync "sync"
)
const (
// Verify that this generated code is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion)
// Verify that runtime/protoimpl is sufficiently up-to-date.
_ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20)
)
type StartRemoteServerRequest struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
UserPublicKey string `protobuf:"bytes,1,opt,name=UserPublicKey,proto3" json:"UserPublicKey,omitempty"`
}
func (x *StartRemoteServerRequest) Reset() {
*x = StartRemoteServerRequest{}
if protoimpl.UnsafeEnabled {
mi := &file_ssh_ssh_server_host_service_v1_proto_msgTypes[0]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *StartRemoteServerRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*StartRemoteServerRequest) ProtoMessage() {}
func (x *StartRemoteServerRequest) ProtoReflect() protoreflect.Message {
mi := &file_ssh_ssh_server_host_service_v1_proto_msgTypes[0]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use StartRemoteServerRequest.ProtoReflect.Descriptor instead.
func (*StartRemoteServerRequest) Descriptor() ([]byte, []int) {
return file_ssh_ssh_server_host_service_v1_proto_rawDescGZIP(), []int{0}
}
func (x *StartRemoteServerRequest) GetUserPublicKey() string {
if x != nil {
return x.UserPublicKey
}
return ""
}
type StartRemoteServerResponse struct {
state protoimpl.MessageState
sizeCache protoimpl.SizeCache
unknownFields protoimpl.UnknownFields
Result bool `protobuf:"varint,1,opt,name=Result,proto3" json:"Result,omitempty"`
ServerPort string `protobuf:"bytes,2,opt,name=ServerPort,proto3" json:"ServerPort,omitempty"`
User string `protobuf:"bytes,3,opt,name=User,proto3" json:"User,omitempty"`
Message string `protobuf:"bytes,4,opt,name=Message,proto3" json:"Message,omitempty"`
}
func (x *StartRemoteServerResponse) Reset() {
*x = StartRemoteServerResponse{}
if protoimpl.UnsafeEnabled {
mi := &file_ssh_ssh_server_host_service_v1_proto_msgTypes[1]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
}
func (x *StartRemoteServerResponse) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*StartRemoteServerResponse) ProtoMessage() {}
func (x *StartRemoteServerResponse) ProtoReflect() protoreflect.Message {
mi := &file_ssh_ssh_server_host_service_v1_proto_msgTypes[1]
if protoimpl.UnsafeEnabled && x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use StartRemoteServerResponse.ProtoReflect.Descriptor instead.
func (*StartRemoteServerResponse) Descriptor() ([]byte, []int) {
return file_ssh_ssh_server_host_service_v1_proto_rawDescGZIP(), []int{1}
}
func (x *StartRemoteServerResponse) GetResult() bool {
if x != nil {
return x.Result
}
return false
}
func (x *StartRemoteServerResponse) GetServerPort() string {
if x != nil {
return x.ServerPort
}
return ""
}
func (x *StartRemoteServerResponse) GetUser() string {
if x != nil {
return x.User
}
return ""
}
func (x *StartRemoteServerResponse) GetMessage() string {
if x != nil {
return x.Message
}
return ""
}
var File_ssh_ssh_server_host_service_v1_proto protoreflect.FileDescriptor
var file_ssh_ssh_server_host_service_v1_proto_rawDesc = []byte{
0x0a, 0x24, 0x73, 0x73, 0x68, 0x2f, 0x73, 0x73, 0x68, 0x5f, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72,
0x5f, 0x68, 0x6f, 0x73, 0x74, 0x5f, 0x73, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x76, 0x31,
0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x27, 0x43, 0x6f, 0x64, 0x65, 0x73, 0x70, 0x61, 0x63,
0x65, 0x73, 0x2e, 0x47, 0x72, 0x70, 0x63, 0x2e, 0x53, 0x73, 0x68, 0x53, 0x65, 0x72, 0x76, 0x65,
0x72, 0x48, 0x6f, 0x73, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x76, 0x31, 0x22,
0x40, 0x0a, 0x18, 0x53, 0x74, 0x61, 0x72, 0x74, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x53, 0x65,
0x72, 0x76, 0x65, 0x72, 0x52, 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x12, 0x24, 0x0a, 0x0d, 0x55,
0x73, 0x65, 0x72, 0x50, 0x75, 0x62, 0x6c, 0x69, 0x63, 0x4b, 0x65, 0x79, 0x18, 0x01, 0x20, 0x01,
0x28, 0x09, 0x52, 0x0d, 0x55, 0x73, 0x65, 0x72, 0x50, 0x75, 0x62, 0x6c, 0x69, 0x63, 0x4b, 0x65,
0x79, 0x22, 0x81, 0x01, 0x0a, 0x19, 0x53, 0x74, 0x61, 0x72, 0x74, 0x52, 0x65, 0x6d, 0x6f, 0x74,
0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x12,
0x16, 0x0a, 0x06, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x18, 0x01, 0x20, 0x01, 0x28, 0x08, 0x52,
0x06, 0x52, 0x65, 0x73, 0x75, 0x6c, 0x74, 0x12, 0x1e, 0x0a, 0x0a, 0x53, 0x65, 0x72, 0x76, 0x65,
0x72, 0x50, 0x6f, 0x72, 0x74, 0x18, 0x02, 0x20, 0x01, 0x28, 0x09, 0x52, 0x0a, 0x53, 0x65, 0x72,
0x76, 0x65, 0x72, 0x50, 0x6f, 0x72, 0x74, 0x12, 0x12, 0x0a, 0x04, 0x55, 0x73, 0x65, 0x72, 0x18,
0x03, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, 0x55, 0x73, 0x65, 0x72, 0x12, 0x18, 0x0a, 0x07, 0x4d,
0x65, 0x73, 0x73, 0x61, 0x67, 0x65, 0x18, 0x04, 0x20, 0x01, 0x28, 0x09, 0x52, 0x07, 0x4d, 0x65,
0x73, 0x73, 0x61, 0x67, 0x65, 0x32, 0xb1, 0x01, 0x0a, 0x0d, 0x53, 0x73, 0x68, 0x53, 0x65, 0x72,
0x76, 0x65, 0x72, 0x48, 0x6f, 0x73, 0x74, 0x12, 0x9f, 0x01, 0x0a, 0x16, 0x53, 0x74, 0x61, 0x72,
0x74, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x41, 0x73, 0x79,
0x6e, 0x63, 0x12, 0x41, 0x2e, 0x43, 0x6f, 0x64, 0x65, 0x73, 0x70, 0x61, 0x63, 0x65, 0x73, 0x2e,
0x47, 0x72, 0x70, 0x63, 0x2e, 0x53, 0x73, 0x68, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x48, 0x6f,
0x73, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x76, 0x31, 0x2e, 0x53, 0x74, 0x61,
0x72, 0x74, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65, 0x72, 0x52, 0x65,
0x71, 0x75, 0x65, 0x73, 0x74, 0x1a, 0x42, 0x2e, 0x43, 0x6f, 0x64, 0x65, 0x73, 0x70, 0x61, 0x63,
0x65, 0x73, 0x2e, 0x47, 0x72, 0x70, 0x63, 0x2e, 0x53, 0x73, 0x68, 0x53, 0x65, 0x72, 0x76, 0x65,
0x72, 0x48, 0x6f, 0x73, 0x74, 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x2e, 0x76, 0x31, 0x2e,
0x53, 0x74, 0x61, 0x72, 0x74, 0x52, 0x65, 0x6d, 0x6f, 0x74, 0x65, 0x53, 0x65, 0x72, 0x76, 0x65,
0x72, 0x52, 0x65, 0x73, 0x70, 0x6f, 0x6e, 0x73, 0x65, 0x42, 0x07, 0x5a, 0x05, 0x2e, 0x2f, 0x73,
0x73, 0x68, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33,
}
var (
file_ssh_ssh_server_host_service_v1_proto_rawDescOnce sync.Once
file_ssh_ssh_server_host_service_v1_proto_rawDescData = file_ssh_ssh_server_host_service_v1_proto_rawDesc
)
func file_ssh_ssh_server_host_service_v1_proto_rawDescGZIP() []byte {
file_ssh_ssh_server_host_service_v1_proto_rawDescOnce.Do(func() {
file_ssh_ssh_server_host_service_v1_proto_rawDescData = protoimpl.X.CompressGZIP(file_ssh_ssh_server_host_service_v1_proto_rawDescData)
})
return file_ssh_ssh_server_host_service_v1_proto_rawDescData
}
var file_ssh_ssh_server_host_service_v1_proto_msgTypes = make([]protoimpl.MessageInfo, 2)
var file_ssh_ssh_server_host_service_v1_proto_goTypes = []interface{}{
(*StartRemoteServerRequest)(nil), // 0: Codespaces.Grpc.SshServerHostService.v1.StartRemoteServerRequest
(*StartRemoteServerResponse)(nil), // 1: Codespaces.Grpc.SshServerHostService.v1.StartRemoteServerResponse
}
var file_ssh_ssh_server_host_service_v1_proto_depIdxs = []int32{
0, // 0: Codespaces.Grpc.SshServerHostService.v1.SshServerHost.StartRemoteServerAsync:input_type -> Codespaces.Grpc.SshServerHostService.v1.StartRemoteServerRequest
1, // 1: Codespaces.Grpc.SshServerHostService.v1.SshServerHost.StartRemoteServerAsync:output_type -> Codespaces.Grpc.SshServerHostService.v1.StartRemoteServerResponse
1, // [1:2] is the sub-list for method output_type
0, // [0:1] is the sub-list for method input_type
0, // [0:0] is the sub-list for extension type_name
0, // [0:0] is the sub-list for extension extendee
0, // [0:0] is the sub-list for field type_name
}
func init() { file_ssh_ssh_server_host_service_v1_proto_init() }
func file_ssh_ssh_server_host_service_v1_proto_init() {
if File_ssh_ssh_server_host_service_v1_proto != nil {
return
}
if !protoimpl.UnsafeEnabled {
file_ssh_ssh_server_host_service_v1_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*StartRemoteServerRequest); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
file_ssh_ssh_server_host_service_v1_proto_msgTypes[1].Exporter = func(v interface{}, i int) interface{} {
switch v := v.(*StartRemoteServerResponse); i {
case 0:
return &v.state
case 1:
return &v.sizeCache
case 2:
return &v.unknownFields
default:
return nil
}
}
}
type x struct{}
out := protoimpl.TypeBuilder{
File: protoimpl.DescBuilder{
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: file_ssh_ssh_server_host_service_v1_proto_rawDesc,
NumEnums: 0,
NumMessages: 2,
NumExtensions: 0,
NumServices: 1,
},
GoTypes: file_ssh_ssh_server_host_service_v1_proto_goTypes,
DependencyIndexes: file_ssh_ssh_server_host_service_v1_proto_depIdxs,
MessageInfos: file_ssh_ssh_server_host_service_v1_proto_msgTypes,
}.Build()
File_ssh_ssh_server_host_service_v1_proto = out.File
file_ssh_ssh_server_host_service_v1_proto_rawDesc = nil
file_ssh_ssh_server_host_service_v1_proto_goTypes = nil
file_ssh_ssh_server_host_service_v1_proto_depIdxs = nil
}

View file

@ -0,0 +1,20 @@
syntax = "proto3";
option go_package = "./ssh";
package Codespaces.Grpc.SshServerHostService.v1;
service SshServerHost {
rpc StartRemoteServerAsync (StartRemoteServerRequest) returns (StartRemoteServerResponse);
}
message StartRemoteServerRequest {
string UserPublicKey = 1;
}
message StartRemoteServerResponse {
bool Result = 1;
string ServerPort = 2;
string User = 3;
string Message = 4;
}

View file

@ -0,0 +1,105 @@
// Code generated by protoc-gen-go-grpc. DO NOT EDIT.
// versions:
// - protoc-gen-go-grpc v1.2.0
// - protoc v3.21.12
// source: ssh/ssh_server_host_service.v1.proto
package ssh
import (
context "context"
grpc "google.golang.org/grpc"
codes "google.golang.org/grpc/codes"
status "google.golang.org/grpc/status"
)
// This is a compile-time assertion to ensure that this generated file
// is compatible with the grpc package it is being compiled against.
// Requires gRPC-Go v1.32.0 or later.
const _ = grpc.SupportPackageIsVersion7
// SshServerHostClient is the client API for SshServerHost service.
//
// For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream.
type SshServerHostClient interface {
StartRemoteServerAsync(ctx context.Context, in *StartRemoteServerRequest, opts ...grpc.CallOption) (*StartRemoteServerResponse, error)
}
type sshServerHostClient struct {
cc grpc.ClientConnInterface
}
func NewSshServerHostClient(cc grpc.ClientConnInterface) SshServerHostClient {
return &sshServerHostClient{cc}
}
func (c *sshServerHostClient) StartRemoteServerAsync(ctx context.Context, in *StartRemoteServerRequest, opts ...grpc.CallOption) (*StartRemoteServerResponse, error) {
out := new(StartRemoteServerResponse)
err := c.cc.Invoke(ctx, "/Codespaces.Grpc.SshServerHostService.v1.SshServerHost/StartRemoteServerAsync", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
// SshServerHostServer is the server API for SshServerHost service.
// All implementations must embed UnimplementedSshServerHostServer
// for forward compatibility
type SshServerHostServer interface {
StartRemoteServerAsync(context.Context, *StartRemoteServerRequest) (*StartRemoteServerResponse, error)
mustEmbedUnimplementedSshServerHostServer()
}
// UnimplementedSshServerHostServer must be embedded to have forward compatible implementations.
type UnimplementedSshServerHostServer struct {
}
func (UnimplementedSshServerHostServer) StartRemoteServerAsync(context.Context, *StartRemoteServerRequest) (*StartRemoteServerResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method StartRemoteServerAsync not implemented")
}
func (UnimplementedSshServerHostServer) mustEmbedUnimplementedSshServerHostServer() {}
// UnsafeSshServerHostServer may be embedded to opt out of forward compatibility for this service.
// Use of this interface is not recommended, as added methods to SshServerHostServer will
// result in compilation errors.
type UnsafeSshServerHostServer interface {
mustEmbedUnimplementedSshServerHostServer()
}
func RegisterSshServerHostServer(s grpc.ServiceRegistrar, srv SshServerHostServer) {
s.RegisterService(&SshServerHost_ServiceDesc, srv)
}
func _SshServerHost_StartRemoteServerAsync_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(StartRemoteServerRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(SshServerHostServer).StartRemoteServerAsync(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/Codespaces.Grpc.SshServerHostService.v1.SshServerHost/StartRemoteServerAsync",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(SshServerHostServer).StartRemoteServerAsync(ctx, req.(*StartRemoteServerRequest))
}
return interceptor(ctx, in, info, handler)
}
// SshServerHost_ServiceDesc is the grpc.ServiceDesc for SshServerHost service.
// It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy)
var SshServerHost_ServiceDesc = grpc.ServiceDesc{
ServiceName: "Codespaces.Grpc.SshServerHostService.v1.SshServerHost",
HandlerType: (*SshServerHostServer)(nil),
Methods: []grpc.MethodDesc{
{
MethodName: "StartRemoteServerAsync",
Handler: _SshServerHost_StartRemoteServerAsync_Handler,
},
},
Streams: []grpc.StreamDesc{},
Metadata: "ssh/ssh_server_host_service.v1.proto",
}

View file

@ -8,6 +8,7 @@ import (
"github.com/cli/cli/v2/internal/codespaces/rpc/codespace"
"github.com/cli/cli/v2/internal/codespaces/rpc/jupyter"
"github.com/cli/cli/v2/internal/codespaces/rpc/ssh"
"google.golang.org/grpc"
)
@ -28,9 +29,18 @@ var (
RebuildContainer = true
)
// Mock responses for the `StartRemoteServerAsync` RPC method
var (
SshServerPort = 1234
SshUser = "test"
SshMessage = ""
SshResult = true
)
type server struct {
jupyter.UnimplementedJupyterServerHostServer
codespace.CodespaceHostServer
ssh.SshServerHostServer
}
func (s *server) GetRunningServer(ctx context.Context, in *jupyter.GetRunningServerRequest) (*jupyter.GetRunningServerResponse, error) {
@ -48,6 +58,15 @@ func (s *server) RebuildContainerAsync(ctx context.Context, in *codespace.Rebuil
}, nil
}
func (s *server) StartRemoteServerAsync(ctx context.Context, in *ssh.StartRemoteServerRequest) (*ssh.StartRemoteServerResponse, error) {
return &ssh.StartRemoteServerResponse{
ServerPort: strconv.Itoa(SshServerPort),
User: SshUser,
Message: SshMessage,
Result: SshResult,
}, nil
}
// Starts the mock gRPC server listening on port 50051
func StartServer(ctx context.Context) error {
listener, err := net.Listen("tcp", fmt.Sprintf("127.0.0.1:%d", ServerPort))
@ -59,6 +78,7 @@ func StartServer(ctx context.Context) error {
s := grpc.NewServer()
jupyter.RegisterJupyterServerHostServer(s, &server{})
codespace.RegisterCodespaceHostServer(s, &server{})
ssh.RegisterSshServerHostServer(s, &server{})
ch := make(chan error, 1)
go func() {

View file

@ -21,18 +21,6 @@ func (*Session) GetSharedServers(context.Context) ([]*liveshare.Port, error) {
panic("unimplemented")
}
func (*Session) RebuildContainer(context.Context, bool) error {
panic("unimplemented")
}
func (*Session) StartSSHServer(context.Context) (int, string, error) {
panic("unimplemented")
}
func (*Session) StartSSHServerWithOptions(context.Context, liveshare.StartSSHServerOptions) (int, string, error) {
panic("unimplemented")
}
func (s *Session) KeepAlive(reason string) {
}

View file

@ -148,7 +148,7 @@ func (a *App) SSH(ctx context.Context, sshArgs []string, opts sshOptions) (err e
}
sshContext := ssh.Context{}
startSSHOptions := liveshare.StartSSHServerOptions{}
startSSHOptions := rpc.StartSSHServerOptions{}
keyPair, shouldAddArg, err := selectSSHKeys(ctx, sshContext, args, opts)
if err != nil {

View file

@ -3,9 +3,6 @@ package liveshare
import (
"context"
"fmt"
"os"
"strconv"
"strings"
"time"
"github.com/opentracing/opentracing-go"
@ -26,8 +23,6 @@ type LiveshareSession interface {
KeepAlive(string)
OpenStreamingChannel(context.Context, ChannelID) (ssh.Channel, error)
StartSharing(context.Context, string, int) (ChannelID, error)
StartSSHServer(context.Context) (int, string, error)
StartSSHServerWithOptions(context.Context, StartSSHServerOptions) (int, string, error)
}
// A Session represents the session between a connected Live Share client and server.
@ -40,10 +35,6 @@ type Session struct {
logger logger
}
type StartSSHServerOptions struct {
UserPublicKeyFile string
}
// Close should be called by users to clean up RPC and SSH resources whenever the session
// is no longer active.
func (s *Session) Close() error {
@ -63,51 +54,6 @@ func (s *Session) registerRequestHandler(requestType string, h handler) func() {
return s.rpc.register(requestType, h)
}
// StartSSHServer starts an SSH server in the container, installing sshd if necessary, applies specified
// options, and returns the port on which it listens and the user name clients should provide.
func (s *Session) StartSSHServer(ctx context.Context) (int, string, error) {
return s.StartSSHServerWithOptions(ctx, StartSSHServerOptions{})
}
// StartSSHServerWithOptions starts an SSH server in the container, installing sshd if necessary, applies specified
// options, and returns the port on which it listens and the user name clients should provide.
func (s *Session) StartSSHServerWithOptions(ctx context.Context, options StartSSHServerOptions) (int, string, error) {
var params struct {
UserPublicKey string `json:"userPublicKey"`
}
var response struct {
Result bool `json:"result"`
ServerPort string `json:"serverPort"`
User string `json:"user"`
Message string `json:"message"`
}
if options.UserPublicKeyFile != "" {
publicKeyBytes, err := os.ReadFile(options.UserPublicKeyFile)
if err != nil {
return 0, "", fmt.Errorf("failed to read public key file: %w", err)
}
params.UserPublicKey = strings.TrimSpace(string(publicKeyBytes))
}
if err := s.rpc.do(ctx, "ISshServerHostService.startRemoteServerWithOptions", params, &response); err != nil {
return 0, "", err
}
if !response.Result {
return 0, "", fmt.Errorf("failed to start server: %s", response.Message)
}
port, err := strconv.Atoi(response.ServerPort)
if err != nil {
return 0, "", fmt.Errorf("failed to parse port: %w", err)
}
return port, response.User, nil
}
// heartbeat runs until context cancellation, periodically checking whether there is a
// reason to keep the connection alive, and if so, notifying the Live Share host to do so.
// Heartbeat ensures it does not send more than one request every "interval" to ratelimit