@@ -67,39 +67,6 @@ func (s *SshConfig) ToRPC() *rpc.SshJump {
67
67
}
68
68
}
69
69
70
- func PortMap (ctx context.Context , sshClient * ssh.Client , remoteEndpoint , localEndpoint netip.AddrPort , done chan struct {}) error {
71
- // Listen on remote server port
72
- var lc net.ListenConfig
73
- listen , err := lc .Listen (ctx , "tcp" , localEndpoint .String ())
74
- if err != nil {
75
- return err
76
- }
77
- defer listen .Close ()
78
-
79
- select {
80
- case done <- struct {}{}:
81
- default :
82
- }
83
- // handle incoming connections on reverse forwarded tunnel
84
- for ctx .Err () == nil {
85
- localConn , err := listen .Accept ()
86
- if err != nil {
87
- return err
88
- }
89
- go func (localConn net.Conn ) {
90
- defer localConn .Close ()
91
- remoteConn , err := sshClient .Dial ("tcp" , remoteEndpoint .String ())
92
- if err != nil {
93
- log .Errorf ("Failed to dial %s: %s" , remoteEndpoint .String (), err )
94
- return
95
- }
96
- defer remoteConn .Close ()
97
- copyStream (localConn , remoteConn )
98
- }(localConn )
99
- }
100
- return ctx .Err ()
101
- }
102
-
103
70
// DialSshRemote https://github.com/golang/go/issues/21478
104
71
func DialSshRemote (ctx context.Context , conf * SshConfig ) (* ssh.Client , error ) {
105
72
var remote * ssh.Client
@@ -170,6 +137,7 @@ func DialSshRemote(ctx context.Context, conf *SshConfig) (*ssh.Client, error) {
170
137
_ , _ , er := remote .SendRequest ("keepalive@golang.org" , true , nil )
171
138
if er != nil {
172
139
log .Errorf ("failed to send keep alive error: %s" , er )
140
+ return
173
141
}
174
142
}
175
143
}
@@ -225,12 +193,12 @@ func publicKeyFile(file string) (ssh.AuthMethod, error) {
225
193
return ssh .PublicKeys (key ), nil
226
194
}
227
195
228
- func copyStream (client net.Conn , remote net.Conn ) {
196
+ func copyStream (local net.Conn , remote net.Conn ) {
229
197
chDone := make (chan bool , 2 )
230
198
231
199
// start remote -> local data transfer
232
200
go func () {
233
- _ , err := io .Copy (client , remote )
201
+ _ , err := io .Copy (local , remote )
234
202
if err != nil && ! errors .Is (err , net .ErrClosed ) {
235
203
log .Debugf ("error while copy remote->local: %s" , err )
236
204
}
@@ -242,7 +210,7 @@ func copyStream(client net.Conn, remote net.Conn) {
242
210
243
211
// start local -> remote data transfer
244
212
go func () {
245
- _ , err := io .Copy (remote , client )
213
+ _ , err := io .Copy (remote , local )
246
214
if err != nil && ! errors .Is (err , net .ErrClosed ) {
247
215
log .Debugf ("error while copy local->remote: %s" , err )
248
216
}
@@ -388,31 +356,65 @@ func init() {
388
356
})
389
357
}
390
358
391
- func PortMapUntil (ctx context.Context , cli * ssh.Client , remote , local netip.AddrPort ) error {
392
- errChan := make (chan error , 1 )
393
- readyChan := make (chan struct {}, 1 )
359
+ func PortMapUntil (ctx context.Context , conf * SshConfig , remote , local netip.AddrPort ) error {
360
+ // Listen on remote server port
361
+ var lc net.ListenConfig
362
+ localListen , err := lc .Listen (ctx , "tcp" , local .String ())
363
+ if err != nil {
364
+ return err
365
+ }
366
+
367
+ var lock sync.Mutex
368
+ var cancelFunc context.CancelFunc
369
+ var sshClient * ssh.Client
370
+
371
+ var getRemoteConnFunc = func () (net.Conn , error ) {
372
+ lock .Lock ()
373
+ defer lock .Unlock ()
374
+
375
+ if sshClient != nil {
376
+ remoteConn , err := sshClient .Dial ("tcp" , remote .String ())
377
+ if err == nil {
378
+ return remoteConn , nil
379
+ }
380
+ sshClient .Close ()
381
+ if cancelFunc != nil {
382
+ cancelFunc ()
383
+ }
384
+ }
385
+ var ctx2 context.Context
386
+ ctx2 , cancelFunc = context .WithCancel (ctx )
387
+ sshClient , err = DialSshRemote (ctx2 , conf )
388
+ if err != nil {
389
+ cancelFunc ()
390
+ cancelFunc = nil
391
+ log .Errorf ("failed to dial remote ssh server: %v" , err )
392
+ return nil , err
393
+ }
394
+ return sshClient .Dial ("tcp" , remote .String ())
395
+ }
396
+
394
397
go func () {
398
+ defer localListen .Close ()
399
+
395
400
for ctx .Err () == nil {
396
- err := PortMap ( ctx , cli , remote , local , readyChan )
401
+ localConn , err := localListen . Accept ( )
397
402
if err != nil {
398
- if ! errors .Is (err , context .Canceled ) {
399
- log .Errorf ("Ssh forward failed err: %v" , err )
400
- }
401
- select {
402
- case errChan <- err :
403
- default :
404
- }
403
+ log .Errorf ("failed to accept conn: %v" , err )
404
+ return
405
405
}
406
- time .Sleep (time .Second * 2 )
406
+ go func () {
407
+ defer localConn .Close ()
408
+
409
+ remoteConn , err := getRemoteConnFunc ()
410
+ if err != nil {
411
+ log .Errorf ("Failed to dial %s: %s" , remote .String (), err )
412
+ return
413
+ }
414
+ defer remoteConn .Close ()
415
+ copyStream (localConn , remoteConn )
416
+ }()
407
417
}
408
418
}()
409
- select {
410
- case <- readyChan :
411
- return nil
412
- case err := <- errChan :
413
- log .Errorf ("Ssh forward failed err: %v" , err )
414
- return err
415
- case <- ctx .Done ():
416
- return ctx .Err ()
417
- }
419
+ return nil
418
420
}
0 commit comments