Skip to content

Commit

Permalink
Updated the code
Browse files Browse the repository at this point in the history
  • Loading branch information
4gust committed Feb 21, 2025
1 parent e6a3b29 commit 76218c5
Show file tree
Hide file tree
Showing 7 changed files with 162 additions and 211 deletions.
178 changes: 89 additions & 89 deletions apps/confidential/confidential_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -827,102 +827,102 @@ func TestRefreshInMultipleRequests(t *testing.T) {
refreshIn := 43200
expiresIn := 86400

t.Run("Test for refresh multiple request", func(t *testing.T) {
originalTime := base.GetCurrentTime
defer func() {
base.GetCurrentTime = originalTime
}()
// Create a mock client and append mock responses
mockClient := mock.SyncClient{}
mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, "firstTenant")))
mockClient.AppendResponse(
mock.WithBody([]byte(fmt.Sprintf(`{"access_token":%q,"expires_in":%d,"refresh_in":%d,"token_type":"Bearer"}`, firstToken, expiresIn, refreshIn))),
)
mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, "secondTenant")))
mockClient.AppendResponse(
mock.WithBody([]byte(fmt.Sprintf(`{"access_token":%q,"expires_in":%d,"refresh_in":%d,"token_type":"Bearer"}`, firstToken, expiresIn, refreshIn))),
)
// Create the client instance
client, err := New(fmt.Sprintf(authorityFmt, lmo, "firstTenant"), fakeClientID, cred, WithHTTPClient(&mockClient), WithInstanceDiscovery(false))
if err != nil {
t.Fatal(err)
}
// Acquire the first token for first tenant
ar, err := client.AcquireTokenByCredential(context.Background(), tokenScope, WithTenantID("firstTenant"))
if err != nil {
t.Fatal(err)
}
// Assert the first token is returned
if ar.AccessToken != firstToken {
t.Fatalf("wanted %q, got %q", firstToken, ar.AccessToken)
}
originalTime := base.GetCurrentTime
defer func() {
base.GetCurrentTime = originalTime
}()
// Create a mock client and append mock responses
mockClient := mock.Client{}
mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, "firstTenant")))
mockClient.AppendResponse(
mock.WithBody([]byte(fmt.Sprintf(`{"access_token":%q,"expires_in":%d,"refresh_in":%d,"token_type":"Bearer"}`, firstToken, expiresIn, refreshIn))),
)
mockClient.AppendResponse(mock.WithBody(mock.GetTenantDiscoveryBody(lmo, "secondTenant")))
mockClient.AppendResponse(
mock.WithBody([]byte(fmt.Sprintf(`{"access_token":%q,"expires_in":%d,"refresh_in":%d,"token_type":"Bearer"}`, firstToken, expiresIn, refreshIn))),
)
// Create the client instance
client, err := New(fmt.Sprintf(authorityFmt, lmo, "firstTenant"), fakeClientID, cred, WithHTTPClient(&mockClient), WithInstanceDiscovery(false))
if err != nil {
t.Fatal(err)
}
// Acquire the first token for first tenant
ar, err := client.AcquireTokenByCredential(context.Background(), tokenScope, WithTenantID("firstTenant"))
if err != nil {
t.Fatal(err)
}
// Assert the first token is returned
if ar.AccessToken != firstToken {
t.Fatalf("wanted %q, got %q", firstToken, ar.AccessToken)
}
// Acquire the first token for second tenant
arSecond, err := client.AcquireTokenByCredential(context.Background(), tokenScope, WithTenantID("secondTenant"))
if err != nil {
t.Fatal(err)
}
if arSecond.AccessToken != firstToken {
t.Fatalf("wanted %q, got %q", firstToken, arSecond.AccessToken)
}
fixedTime := time.Now().Add(time.Duration(43400) * time.Second)
base.GetCurrentTime = func() time.Time {
return fixedTime
}
var wg sync.WaitGroup
done := make(chan struct{})

// Acquire the first token for second tenant
arSecond, err := client.AcquireTokenByCredential(context.Background(), tokenScope, WithTenantID("secondTenant"))
if err != nil {
t.Fatal(err)
}
// Assert the first token is returned
if arSecond.AccessToken != firstToken {
t.Fatalf("wanted %q, got %q", firstToken, arSecond.AccessToken)
}
firstTenantChecker := false
secondTenantChecker := false
mockClient.AppendResponse(
mock.WithBody([]byte(fmt.Sprintf(`{"access_token":%q,"expires_in":%d,"refresh_in":%d,"token_type":"Bearer"}`, secondToken+"firstTenant", expiresIn, refreshIn+44200))),
)
mockClient.AppendResponse(
mock.WithBody([]byte(fmt.Sprintf(`{"access_token":%q,"expires_in":%d,"refresh_in":%d,"token_type":"Bearer"}`, secondToken+"secondTenant", expiresIn, refreshIn+44200))))

fixedTime := time.Now().Add(time.Duration(43400) * time.Second)
base.GetCurrentTime = func() time.Time {
return fixedTime
}
var wg sync.WaitGroup
firstTenantChecker := false
secondTenantChecker := false

mockClient.AppendResponse(
mock.WithBody([]byte(fmt.Sprintf(`{"access_token":%q,"expires_in":%d,"refresh_in":%d,"token_type":"Bearer"}`, secondToken+"firstTenant", expiresIn, refreshIn+44200))),
)
mockClient.AppendResponse(
mock.WithBody([]byte(fmt.Sprintf(`{"access_token":%q,"expires_in":%d,"refresh_in":%d,"token_type":"Bearer"}`, secondToken+"secondTenant", expiresIn, refreshIn+44200))))

for i := 0; i < 10000; i++ {
wg.Add(2)
go func() {
defer wg.Done()
ar, err := client.AcquireTokenSilent(context.Background(), tokenScope, WithTenantID("firstTenant"))
if err != nil {
t.Error(err)
return
}
if ar.AccessToken == secondToken+"firstTenant" && ar.Metadata.TokenSource == base.IdentityProvider {
if firstTenantChecker {
t.Error("Error can only call this once")
} else {
firstTenantChecker = true
}
}
}()
go func() {
defer wg.Done()
ar, err := client.AcquireTokenSilent(context.Background(), tokenScope, WithTenantID("secondTenant"))
if err != nil {
t.Error(err)
return
}
if ar.AccessToken == secondToken+"secondTenant" && ar.Metadata.TokenSource == base.IdentityProvider {
if secondTenantChecker {
t.Error("Error can only call this once")
} else {
secondTenantChecker = true
}
for i := 0; i < 10000; i++ {
wg.Add(2)
go func() {
defer wg.Done()
ar, err := client.AcquireTokenSilent(context.Background(), tokenScope, WithTenantID("firstTenant"))
if err != nil {
t.Error(err)
return
}
if ar.AccessToken == secondToken+"firstTenant" && ar.Metadata.TokenSource == base.IdentityProvider {
if firstTenantChecker {
t.Error("Error can only call this once")
} else {
firstTenantChecker = true
}
}()
}
// Waiting for all goroutines to finish
}
}()
go func() {
wg.Wait()
if !secondTenantChecker && !firstTenantChecker {
t.Error("Error should be called at least once")
defer wg.Done()
ar, err := client.AcquireTokenSilent(context.Background(), tokenScope, WithTenantID("secondTenant"))
if err != nil {
t.Error(err)
return
}
if ar.AccessToken == secondToken+"secondTenant" && ar.Metadata.TokenSource == base.IdentityProvider {
if secondTenantChecker {
t.Error("Error can only call this once")
} else {
secondTenantChecker = true
}
}
}()
}
// Wait for all goroutines in a separate goroutine
go func() {
wg.Wait()
})
close(done)
}()

// Wait for all goroutines to complete
<-done

if !secondTenantChecker && !firstTenantChecker {
t.Error("Error should be called at least once")
}

}

Expand Down
18 changes: 9 additions & 9 deletions apps/internal/base/base.go
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ type Client struct {
cacheAccessor cache.ExportReplace
cacheAccessorMu *sync.RWMutex
canRefresh map[string]*atomic.Value
refreshMu *sync.RWMutex
refreshMu *sync.Mutex
}

// Option is an optional argument to the New constructor.
Expand Down Expand Up @@ -247,7 +247,7 @@ func New(clientID string, authorityURI string, token *oauth.Client, options ...O
manager: storage.New(token),
pmanager: storage.NewPartitionedManager(token),
canRefresh: make(map[string]*atomic.Value),
refreshMu: &sync.RWMutex{},
refreshMu: &sync.Mutex{},
}
for _, o := range options {
if err = o(&client); err != nil {
Expand Down Expand Up @@ -352,14 +352,14 @@ func (b Client) AcquireTokenSilent(ctx context.Context, silent AcquireTokenSilen
if silent.Claims == "" {
ar, err = AuthResultFromStorage(storageTokenResponse)
if err == nil {
if b.shouldRefresh(storageTokenResponse.AccessToken.RefreshOn.T) {
if shouldRefresh(storageTokenResponse.AccessToken.RefreshOn.T) {
b.refreshMu.Lock()
if _, exists := b.canRefresh[tenant]; !exists {
var empty atomic.Value
empty.Store(false)
b.canRefresh[tenant] = &empty
refreshValue, exists := b.canRefresh[tenant]
if !exists {
refreshValue = &atomic.Value{}
refreshValue.Store(false)
b.canRefresh[tenant] = refreshValue
}
refreshValue := b.canRefresh[tenant]
b.refreshMu.Unlock()
if refreshValue.CompareAndSwap(false, true) {
defer refreshValue.Store(false)
Expand Down Expand Up @@ -483,7 +483,7 @@ func (b Client) AuthResultFromToken(ctx context.Context, authParams authority.Au
var GetCurrentTime = time.Now

// shouldRefresh returns true if the token should be refreshed.
func (b *Client) shouldRefresh(t time.Time) bool {
func shouldRefresh(t time.Time) bool {
return !t.IsZero() && t.Before(GetCurrentTime())
}

Expand Down
3 changes: 1 addition & 2 deletions apps/internal/base/base_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -448,7 +448,6 @@ func TestAuthResultFromStorage(t *testing.T) {
func TestShouldRefresh(t *testing.T) {
// Get the current time to use for comparison
now := time.Now()
client := fakeClient(t)
tests := []struct {
name string
input time.Time
Expand All @@ -473,7 +472,7 @@ func TestShouldRefresh(t *testing.T) {

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := client.shouldRefresh(tt.input)
result := shouldRefresh(tt.input)
if result != tt.expected {
t.Errorf("shouldRefresh(%v) = %v; expected %v", tt.input, result, tt.expected)
}
Expand Down
6 changes: 6 additions & 0 deletions apps/internal/mock/mock.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"io"
"net/http"
"strings"
"sync"
"time"

"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority"
Expand Down Expand Up @@ -62,10 +63,13 @@ func WithHTTPStatusCode(statusCode int) responseOption {

// Client is a mock HTTP client that returns a sequence of responses. Use AppendResponse to specify the sequence.
type Client struct {
mu sync.Mutex
resp []response
}

func (c *Client) AppendResponse(opts ...responseOption) {
c.mu.Lock()
defer c.mu.Unlock()
r := response{code: http.StatusOK, headers: http.Header{}}
for _, o := range opts {
o.apply(&r)
Expand All @@ -74,6 +78,8 @@ func (c *Client) AppendResponse(opts ...responseOption) {
}

func (c *Client) Do(req *http.Request) (*http.Response, error) {
c.mu.Lock()
defer c.mu.Unlock()
if len(c.resp) == 0 {
panic(fmt.Sprintf(`no response for "%s"`, req.URL.String()))
}
Expand Down
49 changes: 0 additions & 49 deletions apps/internal/mock/syncmock.go

This file was deleted.

2 changes: 1 addition & 1 deletion apps/managedidentity/managedidentity.go
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ func New(id ID, options ...ClientOption) (Client, error) {
default:
return Client{}, fmt.Errorf("unsupported type %T", id)
}
var zero atomic.Value = atomic.Value{}
zero := atomic.Value{}
zero.Store(false)
client := Client{
miType: id,
Expand Down
Loading

0 comments on commit 76218c5

Please sign in to comment.