Final changes to finish this refactor
This commit is contained in:
parent
892f73221c
commit
0ab67badfa
7 changed files with 206 additions and 25 deletions
|
|
@ -9,6 +9,7 @@ import (
|
|||
"net/http"
|
||||
"net/http/httptest"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
|
|
@ -21,6 +22,7 @@ type Server struct {
|
|||
password string
|
||||
services map[string]RpcHandleFunc
|
||||
relaySAS string
|
||||
streams map[string]io.ReadWriter
|
||||
|
||||
sshConfig *ssh.ServerConfig
|
||||
httptestServer *httptest.Server
|
||||
|
|
@ -50,7 +52,7 @@ func NewServer(opts ...ServerOption) (*Server, error) {
|
|||
server.sshConfig.AddHostKey(privateKey)
|
||||
|
||||
server.errCh = make(chan error)
|
||||
server.httptestServer = httptest.NewTLSServer(http.HandlerFunc(newConnection(server)))
|
||||
server.httptestServer = httptest.NewTLSServer(http.HandlerFunc(makeConnection(server)))
|
||||
return server, nil
|
||||
}
|
||||
|
||||
|
|
@ -81,6 +83,16 @@ func WithRelaySAS(sas string) ServerOption {
|
|||
}
|
||||
}
|
||||
|
||||
func WithStream(name string, stream io.ReadWriter) ServerOption {
|
||||
return func(s *Server) error {
|
||||
if s.streams == nil {
|
||||
s.streams = make(map[string]io.ReadWriter)
|
||||
}
|
||||
s.streams[name] = stream
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func sshPasswordCallback(serverPassword string) func(ssh.ConnMetadata, []byte) (*ssh.Permissions, error) {
|
||||
return func(conn ssh.ConnMetadata, password []byte) (*ssh.Permissions, error) {
|
||||
if string(password) == serverPassword {
|
||||
|
|
@ -104,7 +116,7 @@ func (s *Server) Err() <-chan error {
|
|||
|
||||
var upgrader = websocket.Upgrader{}
|
||||
|
||||
func newConnection(server *Server) http.HandlerFunc {
|
||||
func makeConnection(server *Server) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, req *http.Request) {
|
||||
if server.relaySAS != "" {
|
||||
// validate the sas key
|
||||
|
|
@ -135,12 +147,48 @@ func newConnection(server *Server) http.HandlerFunc {
|
|||
server.errCh <- fmt.Errorf("error accepting new channel: %v", err)
|
||||
return
|
||||
}
|
||||
go ssh.DiscardRequests(reqs)
|
||||
go handleNewRequests(server, ch, reqs)
|
||||
go handleNewChannel(server, ch)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func handleNewRequests(server *Server, channel ssh.Channel, reqs <-chan *ssh.Request) {
|
||||
for req := range reqs {
|
||||
if req.WantReply {
|
||||
if err := req.Reply(true, nil); err != nil {
|
||||
server.errCh <- fmt.Errorf("error replying to channel request: %v", err)
|
||||
}
|
||||
}
|
||||
if strings.HasPrefix(req.Type, "stream-transport") {
|
||||
forwardStream(server, req.Type, channel)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func forwardStream(server *Server, streamName string, channel ssh.Channel) {
|
||||
simpleStreamName := strings.TrimPrefix(streamName, "stream-transport-")
|
||||
stream, found := server.streams[simpleStreamName]
|
||||
if !found {
|
||||
server.errCh <- fmt.Errorf("stream '%v' not found", simpleStreamName)
|
||||
return
|
||||
}
|
||||
|
||||
copy := func(dst io.Writer, src io.Reader) {
|
||||
if _, err := io.Copy(dst, src); err != nil {
|
||||
fmt.Println(err)
|
||||
server.errCh <- fmt.Errorf("io copy: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
go copy(stream, channel)
|
||||
go copy(channel, stream)
|
||||
|
||||
for {
|
||||
}
|
||||
}
|
||||
|
||||
func handleNewChannel(server *Server, channel ssh.Channel) {
|
||||
stream := jsonrpc2.NewBufferedStream(channel, jsonrpc2.VSCodeObjectCodec{})
|
||||
jsonrpc2.NewConn(context.Background(), stream, newRpcHandler(server))
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue