@@ -67,37 +67,44 @@ func (s *SshConfig) ToRPC() *rpc.SshJump {
67
67
}
68
68
}
69
69
70
- func PortMap (ctx context.Context , sshClient * ssh.Client , remoteEndpoint , localEndpoint netip.AddrPort , done chan struct {}) error {
71
- // Listen on remote server port
72
- var lc net.ListenConfig
73
- listen , err := lc .Listen (ctx , "tcp" , localEndpoint .String ())
74
- if err != nil {
75
- return err
76
- }
77
- defer listen .Close ()
70
+ func PortMap (ctx context.Context , sshClient * ssh.Client , remoteEndpoint netip.AddrPort , local net.Listener ) error {
71
+ defer sshClient .Close ()
78
72
79
- select {
80
- case done <- struct {}{}:
81
- default :
82
- }
73
+ errChan := make (chan error , 1 )
83
74
// handle incoming connections on reverse forwarded tunnel
84
- for ctx .Err () == nil {
85
- localConn , err := listen .Accept ()
75
+ for {
76
+ select {
77
+ case <- ctx .Done ():
78
+ return ctx .Err ()
79
+ case err := <- errChan :
80
+ return err
81
+ default :
82
+ }
83
+
84
+ localConn , err := local .Accept ()
86
85
if err != nil {
87
86
return err
88
87
}
89
- go func (localConn net.Conn ) {
90
- defer localConn .Close ()
91
- remoteConn , err := sshClient .Dial ("tcp" , remoteEndpoint .String ())
88
+ go func () {
89
+ err := func () error {
90
+ defer localConn .Close ()
91
+ remoteConn , err := sshClient .Dial ("tcp" , remoteEndpoint .String ())
92
+ if err != nil {
93
+ log .Errorf ("Failed to dial %s: %s" , remoteEndpoint .String (), err )
94
+ return err
95
+ }
96
+ defer remoteConn .Close ()
97
+ copyStream (localConn , remoteConn )
98
+ return nil
99
+ }()
92
100
if err != nil {
93
- log .Errorf ("Failed to dial %s: %s" , remoteEndpoint .String (), err )
94
- return
101
+ select {
102
+ case errChan <- err :
103
+ default :
104
+ }
95
105
}
96
- defer remoteConn .Close ()
97
- copyStream (localConn , remoteConn )
98
- }(localConn )
106
+ }()
99
107
}
100
- return ctx .Err ()
101
108
}
102
109
103
110
// DialSshRemote https://github.com/golang/go/issues/21478
@@ -225,12 +232,12 @@ func publicKeyFile(file string) (ssh.AuthMethod, error) {
225
232
return ssh .PublicKeys (key ), nil
226
233
}
227
234
228
- func copyStream (client net.Conn , remote net.Conn ) {
235
+ func copyStream (local net.Conn , remote net.Conn ) {
229
236
chDone := make (chan bool , 2 )
230
237
231
238
// start remote -> local data transfer
232
239
go func () {
233
- _ , err := io .Copy (client , remote )
240
+ _ , err := io .Copy (local , remote )
234
241
if err != nil && ! errors .Is (err , net .ErrClosed ) {
235
242
log .Debugf ("error while copy remote->local: %s" , err )
236
243
}
@@ -242,7 +249,7 @@ func copyStream(client net.Conn, remote net.Conn) {
242
249
243
250
// start local -> remote data transfer
244
251
go func () {
245
- _ , err := io .Copy (remote , client )
252
+ _ , err := io .Copy (remote , local )
246
253
if err != nil && ! errors .Is (err , net .ErrClosed ) {
247
254
log .Debugf ("error while copy local->remote: %s" , err )
248
255
}
@@ -388,22 +395,44 @@ func init() {
388
395
})
389
396
}
390
397
391
- func PortMapUntil (ctx context.Context , cli * ssh.Client , remote , local netip.AddrPort ) error {
398
+ func PortMapUntil (ctx context.Context , conf * SshConfig , remote , local netip.AddrPort ) error {
399
+ // Listen on remote server port
400
+ var lc net.ListenConfig
401
+ localListen , err := lc .Listen (ctx , "tcp" , local .String ())
402
+ if err != nil {
403
+ return err
404
+ }
405
+ defer localListen .Close ()
406
+
392
407
errChan := make (chan error , 1 )
393
408
readyChan := make (chan struct {}, 1 )
394
409
go func () {
395
410
for ctx .Err () == nil {
396
- err := PortMap (ctx , cli , remote , local , readyChan )
397
- if err != nil {
398
- if ! errors .Is (err , context .Canceled ) {
399
- log .Errorf ("Ssh forward failed err: %v" , err )
411
+ func () {
412
+ ctx2 , cancelFunc := context .WithCancel (ctx )
413
+ defer cancelFunc ()
414
+ cli , err := DialSshRemote (ctx2 , conf )
415
+ if err != nil {
416
+ time .Sleep (time .Second * 2 )
417
+ log .Errorf ("failed to dial remote ssh server: %v" , err )
418
+ return
400
419
}
401
420
select {
402
- case errChan <- err :
421
+ case readyChan <- struct {}{} :
403
422
default :
404
423
}
405
- }
406
- time .Sleep (time .Second * 2 )
424
+ err = PortMap (ctx , cli , remote , localListen )
425
+ if err != nil {
426
+ if ! errors .Is (err , context .Canceled ) {
427
+ log .Errorf ("Ssh forward failed err: %v" , err )
428
+ }
429
+ select {
430
+ case errChan <- err :
431
+ default :
432
+ }
433
+ }
434
+ time .Sleep (time .Second * 2 )
435
+ }()
407
436
}
408
437
}()
409
438
select {
0 commit comments