Skip to content

Commit b1b95bd

Browse files
committed
fix: ssh portmap redo ssh dial
1 parent 46fcf55 commit b1b95bd

File tree

3 files changed

+65
-36
lines changed

3 files changed

+65
-36
lines changed

pkg/daemon/handler/ssh.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ func (w *wsHandler) handle(ctx context.Context) {
8282
if err != nil {
8383
return
8484
}
85-
err = util.PortMapUntil(ctx, cli, remote, local)
85+
err = util.PortMapUntil(ctx, w.sshConfig, remote, local)
8686
if err != nil {
8787
w.Log("Port map error: %v", err)
8888
return

pkg/handler/connect.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -893,7 +893,7 @@ func SshJump(ctx context.Context, conf *util.SshConfig, flags *pflag.FlagSet, pr
893893
if print {
894894
log.Infof("wait jump to bastion host...")
895895
}
896-
err = util.PortMapUntil(ctx, cli, remote, local)
896+
err = util.PortMapUntil(ctx, conf, remote, local)
897897
if err != nil {
898898
log.Errorf("ssh proxy err: %v", err)
899899
return

pkg/util/ssh.go

+63-34
Original file line numberDiff line numberDiff line change
@@ -67,37 +67,44 @@ func (s *SshConfig) ToRPC() *rpc.SshJump {
6767
}
6868
}
6969

70-
func PortMap(ctx context.Context, sshClient *ssh.Client, remoteEndpoint, localEndpoint netip.AddrPort, done chan struct{}) error {
71-
// Listen on remote server port
72-
var lc net.ListenConfig
73-
listen, err := lc.Listen(ctx, "tcp", localEndpoint.String())
74-
if err != nil {
75-
return err
76-
}
77-
defer listen.Close()
70+
func PortMap(ctx context.Context, sshClient *ssh.Client, remoteEndpoint netip.AddrPort, local net.Listener) error {
71+
defer sshClient.Close()
7872

79-
select {
80-
case done <- struct{}{}:
81-
default:
82-
}
73+
errChan := make(chan error, 1)
8374
// handle incoming connections on reverse forwarded tunnel
84-
for ctx.Err() == nil {
85-
localConn, err := listen.Accept()
75+
for {
76+
select {
77+
case <-ctx.Done():
78+
return ctx.Err()
79+
case err := <-errChan:
80+
return err
81+
default:
82+
}
83+
84+
localConn, err := local.Accept()
8685
if err != nil {
8786
return err
8887
}
89-
go func(localConn net.Conn) {
90-
defer localConn.Close()
91-
remoteConn, err := sshClient.Dial("tcp", remoteEndpoint.String())
88+
go func() {
89+
err := func() error {
90+
defer localConn.Close()
91+
remoteConn, err := sshClient.Dial("tcp", remoteEndpoint.String())
92+
if err != nil {
93+
log.Errorf("Failed to dial %s: %s", remoteEndpoint.String(), err)
94+
return err
95+
}
96+
defer remoteConn.Close()
97+
copyStream(localConn, remoteConn)
98+
return nil
99+
}()
92100
if err != nil {
93-
log.Errorf("Failed to dial %s: %s", remoteEndpoint.String(), err)
94-
return
101+
select {
102+
case errChan <- err:
103+
default:
104+
}
95105
}
96-
defer remoteConn.Close()
97-
copyStream(localConn, remoteConn)
98-
}(localConn)
106+
}()
99107
}
100-
return ctx.Err()
101108
}
102109

103110
// DialSshRemote https://github.com/golang/go/issues/21478
@@ -225,12 +232,12 @@ func publicKeyFile(file string) (ssh.AuthMethod, error) {
225232
return ssh.PublicKeys(key), nil
226233
}
227234

228-
func copyStream(client net.Conn, remote net.Conn) {
235+
func copyStream(local net.Conn, remote net.Conn) {
229236
chDone := make(chan bool, 2)
230237

231238
// start remote -> local data transfer
232239
go func() {
233-
_, err := io.Copy(client, remote)
240+
_, err := io.Copy(local, remote)
234241
if err != nil && !errors.Is(err, net.ErrClosed) {
235242
log.Debugf("error while copy remote->local: %s", err)
236243
}
@@ -242,7 +249,7 @@ func copyStream(client net.Conn, remote net.Conn) {
242249

243250
// start local -> remote data transfer
244251
go func() {
245-
_, err := io.Copy(remote, client)
252+
_, err := io.Copy(remote, local)
246253
if err != nil && !errors.Is(err, net.ErrClosed) {
247254
log.Debugf("error while copy local->remote: %s", err)
248255
}
@@ -388,22 +395,44 @@ func init() {
388395
})
389396
}
390397

391-
func PortMapUntil(ctx context.Context, cli *ssh.Client, remote, local netip.AddrPort) error {
398+
func PortMapUntil(ctx context.Context, conf *SshConfig, remote, local netip.AddrPort) error {
399+
// Listen on remote server port
400+
var lc net.ListenConfig
401+
localListen, err := lc.Listen(ctx, "tcp", local.String())
402+
if err != nil {
403+
return err
404+
}
405+
defer localListen.Close()
406+
392407
errChan := make(chan error, 1)
393408
readyChan := make(chan struct{}, 1)
394409
go func() {
395410
for ctx.Err() == nil {
396-
err := PortMap(ctx, cli, remote, local, readyChan)
397-
if err != nil {
398-
if !errors.Is(err, context.Canceled) {
399-
log.Errorf("Ssh forward failed err: %v", err)
411+
func() {
412+
ctx2, cancelFunc := context.WithCancel(ctx)
413+
defer cancelFunc()
414+
cli, err := DialSshRemote(ctx2, conf)
415+
if err != nil {
416+
time.Sleep(time.Second * 2)
417+
log.Errorf("failed to dial remote ssh server: %v", err)
418+
return
400419
}
401420
select {
402-
case errChan <- err:
421+
case readyChan <- struct{}{}:
403422
default:
404423
}
405-
}
406-
time.Sleep(time.Second * 2)
424+
err = PortMap(ctx, cli, remote, localListen)
425+
if err != nil {
426+
if !errors.Is(err, context.Canceled) {
427+
log.Errorf("Ssh forward failed err: %v", err)
428+
}
429+
select {
430+
case errChan <- err:
431+
default:
432+
}
433+
}
434+
time.Sleep(time.Second * 2)
435+
}()
407436
}
408437
}()
409438
select {

0 commit comments

Comments
 (0)