package livesharetest import ( "context" "errors" "fmt" "io" "net/http" "net/http/httptest" "strings" "github.com/gorilla/websocket" "github.com/sourcegraph/jsonrpc2" "golang.org/x/crypto/ssh" ) const sshPrivateKey = `-----BEGIN RSA PRIVATE KEY----- MIIEogIBAAKCAQEAp/Jmzy/HaPNx5Bug09FX5Q/KGY4G9c4DfplhWrn31OQCqNiT ZSLd46rdXC75liHzE7e5Ic0RJN61cYN9SNArjvEXx2vvs7szhwO7LonwPOvpYpUf daayrgbr6S46plpx+hEZ1kO/6BqMgFuvnkIVThrEyx5b48ll8zgDABsYrKF8/p1V SjGfb+bLwjn1NtnZF2prBG5P4ZtMR06HaPglLqBJhmc0ZMG5IZGUE7ew/VrPDqdC f1v4XvvGiU4BLoKYy4QOhyrCGh9Uk/9u0Ea56M2bh4RqwhbpR8m7TYJZ0DVMLbGW 8C+4lCWp+xRyBNxAQh8qeQVCxYl02hPE4bXLGQIDAQABAoIBAEoVPk6UZ+UexhV2 LnphNOFhFqgxI1bYWmhE5lHsCKuLLLUoW9RYDgL4gw6/1e7o6N3AxFRpre9Soj0B YIl28k/qf6/DKAhjQnaDKdV8mVF2Swvmdesi7lyfxv6kGtD4wqApXPlMB2IuG94f E5e+1MEQQ9DJgoU3eNZR1dj9GuRC3PyzPcNNJ2R/MMGFw3sOOVcLOgAukotoicuL 0SiL51rHPQu8a5/darH9EltN1GFeceJSDDhgqMP5T8Tp7g/c3//H6szon4H9W+uN Z3UrImJ+teJjFOaVDqN93+J2eQSUk0lCPGQCd4U9I4AGDGyU6ucdcLQ58Aha9gmU uQwkfKUCgYEA0UkuPOSDE9dbXe+yhsbOwMb1kKzJYgFDKjRTSP7D9BOMZu4YyASo J95R4DWjePlDopafG2tNJoWX+CwUl7Uld1R3Ex6xHBa2B7hwZj860GZtr7D4mdWc DTVjczAjp4P0K1MIFYQui1mVJterkjKuePiI6q/27L1c2jIa/39BWBcCgYEAzW8R MFZamVw3eA2JYSpBuqhQgE5gX5IWrmVJZSUhpAQTNG/A4nxf7WGtjy9p99tm0RMb ld05+sOmNLrzw8Pq8SBpFOd+MAca7lPLS1A2CoaAHbOqRqrzVcZ4EZ2jB3WjoLoq yctwslGb9KmrhBCdcwT48aPAYUIJCZdqEen2xE8CgYBoMowvywGrvjwCH9X9njvP 5P7cAfrdrY04FQcmP5lmCtmLYZ267/6couaWv33dPBU9fMpIh3rI5BiOebvi8FBw AgCq50v8lR4Z5+0mKvLoUSbpIy4SwTRJqzwRXHVT8LF/ZH6Q39egj4Bf716/kjYl im/4kJVatsjk5a9lZ4EsDwKBgERkJ3rKJNtNggHrr8KzSLKVekdc0GTAw+BHRAny NKLf4Gzij3pXIbBrhlZW2JZ1amNMUzCvN7AuFlUTsDeKL9saiSE2eCIRG3wgVVu7 VmJmqJw6xgNEwkHaEvr6Wd4P4euOTtRjcB9NX/gxzDHpPiGelCoN8+vtCgkxaVSR aV+tAoGAO4HtLOfBAVDNbVXa27aJAjQSUq8qfkwUNJNz+rwgpVQahfiVkyqAPCQM IfRJxKWb0Wbt9ojw3AowK/k0d3LZA7FS41JSiiGKIllSGb+i7JKqKW7RHLA3VJ/E Bq5TLNIbUzPVNVwRcGjUYpOhKU6EIw8phTJOvxnUC+g6MVqBP8U= -----END RSA PRIVATE KEY-----` type Server struct { password string services map[string]RpcHandleFunc relaySAS string streams map[string]io.ReadWriter sshConfig *ssh.ServerConfig httptestServer *httptest.Server errCh chan error } func NewServer(opts ...ServerOption) (*Server, error) { server := new(Server) for _, o := range opts { if err := o(server); err != nil { return nil, err } } server.sshConfig = &ssh.ServerConfig{ PasswordCallback: sshPasswordCallback(server.password), } privateKey, err := ssh.ParsePrivateKey([]byte(sshPrivateKey)) if err != nil { return nil, fmt.Errorf("error parsing key: %v", err) } server.sshConfig.AddHostKey(privateKey) server.errCh = make(chan error) server.httptestServer = httptest.NewTLSServer(http.HandlerFunc(makeConnection(server))) return server, nil } type ServerOption func(*Server) error func WithPassword(password string) ServerOption { return func(s *Server) error { s.password = password return nil } } func WithService(serviceName string, handler RpcHandleFunc) ServerOption { return func(s *Server) error { if s.services == nil { s.services = make(map[string]RpcHandleFunc) } s.services[serviceName] = handler return nil } } func WithRelaySAS(sas string) ServerOption { return func(s *Server) error { s.relaySAS = sas return nil } } func WithStream(name string, stream io.ReadWriter) ServerOption { return func(s *Server) error { if s.streams == nil { s.streams = make(map[string]io.ReadWriter) } s.streams[name] = stream return nil } } func sshPasswordCallback(serverPassword string) func(ssh.ConnMetadata, []byte) (*ssh.Permissions, error) { return func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) { if string(password) == serverPassword { return nil, nil } return nil, errors.New("password rejected") } } func (s *Server) Close() { s.httptestServer.Close() } func (s *Server) URL() string { return s.httptestServer.URL } func (s *Server) Err() <-chan error { return s.errCh } var upgrader = websocket.Upgrader{} func makeConnection(server *Server) http.HandlerFunc { return func(w http.ResponseWriter, req *http.Request) { if server.relaySAS != "" { // validate the sas key sasParam := req.URL.Query().Get("sb-hc-token") if sasParam != server.relaySAS { server.errCh <- errors.New("error validating sas") return } } c, err := upgrader.Upgrade(w, req, nil) if err != nil { server.errCh <- fmt.Errorf("error upgrading connection: %v", err) return } defer c.Close() socketConn := newSocketConn(c) _, chans, reqs, err := ssh.NewServerConn(socketConn, server.sshConfig) if err != nil { server.errCh <- fmt.Errorf("error creating new ssh conn: %v", err) return } go ssh.DiscardRequests(reqs) for newChannel := range chans { ch, reqs, err := newChannel.Accept() if err != nil { server.errCh <- fmt.Errorf("error accepting new channel: %v", err) return } go handleNewRequests(server, ch, reqs) go handleNewChannel(server, ch) } } } func handleNewRequests(server *Server, channel ssh.Channel, reqs <-chan *ssh.Request) { for req := range reqs { if req.WantReply { if err := req.Reply(true, nil); err != nil { server.errCh <- fmt.Errorf("error replying to channel request: %v", err) } } if strings.HasPrefix(req.Type, "stream-transport") { forwardStream(server, req.Type, channel) } } } func forwardStream(server *Server, streamName string, channel ssh.Channel) { simpleStreamName := strings.TrimPrefix(streamName, "stream-transport-") stream, found := server.streams[simpleStreamName] if !found { server.errCh <- fmt.Errorf("stream '%v' not found", simpleStreamName) return } copy := func(dst io.Writer, src io.Reader) { if _, err := io.Copy(dst, src); err != nil { fmt.Println(err) server.errCh <- fmt.Errorf("io copy: %v", err) return } } go copy(stream, channel) go copy(channel, stream) for { } } func handleNewChannel(server *Server, channel ssh.Channel) { stream := jsonrpc2.NewBufferedStream(channel, jsonrpc2.VSCodeObjectCodec{}) jsonrpc2.NewConn(context.Background(), stream, newRpcHandler(server)) } type RpcHandleFunc func(req *jsonrpc2.Request) (interface{}, error) type rpcHandler struct { server *Server } func newRpcHandler(server *Server) *rpcHandler { return &rpcHandler{server} } func (r *rpcHandler) Handle(ctx context.Context, conn *jsonrpc2.Conn, req *jsonrpc2.Request) { handler, found := r.server.services[req.Method] if !found { r.server.errCh <- fmt.Errorf("RPC Method: '%v' not serviced", req.Method) return } result, err := handler(req) if err != nil { r.server.errCh <- fmt.Errorf("error handling: '%v': %v", req.Method, err) return } if err := conn.Reply(ctx, req.ID, result); err != nil { r.server.errCh <- fmt.Errorf("error replying: %v", err) } }