Skip to content

Commit d6923bb

Browse files
committed
Merge tag 'v1.80.3' into sunos-1.80
Release 1.80.3
2 parents 10ccacf + bd762b8 commit d6923bb

File tree

6 files changed

+197
-35
lines changed

6 files changed

+197
-35
lines changed

VERSION.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
1.80.2
1+
1.80.3

appc/appconnector.go

+13-11
Original file line numberDiff line numberDiff line change
@@ -289,9 +289,11 @@ func (e *AppConnector) updateDomains(domains []string) {
289289
toRemove = append(toRemove, netip.PrefixFrom(a, a.BitLen()))
290290
}
291291
}
292-
if err := e.routeAdvertiser.UnadvertiseRoute(toRemove...); err != nil {
293-
e.logf("failed to unadvertise routes on domain removal: %v: %v: %v", slicesx.MapKeys(oldDomains), toRemove, err)
294-
}
292+
e.queue.Add(func() {
293+
if err := e.routeAdvertiser.UnadvertiseRoute(toRemove...); err != nil {
294+
e.logf("failed to unadvertise routes on domain removal: %v: %v: %v", slicesx.MapKeys(oldDomains), toRemove, err)
295+
}
296+
})
295297
}
296298

297299
e.logf("handling domains: %v and wildcards: %v", slicesx.MapKeys(e.domains), e.wildcards)
@@ -310,11 +312,6 @@ func (e *AppConnector) updateRoutes(routes []netip.Prefix) {
310312
return
311313
}
312314

313-
if err := e.routeAdvertiser.AdvertiseRoute(routes...); err != nil {
314-
e.logf("failed to advertise routes: %v: %v", routes, err)
315-
return
316-
}
317-
318315
var toRemove []netip.Prefix
319316

320317
// If we're storing routes and know e.controlRoutes is a good
@@ -338,9 +335,14 @@ nextRoute:
338335
}
339336
}
340337

341-
if err := e.routeAdvertiser.UnadvertiseRoute(toRemove...); err != nil {
342-
e.logf("failed to unadvertise routes: %v: %v", toRemove, err)
343-
}
338+
e.queue.Add(func() {
339+
if err := e.routeAdvertiser.AdvertiseRoute(routes...); err != nil {
340+
e.logf("failed to advertise routes: %v: %v", routes, err)
341+
}
342+
if err := e.routeAdvertiser.UnadvertiseRoute(toRemove...); err != nil {
343+
e.logf("failed to unadvertise routes: %v: %v", toRemove, err)
344+
}
345+
})
344346

345347
e.controlRoutes = routes
346348
if err := e.storeRoutesLocked(); err != nil {

appc/appconnector_test.go

+58
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"net/netip"
99
"reflect"
1010
"slices"
11+
"sync/atomic"
1112
"testing"
1213
"time"
1314

@@ -86,6 +87,7 @@ func TestUpdateRoutes(t *testing.T) {
8687

8788
routes := []netip.Prefix{netip.MustParsePrefix("192.0.2.0/24"), netip.MustParsePrefix("192.0.0.1/32")}
8889
a.updateRoutes(routes)
90+
a.Wait(ctx)
8991

9092
slices.SortFunc(rc.Routes(), prefixCompare)
9193
rc.SetRoutes(slices.Compact(rc.Routes()))
@@ -105,6 +107,7 @@ func TestUpdateRoutes(t *testing.T) {
105107
}
106108

107109
func TestUpdateRoutesUnadvertisesContainedRoutes(t *testing.T) {
110+
ctx := context.Background()
108111
for _, shouldStore := range []bool{false, true} {
109112
rc := &appctest.RouteCollector{}
110113
var a *AppConnector
@@ -117,6 +120,7 @@ func TestUpdateRoutesUnadvertisesContainedRoutes(t *testing.T) {
117120
rc.SetRoutes([]netip.Prefix{netip.MustParsePrefix("192.0.2.1/32")})
118121
routes := []netip.Prefix{netip.MustParsePrefix("192.0.2.0/24")}
119122
a.updateRoutes(routes)
123+
a.Wait(ctx)
120124

121125
if !slices.EqualFunc(routes, rc.Routes(), prefixEqual) {
122126
t.Fatalf("got %v, want %v", rc.Routes(), routes)
@@ -636,3 +640,57 @@ func TestMetricBucketsAreSorted(t *testing.T) {
636640
t.Errorf("metricStoreRoutesNBuckets must be in order")
637641
}
638642
}
643+
644+
// TestUpdateRoutesDeadlock is a regression test for a deadlock in
645+
// LocalBackend<->AppConnector interaction. When using real LocalBackend as the
646+
// routeAdvertiser, calls to Advertise/UnadvertiseRoutes can end up calling
647+
// back into AppConnector via authReconfig. If everything is called
648+
// synchronously, this results in a deadlock on AppConnector.mu.
649+
func TestUpdateRoutesDeadlock(t *testing.T) {
650+
ctx := context.Background()
651+
rc := &appctest.RouteCollector{}
652+
a := NewAppConnector(t.Logf, rc, &RouteInfo{}, fakeStoreRoutes)
653+
654+
advertiseCalled := new(atomic.Bool)
655+
unadvertiseCalled := new(atomic.Bool)
656+
rc.AdvertiseCallback = func() {
657+
// Call something that requires a.mu to be held.
658+
a.DomainRoutes()
659+
advertiseCalled.Store(true)
660+
}
661+
rc.UnadvertiseCallback = func() {
662+
// Call something that requires a.mu to be held.
663+
a.DomainRoutes()
664+
unadvertiseCalled.Store(true)
665+
}
666+
667+
a.updateDomains([]string{"example.com"})
668+
a.Wait(ctx)
669+
670+
// Trigger rc.AdveriseRoute.
671+
a.updateRoutes(
672+
[]netip.Prefix{
673+
netip.MustParsePrefix("127.0.0.1/32"),
674+
netip.MustParsePrefix("127.0.0.2/32"),
675+
},
676+
)
677+
a.Wait(ctx)
678+
// Trigger rc.UnadveriseRoute.
679+
a.updateRoutes(
680+
[]netip.Prefix{
681+
netip.MustParsePrefix("127.0.0.1/32"),
682+
},
683+
)
684+
a.Wait(ctx)
685+
686+
if !advertiseCalled.Load() {
687+
t.Error("AdvertiseRoute was not called")
688+
}
689+
if !unadvertiseCalled.Load() {
690+
t.Error("UnadvertiseRoute was not called")
691+
}
692+
693+
if want := []netip.Prefix{netip.MustParsePrefix("127.0.0.1/32")}; !slices.Equal(slices.Compact(rc.Routes()), want) {
694+
t.Fatalf("got %v, want %v", rc.Routes(), want)
695+
}
696+
}

appc/appctest/appctest.go

+13
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,22 @@ import (
1111

1212
// RouteCollector is a test helper that collects the list of routes advertised
1313
type RouteCollector struct {
14+
// AdvertiseCallback (optional) is called synchronously from
15+
// AdvertiseRoute.
16+
AdvertiseCallback func()
17+
// UnadvertiseCallback (optional) is called synchronously from
18+
// UnadvertiseRoute.
19+
UnadvertiseCallback func()
20+
1421
routes []netip.Prefix
1522
removedRoutes []netip.Prefix
1623
}
1724

1825
func (rc *RouteCollector) AdvertiseRoute(pfx ...netip.Prefix) error {
1926
rc.routes = append(rc.routes, pfx...)
27+
if rc.AdvertiseCallback != nil {
28+
rc.AdvertiseCallback()
29+
}
2030
return nil
2131
}
2232

@@ -30,6 +40,9 @@ func (rc *RouteCollector) UnadvertiseRoute(toRemove ...netip.Prefix) error {
3040
rc.removedRoutes = append(rc.removedRoutes, r)
3141
}
3242
}
43+
if rc.UnadvertiseCallback != nil {
44+
rc.UnadvertiseCallback()
45+
}
3346
return nil
3447
}
3548

client/web/web.go

+30-23
Original file line numberDiff line numberDiff line change
@@ -203,15 +203,25 @@ func NewServer(opts ServerOpts) (s *Server, err error) {
203203
}
204204
s.assetsHandler, s.assetsCleanup = assetsHandler(s.devMode)
205205

206-
var metric string // clientmetric to report on startup
206+
var metric string
207+
s.apiHandler, metric = s.modeAPIHandler(s.mode)
208+
s.apiHandler = s.withCSRF(s.apiHandler)
207209

208-
// Create handler for "/api" requests with CSRF protection.
209-
// We don't require secure cookies, since the web client is regularly used
210-
// on network appliances that are served on local non-https URLs.
211-
// The client is secured by limiting the interface it listens on,
212-
// or by authenticating requests before they reach the web client.
210+
// Don't block startup on reporting metric.
211+
// Report in separate go routine with 5 second timeout.
212+
go func() {
213+
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
214+
defer cancel()
215+
s.lc.IncrementCounter(ctx, metric, 1)
216+
}()
217+
218+
return s, nil
219+
}
220+
221+
func (s *Server) withCSRF(h http.Handler) http.Handler {
213222
csrfProtect := csrf.Protect(s.csrfKey(), csrf.Secure(false))
214223

224+
// ref https://github.com/tailscale/tailscale/pull/14822
215225
// signal to the CSRF middleware that the request is being served over
216226
// plaintext HTTP to skip TLS-only header checks.
217227
withSetPlaintext := func(h http.Handler) http.Handler {
@@ -221,27 +231,24 @@ func NewServer(opts ServerOpts) (s *Server, err error) {
221231
})
222232
}
223233

224-
switch s.mode {
234+
// NB: the order of the withSetPlaintext and csrfProtect calls is important
235+
// to ensure that we signal to the CSRF middleware that the request is being
236+
// served over plaintext HTTP and not over TLS as it presumes by default.
237+
return withSetPlaintext(csrfProtect(h))
238+
}
239+
240+
func (s *Server) modeAPIHandler(mode ServerMode) (http.Handler, string) {
241+
switch mode {
225242
case LoginServerMode:
226-
s.apiHandler = csrfProtect(withSetPlaintext(http.HandlerFunc(s.serveLoginAPI)))
227-
metric = "web_login_client_initialization"
243+
return http.HandlerFunc(s.serveLoginAPI), "web_login_client_initialization"
228244
case ReadOnlyServerMode:
229-
s.apiHandler = csrfProtect(withSetPlaintext(http.HandlerFunc(s.serveLoginAPI)))
230-
metric = "web_readonly_client_initialization"
245+
return http.HandlerFunc(s.serveLoginAPI), "web_readonly_client_initialization"
231246
case ManageServerMode:
232-
s.apiHandler = csrfProtect(withSetPlaintext(http.HandlerFunc(s.serveAPI)))
233-
metric = "web_client_initialization"
247+
return http.HandlerFunc(s.serveAPI), "web_client_initialization"
248+
default: // invalid mode
249+
log.Fatalf("invalid mode: %v", mode)
234250
}
235-
236-
// Don't block startup on reporting metric.
237-
// Report in separate go routine with 5 second timeout.
238-
go func() {
239-
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
240-
defer cancel()
241-
s.lc.IncrementCounter(ctx, metric, 1)
242-
}()
243-
244-
return s, nil
251+
return nil, ""
245252
}
246253

247254
func (s *Server) Shutdown() {

client/web/web_test.go

+82
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"fmt"
1212
"io"
1313
"net/http"
14+
"net/http/cookiejar"
1415
"net/http/httptest"
1516
"net/netip"
1617
"net/url"
@@ -20,6 +21,7 @@ import (
2021
"time"
2122

2223
"github.com/google/go-cmp/cmp"
24+
"github.com/gorilla/csrf"
2325
"tailscale.com/client/tailscale"
2426
"tailscale.com/client/tailscale/apitype"
2527
"tailscale.com/ipn"
@@ -1477,3 +1479,83 @@ func mockWaitAuthURL(_ context.Context, id string, src tailcfg.NodeID) (*tailcfg
14771479
return nil, errors.New("unknown id")
14781480
}
14791481
}
1482+
1483+
func TestCSRFProtect(t *testing.T) {
1484+
s := &Server{}
1485+
1486+
mux := http.NewServeMux()
1487+
mux.HandleFunc("GET /test/csrf-token", func(w http.ResponseWriter, r *http.Request) {
1488+
token := csrf.Token(r)
1489+
_, err := io.WriteString(w, token)
1490+
if err != nil {
1491+
t.Fatal(err)
1492+
}
1493+
})
1494+
mux.HandleFunc("POST /test/csrf-protected", func(w http.ResponseWriter, r *http.Request) {
1495+
_, err := io.WriteString(w, "ok")
1496+
if err != nil {
1497+
t.Fatal(err)
1498+
}
1499+
})
1500+
h := s.withCSRF(mux)
1501+
ser := httptest.NewServer(h)
1502+
defer ser.Close()
1503+
1504+
jar, err := cookiejar.New(nil)
1505+
if err != nil {
1506+
t.Fatalf("unable to construct cookie jar: %v", err)
1507+
}
1508+
1509+
client := ser.Client()
1510+
client.Jar = jar
1511+
1512+
// make GET request to populate cookie jar
1513+
resp, err := client.Get(ser.URL + "/test/csrf-token")
1514+
if err != nil {
1515+
t.Fatalf("unable to make request: %v", err)
1516+
}
1517+
defer resp.Body.Close()
1518+
if resp.StatusCode != http.StatusOK {
1519+
t.Fatalf("unexpected status: %v", resp.Status)
1520+
}
1521+
tokenBytes, err := io.ReadAll(resp.Body)
1522+
if err != nil {
1523+
t.Fatalf("unable to read body: %v", err)
1524+
}
1525+
1526+
csrfToken := strings.TrimSpace(string(tokenBytes))
1527+
if csrfToken == "" {
1528+
t.Fatal("empty csrf token")
1529+
}
1530+
1531+
// make a POST request without the CSRF header; ensure it fails
1532+
resp, err = client.Post(ser.URL+"/test/csrf-protected", "text/plain", nil)
1533+
if err != nil {
1534+
t.Fatalf("unable to make request: %v", err)
1535+
}
1536+
if resp.StatusCode != http.StatusForbidden {
1537+
t.Fatalf("unexpected status: %v", resp.Status)
1538+
}
1539+
1540+
// make a POST request with the CSRF header; ensure it succeeds
1541+
req, err := http.NewRequest("POST", ser.URL+"/test/csrf-protected", nil)
1542+
if err != nil {
1543+
t.Fatalf("error building request: %v", err)
1544+
}
1545+
req.Header.Set("X-CSRF-Token", csrfToken)
1546+
resp, err = client.Do(req)
1547+
if err != nil {
1548+
t.Fatalf("unable to make request: %v", err)
1549+
}
1550+
if resp.StatusCode != http.StatusOK {
1551+
t.Fatalf("unexpected status: %v", resp.Status)
1552+
}
1553+
defer resp.Body.Close()
1554+
out, err := io.ReadAll(resp.Body)
1555+
if err != nil {
1556+
t.Fatalf("unable to read body: %v", err)
1557+
}
1558+
if string(out) != "ok" {
1559+
t.Fatalf("unexpected body: %q", out)
1560+
}
1561+
}

0 commit comments

Comments
 (0)