diff --git a/apps/confidential/confidential_test.go b/apps/confidential/confidential_test.go index c91731ca..43bea10e 100644 --- a/apps/confidential/confidential_test.go +++ b/apps/confidential/confidential_test.go @@ -10,7 +10,6 @@ import ( "crypto/x509" "encoding/base64" "encoding/json" - "errors" "fmt" "io" "net/http" @@ -25,6 +24,7 @@ import ( "github.com/kylelemons/godebug/pretty" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/cache" + "github.com/AzureAD/microsoft-authentication-library-for-go/apps/errors" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/exported" internalTime "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/json/types/time" "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/mock" @@ -413,6 +413,41 @@ func TestAcquireTokenByAuthCode(t *testing.T) { } } +func TestInvalidJsonErrFromResponse(t *testing.T) { + cred, err := NewCredFromSecret(fakeSecret) + if err != nil { + t.Fatal(err) + } + tenant := "A" + lmo := "login.microsoftonline.com" + mockClient := mock.Client{} + mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(lmo, tenant))) + client, err := New(fmt.Sprintf(authorityFmt, lmo, tenant), fakeClientID, cred, WithHTTPClient(&mockClient)) + if err != nil { + t.Fatal(err) + } + ctx := context.Background() + // cache an access token for each tenant. To simplify determining their provenance below, the value of each token is the ID of the tenant that provided it. + if _, err = client.AcquireTokenSilent(ctx, tokenScope, WithTenantID(tenant)); err == nil { + t.Fatal("silent auth should fail because the cache is empty") + } + mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, tenant))) + body := fmt.Sprintf( + `{"access_token": "%s","expires_in": %d,"expires_on": %d,"token_type": "Bearer"`, + tenant, 3600, time.Now().Add(time.Duration(3600)*time.Second).Unix(), + ) + // body += "}" + mockClient.AppendResponse(mock.WithBody([]byte(body))) + _, err = client.AcquireTokenByCredential(ctx, tokenScope, WithTenantID(tenant)) + if err == nil { + t.Fatal("should have failed with InvalidJsonErr Response") + } + var ie errors.InvalidJsonErr + if !errors.As(err, &ie) { + t.Fatal("should have revieved a InvalidJsonErr, but got", err) + } +} + func TestAcquireTokenSilentTenants(t *testing.T) { cred, err := NewCredFromSecret(fakeSecret) if err != nil { diff --git a/apps/errors/errors.go b/apps/errors/errors.go index c9b8dbed..b5cbb572 100644 --- a/apps/errors/errors.go +++ b/apps/errors/errors.go @@ -64,11 +64,20 @@ type CallErr struct { Err error } +type InvalidJsonErr struct { + Err error +} + // Errors implements error.Error(). func (e CallErr) Error() string { return e.Err.Error() } +// Errors implements error.Error(). +func (e InvalidJsonErr) Error() string { + return e.Err.Error() +} + // Verbose prints a versbose error message with the request or response. func (e CallErr) Verbose() string { e.Resp.Request = nil // This brings in a bunch of TLS crap we don't need diff --git a/apps/internal/oauth/ops/accesstokens/accesstokens.go b/apps/internal/oauth/ops/accesstokens/accesstokens.go index a7b7b074..71275b32 100644 --- a/apps/internal/oauth/ops/accesstokens/accesstokens.go +++ b/apps/internal/oauth/ops/accesstokens/accesstokens.go @@ -262,11 +262,7 @@ func (c Client) FromClientSecret(ctx context.Context, authParameters authority.A qv.Set(clientID, authParameters.ClientID) addScopeQueryParam(qv, authParameters) - token, err := c.doTokenResp(ctx, authParameters, qv) - if err != nil { - return token, fmt.Errorf("FromClientSecret(): %w", err) - } - return token, nil + return c.doTokenResp(ctx, authParameters, qv) } func (c Client) FromAssertion(ctx context.Context, authParameters authority.AuthParams, assertion string) (TokenResponse, error) { @@ -281,11 +277,7 @@ func (c Client) FromAssertion(ctx context.Context, authParameters authority.Auth qv.Set(clientInfo, clientInfoVal) addScopeQueryParam(qv, authParameters) - token, err := c.doTokenResp(ctx, authParameters, qv) - if err != nil { - return token, fmt.Errorf("FromAssertion(): %w", err) - } - return token, nil + return c.doTokenResp(ctx, authParameters, qv) } func (c Client) FromUserAssertionClientSecret(ctx context.Context, authParameters authority.AuthParams, userAssertion string, clientSecret string) (TokenResponse, error) { diff --git a/apps/internal/oauth/ops/internal/comm/comm.go b/apps/internal/oauth/ops/internal/comm/comm.go index d62aac74..79068036 100644 --- a/apps/internal/oauth/ops/internal/comm/comm.go +++ b/apps/internal/oauth/ops/internal/comm/comm.go @@ -98,7 +98,7 @@ func (c *Client) JSONCall(ctx context.Context, endpoint string, headers http.Hea if resp != nil { if err := unmarshal(data, resp); err != nil { - return fmt.Errorf("json decode error: %w\njson message bytes were: %s", err, string(data)) + return errors.InvalidJsonErr{Err: fmt.Errorf("json decode error: %w\njson message bytes were: %s", err, string(data))} } } return nil @@ -221,7 +221,7 @@ func (c *Client) URLFormCall(ctx context.Context, endpoint string, qv url.Values } if resp != nil { if err := unmarshal(data, resp); err != nil { - return fmt.Errorf("json decode error: %w\nraw message was: %s", err, string(data)) + return errors.InvalidJsonErr{Err: fmt.Errorf("json decode error: %w\nraw message was: %s", err, string(data))} } } return nil diff --git a/apps/managedidentity/managedidentity.go b/apps/managedidentity/managedidentity.go index aea4c1a2..4e274869 100644 --- a/apps/managedidentity/managedidentity.go +++ b/apps/managedidentity/managedidentity.go @@ -523,6 +523,11 @@ func (c Client) getTokenForRequest(req *http.Request, resource string) (accessto } err = json.Unmarshal(responseBytes, &r) + if err != nil { + return r, errors.InvalidJsonErr{ + Err: fmt.Errorf("error parsing the json error: %s", err), + } + } r.GrantedScopes.Slice = append(r.GrantedScopes.Slice, resource) return r, err diff --git a/apps/managedidentity/managedidentity_test.go b/apps/managedidentity/managedidentity_test.go index 03923190..ce9e8cfc 100644 --- a/apps/managedidentity/managedidentity_test.go +++ b/apps/managedidentity/managedidentity_test.go @@ -468,6 +468,39 @@ func TestIMDSAcquireTokenReturnsTokenSuccess(t *testing.T) { } } +func TestInvalidJsonErrReturnOnAcquireToken(t *testing.T) { + resource := "https://resource/.default" + miType := SystemAssigned() + + setEnvVars(t, DefaultToIMDS) + mockClient := mock.Client{} + responseBody := fmt.Sprintf( + `{"access_token": "%s","expires_in": %d,"expires_on": %d,"token_type": "Bearer"`, + "tetant", 3600, time.Now().Add(time.Duration(3600)*time.Second).Unix(), + ) + + mockClient.AppendResponse(mock.WithHTTPStatusCode(http.StatusOK), mock.WithBody([]byte(responseBody))) + + // Resetting cache + before := cacheManager + defer func() { cacheManager = before }() + cacheManager = storage.New(nil) + + client, err := New(miType, WithHTTPClient(&mockClient)) + if err != nil { + t.Fatal(err) + } + + _, err = client.AcquireToken(context.Background(), resource) + if err == nil { + t.Fatal("should have failed with InvalidJsonErr Response") + } + var ie errors.InvalidJsonErr + if !errors.As(err, &ie) { + t.Fatal("should have revieved a InvalidJsonErr, but got", err) + } +} + func TestCloudShellAcquireTokenReturnsTokenSuccess(t *testing.T) { resource := "https://resource/.default" miType := SystemAssigned()