Final changes to finish this refactor

This commit is contained in:
Jose Garcia 2021-07-27 23:19:55 +00:00 committed by GitHub
parent 892f73221c
commit 0ab67badfa
7 changed files with 206 additions and 25 deletions

View file

@ -9,6 +9,7 @@ import (
"net/http"
"net/http/httptest"
"path/filepath"
"strings"
"sync"
"time"
@ -21,6 +22,7 @@ type Server struct {
password string
services map[string]RpcHandleFunc
relaySAS string
streams map[string]io.ReadWriter
sshConfig *ssh.ServerConfig
httptestServer *httptest.Server
@ -50,7 +52,7 @@ func NewServer(opts ...ServerOption) (*Server, error) {
server.sshConfig.AddHostKey(privateKey)
server.errCh = make(chan error)
server.httptestServer = httptest.NewTLSServer(http.HandlerFunc(newConnection(server)))
server.httptestServer = httptest.NewTLSServer(http.HandlerFunc(makeConnection(server)))
return server, nil
}
@ -81,6 +83,16 @@ func WithRelaySAS(sas string) ServerOption {
}
}
func WithStream(name string, stream io.ReadWriter) ServerOption {
return func(s *Server) error {
if s.streams == nil {
s.streams = make(map[string]io.ReadWriter)
}
s.streams[name] = stream
return nil
}
}
func sshPasswordCallback(serverPassword string) func(ssh.ConnMetadata, []byte) (*ssh.Permissions, error) {
return func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
if string(password) == serverPassword {
@ -104,7 +116,7 @@ func (s *Server) Err() <-chan error {
var upgrader = websocket.Upgrader{}
func newConnection(server *Server) http.HandlerFunc {
func makeConnection(server *Server) http.HandlerFunc {
return func(w http.ResponseWriter, req *http.Request) {
if server.relaySAS != "" {
// validate the sas key
@ -135,12 +147,48 @@ func newConnection(server *Server) http.HandlerFunc {
server.errCh <- fmt.Errorf("error accepting new channel: %v", err)
return
}
go ssh.DiscardRequests(reqs)
go handleNewRequests(server, ch, reqs)
go handleNewChannel(server, ch)
}
}
}
func handleNewRequests(server *Server, channel ssh.Channel, reqs <-chan *ssh.Request) {
for req := range reqs {
if req.WantReply {
if err := req.Reply(true, nil); err != nil {
server.errCh <- fmt.Errorf("error replying to channel request: %v", err)
}
}
if strings.HasPrefix(req.Type, "stream-transport") {
forwardStream(server, req.Type, channel)
}
}
}
func forwardStream(server *Server, streamName string, channel ssh.Channel) {
simpleStreamName := strings.TrimPrefix(streamName, "stream-transport-")
stream, found := server.streams[simpleStreamName]
if !found {
server.errCh <- fmt.Errorf("stream '%v' not found", simpleStreamName)
return
}
copy := func(dst io.Writer, src io.Reader) {
if _, err := io.Copy(dst, src); err != nil {
fmt.Println(err)
server.errCh <- fmt.Errorf("io copy: %v", err)
return
}
}
go copy(stream, channel)
go copy(channel, stream)
for {
}
}
func handleNewChannel(server *Server, channel ssh.Channel) {
stream := jsonrpc2.NewBufferedStream(channel, jsonrpc2.VSCodeObjectCodec{})
jsonrpc2.NewConn(context.Background(), stream, newRpcHandler(server))