Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
…ation-library-for-go into 4gust/force-token-refresh
  • Loading branch information
4gust committed Feb 17, 2025
2 parents 2700b64 + e6d9244 commit 87e17fa
Show file tree
Hide file tree
Showing 35 changed files with 2,627 additions and 105 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/go.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ jobs:
run: go build ./apps/...

- name: Unit Tests
run: go test -race -short ./apps/cache/... ./apps/confidential/... ./apps/public/... ./apps/internal/...
run: go test -race -short ./apps/cache/... ./apps/confidential/... ./apps/public/... ./apps/internal/... ./apps/managedidentity/...
# Intergration tests runs on ADO
# - name: Integration Tests
# run: go test -race ./apps/tests/integration/...
Expand Down
32 changes: 32 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,28 @@ Acquiring tokens with MSAL Go follows this general pattern. There might be some
}
confidentialClient, err := confidential.New("https://login.microsoftonline.com/your_tenant", "client_id", cred)
```
* Initializing a Managed Identity client for SystemAssigned:

```go
import mi "github.com/AzureAD/microsoft-authentication-library-for-go/apps/managedidentity"
// Managed identity client have a type of ID required, SystemAssigned or UserAssigned
miSystemAssigned, err := mi.New(mi.SystemAssigned())
if err != nil {
// TODO: handle error
}
```
* Initializing a Managed Identity client for UserAssigned:

```go
import mi "github.com/AzureAD/microsoft-authentication-library-for-go/apps/managedidentity"
// Managed identity client have a type of ID required, SystemAssigned or UserAssigned
miSystemAssigned, err := mi.New(mi.UserAssignedClientID("YOUR_CLIENT_ID"))
if err != nil {
// TODO: handle error
}
```

1. Call `AcquireTokenSilent()` to look for a cached token. If `AcquireTokenSilent()` returns an error, call another `AcquireToken...` method to authenticate.

Expand Down Expand Up @@ -96,6 +118,16 @@ Acquiring tokens with MSAL Go follows this general pattern. There might be some
accessToken := result.AccessToken
```

* ManagedIdentity clietn can simply call `AcquireToken()`:
```go
resource := "<Your resource>"
result, err := miSystemAssigned.AcquireToken(context.TODO(), resource)
if err != nil {
// TODO: handle error
}
accessToken := result.AccessToken
```

## Community Help and Support

We use [Stack Overflow](http://stackoverflow.com/questions/tagged/msal) to work with the community on supporting Azure Active Directory and its SDKs, including this one! We highly recommend you ask your questions on Stack Overflow (we're all on there!) Also browse existing issues to see if someone has had your question before. Please use the "msal" tag when asking your questions.
Expand Down
4 changes: 3 additions & 1 deletion apps/confidential/confidential.go
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,9 @@ func WithInstanceDiscovery(enabled bool) Option {
// If an invalid region name is provided, the non-regional endpoint MIGHT be used or the token request MIGHT fail.
func WithAzureRegion(val string) Option {
return func(o *clientOptions) {
o.azureRegion = val
if val != "" {
o.azureRegion = val
}
}
}

Expand Down
150 changes: 100 additions & 50 deletions apps/confidential/confidential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"crypto/x509"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
Expand All @@ -26,6 +25,7 @@ import (
"github.com/kylelemons/godebug/pretty"

"github.com/AzureAD/microsoft-authentication-library-for-go/apps/cache"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/errors"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/base"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/exported"
internalTime "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/json/types/time"
Expand All @@ -37,6 +37,7 @@ import (

// errorClient is an HTTP client for tests that should fail when confidential.Client sends a request
type errorClient struct{}
type contextKey struct{}

func (*errorClient) Do(req *http.Request) (*http.Response, error) {
return nil, fmt.Errorf("expected no requests but received one for %s", req.URL.String())
Expand Down Expand Up @@ -141,7 +142,7 @@ func TestAcquireTokenByCredential(t *testing.T) {
client, err := fakeClient(accesstokens.TokenResponse{
AccessToken: token,
RefreshOn: internalTime.DurationTime{T: time.Now().Add(6 * time.Hour)},
ExpiresOn: internalTime.DurationTime{T: time.Now().Add(12 * time.Hour)},
ExpiresOn: time.Now().Add(12 * time.Hour),
ExtExpiresOn: internalTime.DurationTime{T: time.Now().Add(12 * time.Hour)},
GrantedScopes: accesstokens.Scopes{Slice: tokenScope},
TokenType: "Bearer",
Expand Down Expand Up @@ -198,50 +199,65 @@ func TestRegionAutoEnable_EmptyRegion_EnvRegion(t *testing.T) {
}
}

func TestRegionAutoEnable_SpecifiedRegion_EnvRegion(t *testing.T) {
cred, err := NewCredFromSecret(fakeSecret)
if err != nil {
t.Fatal(err)
}

envRegion := "envRegion"
err = os.Setenv("MSAL_FORCE_REGION", envRegion)
if err != nil {
t.Fatal(err)
}
defer os.Unsetenv("MSAL_FORCE_REGION")

lmo := "login.microsoftonline.com"
tenant := "tenant"
mockClient := mock.Client{}
testRegion := "region"
client, err := New(fmt.Sprintf(authorityFmt, lmo, tenant), fakeClientID, cred, WithHTTPClient(&mockClient), WithAzureRegion(testRegion))
if err != nil {
t.Fatal(err)
}

if client.base.AuthParams.AuthorityInfo.Region != testRegion {
t.Fatalf("wanted %q, got %q", testRegion, client.base.AuthParams.AuthorityInfo.Region)
func TestRegionAutoEnable_SpecifiedEmptyRegion_EnvRegion(t *testing.T) {
tests := []struct {
name string
envRegion string
region string
resultRegion string
}{
{
name: "Region is empty, envRegion is set",
envRegion: "region",
region: "",
resultRegion: "region",
},
{
name: "Region is set, envRegion is set",
envRegion: "region",
region: "setRegion",
resultRegion: "setRegion",
},
{
name: "Region is set, envRegion is empty",
envRegion: "",
region: "setRegion",
resultRegion: "setRegion",
},
{
name: "Disable region is set, envRegion is set",
envRegion: "region",
region: "DisableMsalForceRegion",
resultRegion: "",
},
}
}

func TestRegionAutoEnable_DisableMsalForceRegion(t *testing.T) {
cred, err := NewCredFromSecret(fakeSecret)
if err != nil {
t.Fatal(err)
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
cred, err := NewCredFromSecret(fakeSecret)
if err != nil {
t.Fatal(err)
}
if test.envRegion != "" {
t.Setenv("MSAL_FORCE_REGION", test.envRegion)
}
lmo := "login.microsoftonline.com"
tenant := "tenant"
mockClient := mock.Client{}

lmo := "login.microsoftonline.com"
tenant := "tenant"
mockClient := mock.Client{}
testRegion := "DisableMsalForceRegion"
client, err := New(fmt.Sprintf(authorityFmt, lmo, tenant), fakeClientID, cred, WithHTTPClient(&mockClient), WithAzureRegion(testRegion))
if err != nil {
t.Fatal(err)
}
client, err := New(fmt.Sprintf(authorityFmt, lmo, tenant), fakeClientID, cred, WithHTTPClient(&mockClient), WithAzureRegion(test.region))
if err != nil {
t.Fatal(err)
}

if client.base.AuthParams.AuthorityInfo.Region != "" {
t.Fatalf("wanted empty, got %q", client.base.AuthParams.AuthorityInfo.Region)
if test.resultRegion == "DisableMsalForceRegion" {
if client.base.AuthParams.AuthorityInfo.Region != "" {
t.Fatalf("wanted %q, got %q", test.resultRegion, client.base.AuthParams.AuthorityInfo.Region)
}
} else if client.base.AuthParams.AuthorityInfo.Region != test.resultRegion {
t.Fatalf("wanted %q, got %q", test.resultRegion, client.base.AuthParams.AuthorityInfo.Region)
}
})
}
}

Expand Down Expand Up @@ -293,7 +309,7 @@ func TestAcquireTokenOnBehalfOf(t *testing.T) {

func TestAcquireTokenByAssertionCallback(t *testing.T) {
calls := 0
key := struct{}{}
key := contextKey{}
ctx := context.WithValue(context.Background(), key, true)
getAssertion := func(c context.Context, o AssertionRequestOptions) (string, error) {
if v := c.Value(key); v == nil || !v.(bool) {
Expand Down Expand Up @@ -346,7 +362,7 @@ func TestAcquireTokenByAuthCode(t *testing.T) {
tr := accesstokens.TokenResponse{
AccessToken: token,
RefreshToken: refresh,
ExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)},
ExpiresOn: time.Now().Add(1 * time.Hour),
ExtExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)},
GrantedScopes: accesstokens.Scopes{Slice: tokenScope},
IDToken: accesstokens.IDToken{
Expand Down Expand Up @@ -415,6 +431,40 @@ func TestAcquireTokenByAuthCode(t *testing.T) {
}
}

func TestInvalidJsonErrFromResponse(t *testing.T) {
cred, err := NewCredFromSecret(fakeSecret)
if err != nil {
t.Fatal(err)
}
tenant := "A"
lmo := "login.microsoftonline.com"
mockClient := mock.Client{}
mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(lmo, tenant)))
client, err := New(fmt.Sprintf(authorityFmt, lmo, tenant), fakeClientID, cred, WithHTTPClient(&mockClient))
if err != nil {
t.Fatal(err)
}
ctx := context.Background()
// cache an access token for each tenant. To simplify determining their provenance below, the value of each token is the ID of the tenant that provided it.
if _, err = client.AcquireTokenSilent(ctx, tokenScope, WithTenantID(tenant)); err == nil {
t.Fatal("silent auth should fail because the cache is empty")
}
mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, tenant)))
body := fmt.Sprintf(
`{"access_token": "%s","expires_in": %d,"expires_on": %d,"token_type": "Bearer"`,
tenant, 3600, time.Now().Add(time.Duration(3600)*time.Second).Unix(),
)
mockClient.AppendResponse(mock.WithBody([]byte(body)))
_, err = client.AcquireTokenByCredential(ctx, tokenScope, WithTenantID(tenant))
if err == nil {
t.Fatal("should have failed with InvalidJsonErr Response")
}
var ie errors.InvalidJsonErr
if !errors.As(err, &ie) {
t.Fatal("should have revieved a InvalidJsonErr, but got", err)
}
}

func TestAcquireTokenSilentTenants(t *testing.T) {
cred, err := NewCredFromSecret(fakeSecret)
if err != nil {
Expand Down Expand Up @@ -466,7 +516,7 @@ func TestADFSTokenCaching(t *testing.T) {
AccessToken: "at1",
RefreshToken: "rt",
TokenType: "bearer",
ExpiresOn: internalTime.DurationTime{T: time.Now().Add(time.Hour)},
ExpiresOn: time.Now().Add(time.Hour),
ExtExpiresOn: internalTime.DurationTime{T: time.Now().Add(time.Hour)},
GrantedScopes: accesstokens.Scopes{Slice: tokenScope},
IDToken: accesstokens.IDToken{
Expand Down Expand Up @@ -596,7 +646,7 @@ func TestNewCredFromCert(t *testing.T) {
t.Run(fmt.Sprintf("%s/%v", filepath.Base(file.path), sendX5c), func(t *testing.T) {
client, err := fakeClient(accesstokens.TokenResponse{
AccessToken: token,
ExpiresOn: internalTime.DurationTime{T: time.Now().Add(time.Hour)},
ExpiresOn: time.Now().Add(time.Hour),
GrantedScopes: accesstokens.Scopes{Slice: tokenScope},
}, cred, fakeAuthority, opts...)
if err != nil {
Expand Down Expand Up @@ -1220,7 +1270,7 @@ func TestWithClaims(t *testing.T) {
case "password":
ar, err = client.AcquireTokenByUsernamePassword(ctx, tokenScope, "username", "password", WithClaims(test.claims))
default:
t.Fatalf("test bug: no test for " + method)
t.Fatalf("test bug: no test for %s", method)
}
if err != nil {
t.Fatal(err)
Expand Down Expand Up @@ -1330,7 +1380,7 @@ func TestWithTenantID(t *testing.T) {
case "obo":
ar, err = client.AcquireTokenOnBehalfOf(ctx, "assertion", tokenScope, WithTenantID(test.tenant))
default:
t.Fatalf("test bug: no test for " + method)
t.Fatalf("test bug: no test for %s", method)
}
if err != nil {
if test.expectError {
Expand Down Expand Up @@ -1640,7 +1690,7 @@ func TestWithAuthenticationScheme(t *testing.T) {
}
client, err := fakeClient(accesstokens.TokenResponse{
AccessToken: token,
ExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)},
ExpiresOn: time.Now().Add(1 * time.Hour),
ExtExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)},
GrantedScopes: accesstokens.Scopes{Slice: tokenScope},
TokenType: "TokenType",
Expand Down Expand Up @@ -1680,7 +1730,7 @@ func TestAcquireTokenByCredentialFromDSTS(t *testing.T) {
}
client, err := fakeClient(accesstokens.TokenResponse{
AccessToken: token,
ExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)},
ExpiresOn: time.Now().Add(1 * time.Hour),
ExtExpiresOn: internalTime.DurationTime{T: time.Now().Add(1 * time.Hour)},
GrantedScopes: accesstokens.Scopes{Slice: tokenScope},
TokenType: "Bearer",
Expand Down
9 changes: 9 additions & 0 deletions apps/errors/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,20 @@ type CallErr struct {
Err error
}

type InvalidJsonErr struct {
Err error
}

// Errors implements error.Error().
func (e CallErr) Error() string {
return e.Err.Error()
}

// Errors implements error.Error().
func (e InvalidJsonErr) Error() string {
return e.Err.Error()
}

// Verbose prints a versbose error message with the request or response.
func (e CallErr) Verbose() string {
e.Resp.Request = nil // This brings in a bunch of TLS crap we don't need
Expand Down
5 changes: 2 additions & 3 deletions apps/internal/base/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ import (

"github.com/AzureAD/microsoft-authentication-library-for-go/apps/cache"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/errors"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/base/internal/storage"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/base/storage"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/accesstokens"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority"
Expand Down Expand Up @@ -113,7 +113,6 @@ func AuthResultFromStorage(storageTokenResponse storage.TokenResponse) (AuthResu
if err := storageTokenResponse.AccessToken.Validate(); err != nil {
return AuthResult{}, fmt.Errorf("problem with access token in StorageTokenResponse: %w", err)
}

account := storageTokenResponse.Account
accessToken := storageTokenResponse.AccessToken.Secret
grantedScopes := strings.Split(storageTokenResponse.AccessToken.Scopes, scopeSeparator)
Expand Down Expand Up @@ -149,7 +148,7 @@ func NewAuthResult(tokenResponse accesstokens.TokenResponse, account shared.Acco
Account: account,
IDToken: tokenResponse.IDToken,
AccessToken: tokenResponse.AccessToken,
ExpiresOn: tokenResponse.ExpiresOn.T,
ExpiresOn: tokenResponse.ExpiresOn,
GrantedScopes: tokenResponse.GrantedScopes.Slice,
Metadata: AuthResultMetadata{
TokenSource: IdentityProvider,
Expand Down
Loading

0 comments on commit 87e17fa

Please sign in to comment.