Skip to content

Commit

Permalink
Added a new error InvalidJsonErr
Browse files Browse the repository at this point in the history
  • Loading branch information
4gust committed Feb 14, 2025
1 parent f039f60 commit 9385bab
Show file tree
Hide file tree
Showing 6 changed files with 87 additions and 13 deletions.
37 changes: 36 additions & 1 deletion apps/confidential/confidential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"crypto/x509"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
Expand All @@ -25,6 +24,7 @@ import (
"github.com/kylelemons/godebug/pretty"

"github.com/AzureAD/microsoft-authentication-library-for-go/apps/cache"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/errors"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/exported"
internalTime "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/json/types/time"
"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/mock"
Expand Down Expand Up @@ -413,6 +413,41 @@ func TestAcquireTokenByAuthCode(t *testing.T) {
}
}

func TestInvalidJsonErrFromResponse(t *testing.T) {
cred, err := NewCredFromSecret(fakeSecret)
if err != nil {
t.Fatal(err)
}
tenant := "A"
lmo := "login.microsoftonline.com"
mockClient := mock.Client{}
mockClient.AppendResponse(mock.WithBody(mock.GetInstanceDiscoveryBody(lmo, tenant)))
client, err := New(fmt.Sprintf(authorityFmt, lmo, tenant), fakeClientID, cred, WithHTTPClient(&mockClient))
if err != nil {
t.Fatal(err)
}
ctx := context.Background()
// cache an access token for each tenant. To simplify determining their provenance below, the value of each token is the ID of the tenant that provided it.
if _, err = client.AcquireTokenSilent(ctx, tokenScope, WithTenantID(tenant)); err == nil {
t.Fatal("silent auth should fail because the cache is empty")
}
mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, tenant)))
body := fmt.Sprintf(
`{"access_token": "%s","expires_in": %d,"expires_on": %d,"token_type": "Bearer"`,
tenant, 3600, time.Now().Add(time.Duration(3600)*time.Second).Unix(),
)
// body += "}"
mockClient.AppendResponse(mock.WithBody([]byte(body)))
_, err = client.AcquireTokenByCredential(ctx, tokenScope, WithTenantID(tenant))
if err == nil {
t.Fatal("should have failed with InvalidJsonErr Response")
}
var ie errors.InvalidJsonErr
if !errors.As(err, &ie) {
t.Fatal("should have revieved a InvalidJsonErr, but got", err)
}
}

func TestAcquireTokenSilentTenants(t *testing.T) {
cred, err := NewCredFromSecret(fakeSecret)
if err != nil {
Expand Down
9 changes: 9 additions & 0 deletions apps/errors/errors.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,11 +64,20 @@ type CallErr struct {
Err error
}

type InvalidJsonErr struct {
Err error
}

// Errors implements error.Error().
func (e CallErr) Error() string {
return e.Err.Error()
}

// Errors implements error.Error().
func (e InvalidJsonErr) Error() string {
return e.Err.Error()
}

// Verbose prints a versbose error message with the request or response.
func (e CallErr) Verbose() string {
e.Resp.Request = nil // This brings in a bunch of TLS crap we don't need
Expand Down
12 changes: 2 additions & 10 deletions apps/internal/oauth/ops/accesstokens/accesstokens.go
Original file line number Diff line number Diff line change
Expand Up @@ -262,11 +262,7 @@ func (c Client) FromClientSecret(ctx context.Context, authParameters authority.A
qv.Set(clientID, authParameters.ClientID)
addScopeQueryParam(qv, authParameters)

token, err := c.doTokenResp(ctx, authParameters, qv)
if err != nil {
return token, fmt.Errorf("FromClientSecret(): %w", err)
}
return token, nil
return c.doTokenResp(ctx, authParameters, qv)
}

func (c Client) FromAssertion(ctx context.Context, authParameters authority.AuthParams, assertion string) (TokenResponse, error) {
Expand All @@ -281,11 +277,7 @@ func (c Client) FromAssertion(ctx context.Context, authParameters authority.Auth
qv.Set(clientInfo, clientInfoVal)
addScopeQueryParam(qv, authParameters)

token, err := c.doTokenResp(ctx, authParameters, qv)
if err != nil {
return token, fmt.Errorf("FromAssertion(): %w", err)
}
return token, nil
return c.doTokenResp(ctx, authParameters, qv)
}

func (c Client) FromUserAssertionClientSecret(ctx context.Context, authParameters authority.AuthParams, userAssertion string, clientSecret string) (TokenResponse, error) {
Expand Down
4 changes: 2 additions & 2 deletions apps/internal/oauth/ops/internal/comm/comm.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ func (c *Client) JSONCall(ctx context.Context, endpoint string, headers http.Hea

if resp != nil {
if err := unmarshal(data, resp); err != nil {
return fmt.Errorf("json decode error: %w\njson message bytes were: %s", err, string(data))
return errors.InvalidJsonErr{Err: fmt.Errorf("json decode error: %w\njson message bytes were: %s", err, string(data))}
}
}
return nil
Expand Down Expand Up @@ -221,7 +221,7 @@ func (c *Client) URLFormCall(ctx context.Context, endpoint string, qv url.Values
}
if resp != nil {
if err := unmarshal(data, resp); err != nil {
return fmt.Errorf("json decode error: %w\nraw message was: %s", err, string(data))
return errors.InvalidJsonErr{Err: fmt.Errorf("json decode error: %w\nraw message was: %s", err, string(data))}
}
}
return nil
Expand Down
5 changes: 5 additions & 0 deletions apps/managedidentity/managedidentity.go
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,11 @@ func (c Client) getTokenForRequest(req *http.Request, resource string) (accessto
}

err = json.Unmarshal(responseBytes, &r)
if err != nil {
return r, errors.InvalidJsonErr{
Err: fmt.Errorf("error parsing the json error: %s", err),
}
}
r.GrantedScopes.Slice = append(r.GrantedScopes.Slice, resource)

return r, err
Expand Down
33 changes: 33 additions & 0 deletions apps/managedidentity/managedidentity_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,39 @@ func TestIMDSAcquireTokenReturnsTokenSuccess(t *testing.T) {
}
}

func TestInvalidJsonErrReturnOnAcquireToken(t *testing.T) {
resource := "https://resource/.default"
miType := SystemAssigned()

setEnvVars(t, DefaultToIMDS)
mockClient := mock.Client{}
responseBody := fmt.Sprintf(
`{"access_token": "%s","expires_in": %d,"expires_on": %d,"token_type": "Bearer"`,
"tetant", 3600, time.Now().Add(time.Duration(3600)*time.Second).Unix(),
)

mockClient.AppendResponse(mock.WithHTTPStatusCode(http.StatusOK), mock.WithBody([]byte(responseBody)))

// Resetting cache
before := cacheManager
defer func() { cacheManager = before }()
cacheManager = storage.New(nil)

client, err := New(miType, WithHTTPClient(&mockClient))
if err != nil {
t.Fatal(err)
}

_, err = client.AcquireToken(context.Background(), resource)
if err == nil {
t.Fatal("should have failed with InvalidJsonErr Response")
}
var ie errors.InvalidJsonErr
if !errors.As(err, &ie) {
t.Fatal("should have revieved a InvalidJsonErr, but got", err)
}
}

func TestCloudShellAcquireTokenReturnsTokenSuccess(t *testing.T) {
resource := "https://resource/.default"
miType := SystemAssigned()
Expand Down

0 comments on commit 9385bab

Please sign in to comment.