diff --git a/apps/confidential/confidential.go b/apps/confidential/confidential.go index 5b375794..22c17d20 100644 --- a/apps/confidential/confidential.go +++ b/apps/confidential/confidential.go @@ -305,7 +305,9 @@ func WithInstanceDiscovery(enabled bool) Option { // If an invalid region name is provided, the non-regional endpoint MIGHT be used or the token request MIGHT fail. func WithAzureRegion(val string) Option { return func(o *clientOptions) { - o.azureRegion = val + if val != "" { + o.azureRegion = val + } } } diff --git a/apps/confidential/confidential_test.go b/apps/confidential/confidential_test.go index 1036c4f7..8d6d61a9 100644 --- a/apps/confidential/confidential_test.go +++ b/apps/confidential/confidential_test.go @@ -195,50 +195,65 @@ func TestRegionAutoEnable_EmptyRegion_EnvRegion(t *testing.T) { } } -func TestRegionAutoEnable_SpecifiedRegion_EnvRegion(t *testing.T) { - cred, err := NewCredFromSecret(fakeSecret) - if err != nil { - t.Fatal(err) - } - - envRegion := "envRegion" - err = os.Setenv("MSAL_FORCE_REGION", envRegion) - if err != nil { - t.Fatal(err) - } - defer os.Unsetenv("MSAL_FORCE_REGION") - - lmo := "login.microsoftonline.com" - tenant := "tenant" - mockClient := mock.Client{} - testRegion := "region" - client, err := New(fmt.Sprintf(authorityFmt, lmo, tenant), fakeClientID, cred, WithHTTPClient(&mockClient), WithAzureRegion(testRegion)) - if err != nil { - t.Fatal(err) - } - - if client.base.AuthParams.AuthorityInfo.Region != testRegion { - t.Fatalf("wanted %q, got %q", testRegion, client.base.AuthParams.AuthorityInfo.Region) +func TestRegionAutoEnable_SpecifiedEmptyRegion_EnvRegion(t *testing.T) { + tests := []struct { + name string + envRegion string + region string + resultRegion string + }{ + { + name: "Region is empty, envRegion is set", + envRegion: "region", + region: "", + resultRegion: "region", + }, + { + name: "Region is set, envRegion is set", + envRegion: "region", + region: "setRegion", + resultRegion: "setRegion", + }, + { + name: "Region is set, envRegion is empty", + envRegion: "", + region: "setRegion", + resultRegion: "setRegion", + }, + { + name: "Disable region is set, envRegion is set", + envRegion: "region", + region: "DisableMsalForceRegion", + resultRegion: "", + }, } -} -func TestRegionAutoEnable_DisableMsalForceRegion(t *testing.T) { - cred, err := NewCredFromSecret(fakeSecret) - if err != nil { - t.Fatal(err) - } + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + cred, err := NewCredFromSecret(fakeSecret) + if err != nil { + t.Fatal(err) + } + if test.envRegion != "" { + t.Setenv("MSAL_FORCE_REGION", test.envRegion) + } + lmo := "login.microsoftonline.com" + tenant := "tenant" + mockClient := mock.Client{} - lmo := "login.microsoftonline.com" - tenant := "tenant" - mockClient := mock.Client{} - testRegion := "DisableMsalForceRegion" - client, err := New(fmt.Sprintf(authorityFmt, lmo, tenant), fakeClientID, cred, WithHTTPClient(&mockClient), WithAzureRegion(testRegion)) - if err != nil { - t.Fatal(err) - } + client, err := New(fmt.Sprintf(authorityFmt, lmo, tenant), fakeClientID, cred, WithHTTPClient(&mockClient), WithAzureRegion(test.region)) + if err != nil { + t.Fatal(err) + } - if client.base.AuthParams.AuthorityInfo.Region != "" { - t.Fatalf("wanted empty, got %q", client.base.AuthParams.AuthorityInfo.Region) + if test.resultRegion == "DisableMsalForceRegion" { + if client.base.AuthParams.AuthorityInfo.Region != "" { + t.Fatalf("wanted %q, got %q", test.resultRegion, client.base.AuthParams.AuthorityInfo.Region) + } + } else if client.base.AuthParams.AuthorityInfo.Region != test.resultRegion { + t.Fatalf("wanted %q, got %q", test.resultRegion, client.base.AuthParams.AuthorityInfo.Region) + } + }) } }