Skip to content

Commit 01e3456

Browse files
authored
fix: ssh portmap redo ssh dial (#170)
1 parent 46fcf55 commit 01e3456

File tree

3 files changed

+61
-59
lines changed

3 files changed

+61
-59
lines changed

pkg/daemon/handler/ssh.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ func (w *wsHandler) handle(ctx context.Context) {
8282
if err != nil {
8383
return
8484
}
85-
err = util.PortMapUntil(ctx, cli, remote, local)
85+
err = util.PortMapUntil(ctx, w.sshConfig, remote, local)
8686
if err != nil {
8787
w.Log("Port map error: %v", err)
8888
return

pkg/handler/connect.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -893,7 +893,7 @@ func SshJump(ctx context.Context, conf *util.SshConfig, flags *pflag.FlagSet, pr
893893
if print {
894894
log.Infof("wait jump to bastion host...")
895895
}
896-
err = util.PortMapUntil(ctx, cli, remote, local)
896+
err = util.PortMapUntil(ctx, conf, remote, local)
897897
if err != nil {
898898
log.Errorf("ssh proxy err: %v", err)
899899
return

pkg/util/ssh.go

+59-57
Original file line numberDiff line numberDiff line change
@@ -67,39 +67,6 @@ func (s *SshConfig) ToRPC() *rpc.SshJump {
6767
}
6868
}
6969

70-
func PortMap(ctx context.Context, sshClient *ssh.Client, remoteEndpoint, localEndpoint netip.AddrPort, done chan struct{}) error {
71-
// Listen on remote server port
72-
var lc net.ListenConfig
73-
listen, err := lc.Listen(ctx, "tcp", localEndpoint.String())
74-
if err != nil {
75-
return err
76-
}
77-
defer listen.Close()
78-
79-
select {
80-
case done <- struct{}{}:
81-
default:
82-
}
83-
// handle incoming connections on reverse forwarded tunnel
84-
for ctx.Err() == nil {
85-
localConn, err := listen.Accept()
86-
if err != nil {
87-
return err
88-
}
89-
go func(localConn net.Conn) {
90-
defer localConn.Close()
91-
remoteConn, err := sshClient.Dial("tcp", remoteEndpoint.String())
92-
if err != nil {
93-
log.Errorf("Failed to dial %s: %s", remoteEndpoint.String(), err)
94-
return
95-
}
96-
defer remoteConn.Close()
97-
copyStream(localConn, remoteConn)
98-
}(localConn)
99-
}
100-
return ctx.Err()
101-
}
102-
10370
// DialSshRemote https://github.com/golang/go/issues/21478
10471
func DialSshRemote(ctx context.Context, conf *SshConfig) (*ssh.Client, error) {
10572
var remote *ssh.Client
@@ -170,6 +137,7 @@ func DialSshRemote(ctx context.Context, conf *SshConfig) (*ssh.Client, error) {
170137
_, _, er := remote.SendRequest("keepalive@golang.org", true, nil)
171138
if er != nil {
172139
log.Errorf("failed to send keep alive error: %s", er)
140+
return
173141
}
174142
}
175143
}
@@ -225,12 +193,12 @@ func publicKeyFile(file string) (ssh.AuthMethod, error) {
225193
return ssh.PublicKeys(key), nil
226194
}
227195

228-
func copyStream(client net.Conn, remote net.Conn) {
196+
func copyStream(local net.Conn, remote net.Conn) {
229197
chDone := make(chan bool, 2)
230198

231199
// start remote -> local data transfer
232200
go func() {
233-
_, err := io.Copy(client, remote)
201+
_, err := io.Copy(local, remote)
234202
if err != nil && !errors.Is(err, net.ErrClosed) {
235203
log.Debugf("error while copy remote->local: %s", err)
236204
}
@@ -242,7 +210,7 @@ func copyStream(client net.Conn, remote net.Conn) {
242210

243211
// start local -> remote data transfer
244212
go func() {
245-
_, err := io.Copy(remote, client)
213+
_, err := io.Copy(remote, local)
246214
if err != nil && !errors.Is(err, net.ErrClosed) {
247215
log.Debugf("error while copy local->remote: %s", err)
248216
}
@@ -388,31 +356,65 @@ func init() {
388356
})
389357
}
390358

391-
func PortMapUntil(ctx context.Context, cli *ssh.Client, remote, local netip.AddrPort) error {
392-
errChan := make(chan error, 1)
393-
readyChan := make(chan struct{}, 1)
359+
func PortMapUntil(ctx context.Context, conf *SshConfig, remote, local netip.AddrPort) error {
360+
// Listen on remote server port
361+
var lc net.ListenConfig
362+
localListen, err := lc.Listen(ctx, "tcp", local.String())
363+
if err != nil {
364+
return err
365+
}
366+
367+
var lock sync.Mutex
368+
var cancelFunc context.CancelFunc
369+
var sshClient *ssh.Client
370+
371+
var getRemoteConnFunc = func() (net.Conn, error) {
372+
lock.Lock()
373+
defer lock.Unlock()
374+
375+
if sshClient != nil {
376+
remoteConn, err := sshClient.Dial("tcp", remote.String())
377+
if err == nil {
378+
return remoteConn, nil
379+
}
380+
sshClient.Close()
381+
if cancelFunc != nil {
382+
cancelFunc()
383+
}
384+
}
385+
var ctx2 context.Context
386+
ctx2, cancelFunc = context.WithCancel(ctx)
387+
sshClient, err = DialSshRemote(ctx2, conf)
388+
if err != nil {
389+
cancelFunc()
390+
cancelFunc = nil
391+
log.Errorf("failed to dial remote ssh server: %v", err)
392+
return nil, err
393+
}
394+
return sshClient.Dial("tcp", remote.String())
395+
}
396+
394397
go func() {
398+
defer localListen.Close()
399+
395400
for ctx.Err() == nil {
396-
err := PortMap(ctx, cli, remote, local, readyChan)
401+
localConn, err := localListen.Accept()
397402
if err != nil {
398-
if !errors.Is(err, context.Canceled) {
399-
log.Errorf("Ssh forward failed err: %v", err)
400-
}
401-
select {
402-
case errChan <- err:
403-
default:
404-
}
403+
log.Errorf("failed to accept conn: %v", err)
404+
return
405405
}
406-
time.Sleep(time.Second * 2)
406+
go func() {
407+
defer localConn.Close()
408+
409+
remoteConn, err := getRemoteConnFunc()
410+
if err != nil {
411+
log.Errorf("Failed to dial %s: %s", remote.String(), err)
412+
return
413+
}
414+
defer remoteConn.Close()
415+
copyStream(localConn, remoteConn)
416+
}()
407417
}
408418
}()
409-
select {
410-
case <-readyChan:
411-
return nil
412-
case err := <-errChan:
413-
log.Errorf("Ssh forward failed err: %v", err)
414-
return err
415-
case <-ctx.Done():
416-
return ctx.Err()
417-
}
419+
return nil
418420
}

0 commit comments

Comments
 (0)