diff --git a/.github/workflows/codeql.yml b/.github/workflows/codeql.yml index 80ee3a988..f19e18d73 100644 --- a/.github/workflows/codeql.yml +++ b/.github/workflows/codeql.yml @@ -57,6 +57,12 @@ jobs: # Details on CodeQL's query packs refer to : https://docs.github.com/en/code-security/code-scanning/automatically-scanning-your-code-for-vulnerabilities-and-errors/configuring-code-scanning#using-queries-in-ql-packs # queries: security-extended,security-and-quality + - name: Setup dotnet + uses: actions/setup-dotnet@v4 + with: + dotnet-version: | + 8 + 9 # Autobuild attempts to build any compiled languages (C/C++, C#, Go, or Java). # If this step fails, then you should remove it and run the build manually (see below) diff --git a/Microsoft.GA4GH.TES.sln b/Microsoft.GA4GH.TES.sln index d51f43054..e56cfc43b 100644 --- a/Microsoft.GA4GH.TES.sln +++ b/Microsoft.GA4GH.TES.sln @@ -49,6 +49,8 @@ Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tes.SDK", "src\Tes.SDK\Tes. EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tes.SDK.Tests", "src\Tes.SDK.Tests\Tes.SDK.Tests.csproj", "{AE7ADB92-BEC6-4030-B62F-BDBB6AC53CB4}" EndProject +Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tes.Repository", "src\Tes.Repository\Tes.Repository.csproj", "{515A4905-0522-4C72-BC18-41BE6A3BE880}" +EndProject Project("{9A19103F-16F7-4668-BE54-9A1E7A4F7556}") = "Tes.SDK.Examples", "src\Tes.SDK.Examples\Tes.SDK.Examples.csproj", "{08A30572-2C5A-4F61-AF77-36F624A6020B}" EndProject Project("{FAE04EC0-301F-11D3-BF4B-00C04F79EFBC}") = "build-push-acr", "src\build-push-acr\build-push-acr.csproj", "{9DD148A9-86DA-4BB4-886C-1D29E11A0BB3}" @@ -119,6 +121,10 @@ Global {AE7ADB92-BEC6-4030-B62F-BDBB6AC53CB4}.Debug|Any CPU.Build.0 = Debug|Any CPU {AE7ADB92-BEC6-4030-B62F-BDBB6AC53CB4}.Release|Any CPU.ActiveCfg = Release|Any CPU {AE7ADB92-BEC6-4030-B62F-BDBB6AC53CB4}.Release|Any CPU.Build.0 = Release|Any CPU + {515A4905-0522-4C72-BC18-41BE6A3BE880}.Debug|Any CPU.ActiveCfg = Debug|Any CPU + {515A4905-0522-4C72-BC18-41BE6A3BE880}.Debug|Any CPU.Build.0 = Debug|Any CPU + {515A4905-0522-4C72-BC18-41BE6A3BE880}.Release|Any CPU.ActiveCfg = Release|Any CPU + {515A4905-0522-4C72-BC18-41BE6A3BE880}.Release|Any CPU.Build.0 = Release|Any CPU {08A30572-2C5A-4F61-AF77-36F624A6020B}.Debug|Any CPU.ActiveCfg = Debug|Any CPU {08A30572-2C5A-4F61-AF77-36F624A6020B}.Debug|Any CPU.Build.0 = Debug|Any CPU {08A30572-2C5A-4F61-AF77-36F624A6020B}.Release|Any CPU.ActiveCfg = Release|Any CPU diff --git a/src/CommonUtilities.Tests/ArmEnvironmentEndpointsTests.cs b/src/CommonUtilities.Tests/ArmEnvironmentEndpointsTests.cs index 85033b6c7..2b054a40d 100644 --- a/src/CommonUtilities.Tests/ArmEnvironmentEndpointsTests.cs +++ b/src/CommonUtilities.Tests/ArmEnvironmentEndpointsTests.cs @@ -147,22 +147,17 @@ private static bool Equals(IReadOnlyDictionary x, T y) [DataRow("AzureChinaCloud", "https://management.chinacloudapi.cn/.default", DisplayName = "AzureChinaCloud")] public async Task FromKnownCloudNameAsync_ExpectedDefaultTokenScope(string cloud, string audience) { - var environment = await AzureCloudConfig.FromKnownCloudNameAsync(cloudName: cloud, retryPolicyOptions: Microsoft.Extensions.Options.Options.Create(new Options.RetryPolicyOptions())); + var environment = await SkipWhenTimeout(AzureCloudConfig.FromKnownCloudNameAsync(cloudName: cloud, retryPolicyOptions: Microsoft.Extensions.Options.Options.Create(new Options.RetryPolicyOptions()))); Assert.AreEqual(audience, GetPropertyFromEnvironment(environment, nameof(AzureCloudConfig.DefaultTokenScope))); } - private static T? GetPropertyFromEnvironment(AzureCloudConfig environment, string property) - { - return (T?)environment.GetType().GetProperty(property)?.GetValue(environment); - } - [DataTestMethod] [DataRow(Cloud.Public, "AzureCloud", DisplayName = "All generally available global Azure regions")] [DataRow(Cloud.USGovernment, "AzureUSGovernment", DisplayName = "Azure Government")] [DataRow(Cloud.China, "AzureChinaCloud", DisplayName = "Microsoft Azure operated by 21Vianet")] public async Task FromKnownCloudNameAsync_ExpectedValues(Cloud cloud, string cloudName) { - var environment = await AzureCloudConfig.FromKnownCloudNameAsync(cloudName: cloudName, retryPolicyOptions: Microsoft.Extensions.Options.Options.Create(new Options.RetryPolicyOptions())); + var environment = await SkipWhenTimeout(AzureCloudConfig.FromKnownCloudNameAsync(cloudName: cloudName, retryPolicyOptions: Microsoft.Extensions.Options.Options.Create(new Options.RetryPolicyOptions()))); foreach (var (property, value) in CloudEndpoints[cloud]) { switch (value) @@ -189,5 +184,23 @@ public async Task FromKnownCloudNameAsync_ExpectedValues(Cloud cloud, string clo } } } + + private static T? GetPropertyFromEnvironment(AzureCloudConfig environment, string property) + { + return (T?)environment.GetType().GetProperty(property)?.GetValue(environment); + } + + private static async Task SkipWhenTimeout(Task task) + { + try + { + return await task; + } + catch (TaskCanceledException e) when (e.InnerException is TimeoutException) + { + Assert.Inconclusive(e.Message); + throw new System.Diagnostics.UnreachableException(); + } + } } } diff --git a/src/CommonUtilities/AzureServicesConnectionStringCredential.cs b/src/CommonUtilities/AzureServicesConnectionStringCredential.cs index 6a9ee7cfa..957fdf5a6 100644 --- a/src/CommonUtilities/AzureServicesConnectionStringCredential.cs +++ b/src/CommonUtilities/AzureServicesConnectionStringCredential.cs @@ -71,7 +71,7 @@ private AzureServicesConnectionStringCredentialOptions() private void SetInitialState(AzureCloudConfig armEndpoints) { - (GetEnvironmentVariable("AZURE_ADDITIONALLY_ALLOWED_TENANTS") ?? string.Empty).Split((char[]?)[';'], StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries).ForEach(AdditionallyAllowedTenants.Add); + (GetEnvironmentVariable("AZURE_ADDITIONALLY_ALLOWED_TENANTS") ?? string.Empty).Split(';', StringSplitOptions.RemoveEmptyEntries | StringSplitOptions.TrimEntries).ForEach(AdditionallyAllowedTenants.Add); TenantId = GetEnvironmentVariable("AZURE_TENANT_ID")!; AuthorityHost = armEndpoints.AuthorityHost ?? new(armEndpoints.Authentication?.LoginEndpointUrl ?? throw new ArgumentException("AuthorityHost is missing", nameof(armEndpoints))); Audience = armEndpoints.ArmEnvironment?.Audience ?? armEndpoints.Authentication?.Audiences?.LastOrDefault() ?? throw new ArgumentException("Audience is missing", nameof(armEndpoints)); @@ -113,6 +113,11 @@ private void SetInitialState(AzureCloudConfig armEndpoints) /// public bool DisableInstanceDiscovery { get; set; } + /// + /// Options controlling the storage of the token cache. + /// + public Azure.Identity.TokenCachePersistenceOptions TokenCachePersistenceOptions { get; set; } + /// /// Specifies tenants in addition to the specified for which the credential may acquire tokens. /// Add the wildcard value "*" to allow the credential to acquire tokens for any tenant the logged in account can access. @@ -134,23 +139,17 @@ private void SetInitialState(AzureCloudConfig armEndpoints) internal Azure.Identity.AzureCliCredential CreateAzureCliCredential() { - var result = new Azure.Identity.AzureCliCredentialOptions { TenantId = TenantId, AuthorityHost = AuthorityHost, IsUnsafeSupportLoggingEnabled = IsUnsafeSupportLoggingEnabled }; - CopyAdditionallyAllowedTenants(result.AdditionallyAllowedTenants); - return new(result); + return new(ConfigureOptions(new Azure.Identity.AzureCliCredentialOptions())); } internal Azure.Identity.VisualStudioCredential CreateVisualStudioCredential() { - var result = new Azure.Identity.VisualStudioCredentialOptions { TenantId = TenantId, AuthorityHost = AuthorityHost, IsUnsafeSupportLoggingEnabled = IsUnsafeSupportLoggingEnabled }; - CopyAdditionallyAllowedTenants(result.AdditionallyAllowedTenants); - return new(result); + return new(ConfigureOptions(new Azure.Identity.VisualStudioCredentialOptions())); } internal Azure.Identity.VisualStudioCodeCredential CreateVisualStudioCodeCredential() { - var result = new Azure.Identity.VisualStudioCodeCredentialOptions { TenantId = TenantId, AuthorityHost = AuthorityHost, IsUnsafeSupportLoggingEnabled = IsUnsafeSupportLoggingEnabled }; - CopyAdditionallyAllowedTenants(result.AdditionallyAllowedTenants); - return new(result); + return new(ConfigureOptions(new Azure.Identity.VisualStudioCodeCredentialOptions())); } //internal Azure.Identity.InteractiveBrowserCredential CreateInteractiveBrowserCredential() @@ -169,40 +168,112 @@ internal Azure.Identity.VisualStudioCodeCredential CreateVisualStudioCodeCredent internal Azure.Identity.ClientSecretCredential CreateClientSecretCredential(string appId, string appKey, string tenantId) { - var result = new Azure.Identity.ClientSecretCredentialOptions { AuthorityHost = AuthorityHost, IsUnsafeSupportLoggingEnabled = IsUnsafeSupportLoggingEnabled, DisableInstanceDiscovery = DisableInstanceDiscovery }; - CopyAdditionallyAllowedTenants(result.AdditionallyAllowedTenants); - return new(string.IsNullOrEmpty(tenantId) ? TenantId : tenantId, appId, appKey, result); + return new(string.IsNullOrEmpty(tenantId) ? TenantId : tenantId, appId, appKey, ConfigureOptions(new Azure.Identity.ClientSecretCredentialOptions())); } - internal Azure.Identity.ManagedIdentityCredential CreateManagedIdentityCredential(int _1, string appId) + internal Azure.Identity.ManagedIdentityCredential CreateManagedIdentityCredential(string appId) { - return new(appId, this); + return new(appId, options: this); } - internal Azure.Identity.ManagedIdentityCredential CreateManagedIdentityCredential(int _1) + internal Azure.Identity.ManagedIdentityCredential CreateManagedIdentityCredential() { - return new(options: this); + return CreateManagedIdentityCredential(null!); } internal Azure.Identity.WorkloadIdentityCredential CreateWorkloadIdentityCredential(string appId) { - Azure.Identity.WorkloadIdentityCredentialOptions result = new() { ClientId = appId, AuthorityHost = AuthorityHost, IsUnsafeSupportLoggingEnabled = IsUnsafeSupportLoggingEnabled, DisableInstanceDiscovery = DisableInstanceDiscovery, TenantId = TenantId }; - CopyAdditionallyAllowedTenants(result.AdditionallyAllowedTenants); - return new(result); + return new(ConfigureOptions(new Azure.Identity.WorkloadIdentityCredentialOptions() { ClientId = appId })); } internal Azure.Identity.WorkloadIdentityCredential CreateWorkloadIdentityCredential() { - Azure.Identity.WorkloadIdentityCredentialOptions result = new() { AuthorityHost = AuthorityHost, IsUnsafeSupportLoggingEnabled = IsUnsafeSupportLoggingEnabled, DisableInstanceDiscovery = DisableInstanceDiscovery, TenantId = TenantId }; - CopyAdditionallyAllowedTenants(result.AdditionallyAllowedTenants); - return new(result); + return new(ConfigureOptions(new Azure.Identity.WorkloadIdentityCredentialOptions())); } - void CopyAdditionallyAllowedTenants(IList additionalTenants) + // Based on https://github.com/Azure/azure-sdk-for-net/blob/main/sdk/identity/Azure.Identity/src/Credentials/TokenCredentialOptions.cs#L50 method Clone + private T ConfigureOptions(T options) where T : Azure.Identity.TokenCredentialOptions { - foreach (var tenant in AdditionallyAllowedTenants) + CopyTenantId(options); + + // copy TokenCredentialOptions Properties + options.AuthorityHost = AuthorityHost; + + options.IsUnsafeSupportLoggingEnabled = IsUnsafeSupportLoggingEnabled; + + // copy TokenCredentialDiagnosticsOptions specific options + options.Diagnostics.IsAccountIdentifierLoggingEnabled = Diagnostics.IsAccountIdentifierLoggingEnabled; + + // copy ISupportsDisableInstanceDiscovery + CopyDisableInstanceDiscovery(options); + + // copy ISupportsTokenCachePersistenceOptions + CopyTokenCachePersistenceOptions(options); + + // copy ISupportsAdditionallyAllowedTenants + CopyAdditionallyAllowedTenants(options); + + // copy base ClientOptions properties + + // only copy transport if the original has changed from the default so as not to set IsCustomTransportSet unintentionally + if (Transport != Default.Transport) { - additionalTenants.Add(tenant); + options.Transport = Transport; + } + + // clone base Diagnostic options + options.Diagnostics.ApplicationId = Diagnostics.ApplicationId; + options.Diagnostics.IsLoggingEnabled = Diagnostics.IsLoggingEnabled; + options.Diagnostics.IsTelemetryEnabled = Diagnostics.IsTelemetryEnabled; + options.Diagnostics.LoggedContentSizeLimit = Diagnostics.LoggedContentSizeLimit; + options.Diagnostics.IsDistributedTracingEnabled = Diagnostics.IsDistributedTracingEnabled; + options.Diagnostics.IsLoggingContentEnabled = Diagnostics.IsLoggingContentEnabled; + + CopyListItems(Diagnostics.LoggedHeaderNames, options.Diagnostics.LoggedHeaderNames); + CopyListItems(Diagnostics.LoggedQueryParameters, options.Diagnostics.LoggedQueryParameters); + + // clone base RetryOptions + options.RetryPolicy = RetryPolicy; + + options.Retry.MaxRetries = Retry.MaxRetries; + options.Retry.Delay = Retry.Delay; + options.Retry.MaxDelay = Retry.MaxDelay; + options.Retry.Mode = Retry.Mode; + options.Retry.NetworkTimeout = Retry.NetworkTimeout; + + return options; + } + + private static void CopyListItems(IList source, IList destination) + { + foreach (var item in source) + { + destination.Add(item); + } + } + + private void CopyTenantId(T options) where T : Azure.Identity.TokenCredentialOptions + { + options?.GetType().GetProperty(nameof(TenantId))?.SetValue(options, TenantId); + } + + private void CopyDisableInstanceDiscovery(T options) where T : Azure.Identity.TokenCredentialOptions + { + options?.GetType().GetProperty(nameof(DisableInstanceDiscovery))?.SetValue(options, DisableInstanceDiscovery); + } + + private void CopyTokenCachePersistenceOptions(T options) where T : Azure.Identity.TokenCredentialOptions + { + options?.GetType().GetProperty(nameof(TokenCachePersistenceOptions))?.SetValue(options, TokenCachePersistenceOptions); + } + + void CopyAdditionallyAllowedTenants(T options) where T : Azure.Identity.TokenCredentialOptions + { + var additionalTenants = options?.GetType().GetProperty(nameof(AdditionallyAllowedTenants))?.GetValue(options) as IList; + + if (additionalTenants is not null) + { + CopyListItems(AdditionallyAllowedTenants, additionalTenants); } } } @@ -367,30 +438,24 @@ internal static TokenCredential Create(AzureServicesConnectionStringCredentialOp } else { - ValidateMsiRetryTimeout(connectionSettings, options.ConnectionString); + ValidateAndSetMsiRetryTimeout(connectionSettings, options); // If certificate or client secret are not specified, use the specified managed identity - azureServiceTokenCredential = options.CreateManagedIdentityCredential( - connectionSettings.TryGetValue(MsiRetryTimeout, out var value) - ? int.Parse(value) - : 0, - appId); + azureServiceTokenCredential = options.CreateManagedIdentityCredential(appId); } } else { - ValidateMsiRetryTimeout(connectionSettings, options.ConnectionString); + ValidateAndSetMsiRetryTimeout(connectionSettings, options); // If AppId is not specified, use Managed Service Identity - azureServiceTokenCredential = options.CreateManagedIdentityCredential( - connectionSettings.TryGetValue(MsiRetryTimeout, out var value) - ? int.Parse(value) - : 0); + azureServiceTokenCredential = options.CreateManagedIdentityCredential(); } } else if (string.Equals(runAs, Workload, StringComparison.OrdinalIgnoreCase)) { - // If RunAs=Workload use the specified Workload Identity + // RunAs=Workload + // Use the specified Workload Identity // If AppId key is present, use it as the ClientId if (connectionSettings.TryGetValue(AppId, out var appId)) { @@ -468,7 +533,7 @@ private static void ValidateAttribute(Dictionary connectionSetti // } //} - private static void ValidateMsiRetryTimeout(Dictionary connectionSettings, string connectionString) + private static void ValidateAndSetMsiRetryTimeout(Dictionary connectionSettings, AzureServicesConnectionStringCredentialOptions options) { if (connectionSettings != null && connectionSettings.TryGetValue(MsiRetryTimeout, out var value)) { @@ -476,10 +541,13 @@ private static void ValidateMsiRetryTimeout(Dictionary connectio { var timeoutString = value; - var parseSucceeded = int.TryParse(timeoutString, out _); - if (!parseSucceeded) + if (int.TryParse(timeoutString, out var timeoutValue) && timeoutValue >= 0) + { + options.Retry.NetworkTimeout = TimeSpan.FromSeconds(timeoutValue); + } + else { - throw new ArgumentException($"Connection string '{connectionString}' is not valid. MsiRetryTimeout '{timeoutString}' is not valid. Valid values are integers greater than or equal to 0.", nameof(connectionString)); + throw new ArgumentException($"Connection string '{options.ConnectionString}' is not valid. MsiRetryTimeout '{timeoutString}' is not valid. Valid values are integers greater than or equal to 0.", nameof(options)); } } } diff --git a/src/CommonUtilities/CommonUtilities.csproj b/src/CommonUtilities/CommonUtilities.csproj index 3cda6bcda..5374eeaed 100644 --- a/src/CommonUtilities/CommonUtilities.csproj +++ b/src/CommonUtilities/CommonUtilities.csproj @@ -18,6 +18,7 @@ + diff --git a/src/CommonUtilities/Models/NodeTask.cs b/src/CommonUtilities/Models/NodeTask.cs index 307d8a04e..1d3208314 100644 --- a/src/CommonUtilities/Models/NodeTask.cs +++ b/src/CommonUtilities/Models/NodeTask.cs @@ -10,33 +10,33 @@ public class NodeTask { public string? Id { get; set; } public string? WorkflowId { get; set; } - public string? ImageTag { get; set; } - public string? ImageName { get; set; } + public List? Executors { get; set; } public List? ContainerDeviceRequests { get; set; } - public string? ContainerWorkDir { get; set; } - public List? CommandsToExecute { get; set; } - - /// Path inside the container to a file which will be piped to the executor's stdin. Must be an absolute path. - public string? ContainerStdInPath { get; set; } - - /// Path inside the container to a file where the executor's stdout will be written to. Must be an absolute path. - public string? ContainerStdOutPath { get; set; } - - /// Path inside the container to a file where the executor's stderr will be written to. Must be an absolute path. - public string? ContainerStdErrPath { get; set; } - - public Dictionary? ContainerEnv { get; set; } + public List? ContainerVolumes { get; set; } public List? Inputs { get; set; } public List? Outputs { get; set; } + public List? TaskOutputs { get; set; } public string? MetricsFilename { get; set; } public string? InputsMetricsFormat { get; set; } public string? OutputsMetricsFormat { get; set; } public List? TimestampMetricsFormats { get; set; } public List? BashScriptMetricsFormats { get; set; } - public string? MountParentDirectoryPath { get; set; } public RuntimeOptions RuntimeOptions { get; set; } = null!; } + public class Executor + { + public string? ImageTag { get; set; } + public string? ImageName { get; set; } + public string? ContainerWorkDir { get; set; } + public List? CommandsToExecute { get; set; } + public string? ContainerStdInPath { get; set; } + public string? ContainerStdOutPath { get; set; } + public string? ContainerStdErrPath { get; set; } + public Dictionary? ContainerEnv { get; set; } + public bool IgnoreError { get; set; } + } + public class ContainerDeviceRequest { public string? Driver { get; set; } @@ -80,6 +80,7 @@ public class RuntimeOptions public AzureEnvironmentConfig? AzureEnvironmentConfig { get; set; } public bool? SetContentMd5OnUpload { get; set; } + public string? MountParentDirectoryPath { get; set; } } public class StorageTargetLocation diff --git a/src/CommonUtilities/PagedInterfaceExtensions.cs b/src/CommonUtilities/PagedInterfaceExtensions.cs index 3230bb56a..2042211bc 100644 --- a/src/CommonUtilities/PagedInterfaceExtensions.cs +++ b/src/CommonUtilities/PagedInterfaceExtensions.cs @@ -69,10 +69,10 @@ public static IAsyncEnumerable ExecuteWithRetryAsync(this AsyncRetryHandle private sealed class PollyAsyncEnumerable : IAsyncEnumerable { private readonly IAsyncEnumerable _source; - private readonly RetryHandler.AsyncRetryHandlerPolicy _retryPolicy; + private readonly AsyncRetryHandlerPolicy _retryPolicy; private readonly Polly.Context _ctx; - public PollyAsyncEnumerable(IAsyncEnumerable source, RetryHandler.AsyncRetryHandlerPolicy retryPolicy, Polly.Context ctx) + public PollyAsyncEnumerable(IAsyncEnumerable source, AsyncRetryHandlerPolicy retryPolicy, Polly.Context ctx) { ArgumentNullException.ThrowIfNull(source); ArgumentNullException.ThrowIfNull(retryPolicy); diff --git a/src/CommonUtilities/RetryHandler.cs b/src/CommonUtilities/RetryHandler.cs index da741fbd2..0d02ba7fa 100644 --- a/src/CommonUtilities/RetryHandler.cs +++ b/src/CommonUtilities/RetryHandler.cs @@ -6,7 +6,7 @@ namespace CommonUtilities; /// -/// Utility class that facilitates the retry policy implementations for HTTP clients. +/// Utility class that facilitates the retry policy implementations for HTTP clients. /// public static class RetryHandler { @@ -14,11 +14,16 @@ public static class RetryHandler /// Polly Context key for caller method name /// public const string CallerMemberNameKey = $"Tes.ApiClients.{nameof(RetryHandler)}.CallerMemberName"; + /// /// Polly Context key for backup skip increment setting /// public const string BackupSkipProvidedIncrementKey = $"Tes.ApiClients.{nameof(RetryHandler)}.BackupSkipProvidedIncrementCount"; + /// Polly Context key combined sleep method and enumerable duration policies + /// + public const string CombineSleepDurationsKey = $"Tes.ApiClients.{nameof(RetryHandler)}.CombineSleepDurations"; + #region RetryHandlerPolicies /// /// Non-generic synchronous retry policy @@ -100,7 +105,6 @@ public AsyncRetryHandlerPolicy(IAsyncPolicy retryPolicy) /// For mocking public AsyncRetryHandlerPolicy() { } - /// /// Executes a delegate with the configured async policy. /// @@ -156,6 +160,20 @@ public virtual Task ExecuteWithRetryAsync(Func action(ct), PrepareContext(caller), cancellationToken); } + + /// + /// Executes the specified asynchronous action within the policy and returns the captured result. + /// + /// The action to perform. + /// A cancellation token which can be used to cancel the action. When a retry policy in use, also cancels any further retries. + /// Name of method originating the retriable operation. + /// The captured result. + public virtual Task ExecuteAndCaptureAsync(Func action, CancellationToken cancellationToken, [System.Runtime.CompilerServices.CallerMemberName] string? caller = default) + { + ArgumentNullException.ThrowIfNull(action); + + return retryPolicy.ExecuteAndCaptureAsync((_, token) => action(token), PrepareContext(caller), cancellationToken); + } } /// @@ -277,6 +295,6 @@ public virtual async Task ExecuteWithRetryAndConversionAsync(Func new() { - [CallerMemberNameKey] = caller + [CallerMemberNameKey] = caller ?? throw new ArgumentNullException(nameof(caller)) }; } diff --git a/src/CommonUtilities/RetryPolicyBuilder.cs b/src/CommonUtilities/RetryPolicyBuilder.cs index 78934446f..df9f65dc1 100644 --- a/src/CommonUtilities/RetryPolicyBuilder.cs +++ b/src/CommonUtilities/RetryPolicyBuilder.cs @@ -11,7 +11,7 @@ namespace CommonUtilities; /// -/// Utility class that facilitates the retry policy implementations for HTTP clients. +/// Utility class that facilitates the retry policy implementations for HTTP clients. /// public class RetryPolicyBuilder { @@ -93,6 +93,12 @@ public interface IPolicyBuilderBase /// OnRetry hander IPolicyBuilderWait WithRetryPolicyOptionsWait(); + /// + /// Default exponential wait policy. + /// + /// OnRetry hander + IPolicyBuilderWait WithExponentialBackoffWait(); + /// /// Custom exponential wait policy. /// @@ -110,22 +116,29 @@ public interface IPolicyBuilderBase IPolicyBuilderWait WithCustomizedRetryPolicyOptionsWait(int maxRetryCount, Func waitDurationProvider); /// - /// Custom optional exception-based wait policy backed up by an exponential wait policy. + /// Custom optional exception-based wait policy backed up by the default wait policy. /// /// Wait policy that can return to use the backup wait policy. - /// Maximum number of retries. - /// Value in seconds which is raised by the power of the backup retry attempt. - /// True to pass backup wait provider its own attempt values, False to provide overall attemp values. /// OnRetry hander - IPolicyBuilderWait WithExceptionBasedWaitWithExponentialBackoffBackup(Func waitDurationProvider, int maxRetryCount, double exponentialBackOffExponent, bool backupSkipProvidedIncrements); + IPolicyBuilderWait WithExceptionBasedWaitWithRetryPolicyOptionsBackup(Func waitDurationProvider); /// - /// Custom optional exception-based wait policy backed up by the default wait policy. + /// Custom optional exception-based wait policy backed up by the default exponential wait policy. + /// + /// Wait policy that can return to use the backup wait policy. + /// True to pass backup wait provider its own attempt values, False to provide overall attempt values. + /// OnRetry hander + IPolicyBuilderWait WithExceptionBasedWaitWithExponentialBackoffBackup(Func waitDurationProvider, bool backupSkipProvidedIncrements); + + /// + /// Custom optional exception-based wait policy backed up by an exponential wait policy. /// /// Wait policy that can return to use the backup wait policy. - /// True to pass backup wait provider its own attempt values, False to provide overall attemp values. + /// Maximum number of retries. + /// Value in seconds which is raised by the power of the backup retry attempt. + /// True to pass backup wait provider its own attempt values, False to provide overall attempt values. /// OnRetry hander - IPolicyBuilderWait WithExceptionBasedWaitWithRetryPolicyOptionsBackup(Func waitDurationProvider, bool backupSkipProvidedIncrements); + IPolicyBuilderWait WithExceptionBasedWaitWithExponentialBackoffBackup(Func waitDurationProvider, int maxRetryCount, double exponentialBackOffExponent, bool backupSkipProvidedIncrements); } /// @@ -139,6 +152,12 @@ public interface IPolicyBuilderBase /// OnRetry hander IPolicyBuilderWait WithRetryPolicyOptionsWait(); + /// + /// Default exponential wait policy. + /// + /// OnRetry hander + IPolicyBuilderWait WithExponentialBackoffWait(); + /// /// Custom exponential wait policy. /// @@ -165,22 +184,29 @@ public interface IPolicyBuilderBase IPolicyBuilderWait WithCustomizedRetryPolicyOptionsWait(int maxRetryCount, Func waitDurationProvider); /// - /// Custom optional exception-based wait policy backed up by an exponential wait policy. + /// Custom optional exception-based wait policy backed up by the default wait policy. /// /// Wait policy that can return to use the backup wait policy. - /// Maximum number of retries. - /// Value in seconds which is raised by the power of the backup retry attempt. - /// True to pass backup wait provider its own attempt values, False to provide overall attemp values. /// OnRetry hander - IPolicyBuilderWait WithExceptionBasedWaitWithExponentialBackoffBackup(Func waitDurationProvider, int retryCount, double exponentialBackOffExponent, bool backupSkipProvidedIncrements); + IPolicyBuilderWait WithExceptionBasedWaitWithRetryPolicyOptionsBackup(Func waitDurationProvider); /// - /// Custom optional exception-based wait policy backed up by the default wait policy. + /// Custom optional exception-based wait policy backed up by the default exponential wait policy. + /// + /// Wait policy that can return to use the backup wait policy. + /// True to pass backup wait provider its own attempt values, False to provide overall attempt values. + /// OnRetry hander + IPolicyBuilderWait WithExceptionBasedWaitWithExponentialBackoffBackup(Func waitDurationProvider, bool backupSkipProvidedIncrements); + + /// + /// Custom optional exception-based wait policy backed up by an exponential wait policy. /// /// Wait policy that can return to use the backup wait policy. - /// True to pass backup wait provider its own attempt values, False to provide overall attemp values. + /// Maximum number of retries. + /// Value in seconds which is raised by the power of the backup retry attempt. + /// True to pass backup wait provider its own attempt values, False to provide overall attempt values. /// OnRetry hander - IPolicyBuilderWait WithExceptionBasedWaitWithRetryPolicyOptionsBackup(Func waitDurationProvider, bool backupSkipProvidedIncrements); + IPolicyBuilderWait WithExceptionBasedWaitWithExponentialBackoffBackup(Func waitDurationProvider, int retryCount, double exponentialBackOffExponent, bool backupSkipProvidedIncrements); } /// @@ -240,7 +266,7 @@ public interface IPolicyBuilderBuild IAsyncPolicy AsyncBuildPolicy(); /// - /// Retrives the instance of the retryhandler to accomodate extensions to the builder + /// Retrieves the instance of the retryhandler to accommodate extensions to the builder /// RetryPolicyBuilder PolicyBuilderBase { get; } } @@ -272,7 +298,7 @@ public interface IPolicyBuilderBuild IAsyncPolicy AsyncBuildPolicy(); /// - /// Retrives the instance of the retryhandler to accomodate extensions to the builder + /// Retrieves the instance of the retryhandler to accommodate extensions to the builder /// RetryPolicyBuilder PolicyBuilderBase { get; } } @@ -329,7 +355,10 @@ public PolicyBuilderBase(PolicyBuilder policyBuilder, Defaults defaults) Defaults = defaults; } - public static Func DefaultSleepDurationProvider(Defaults defaults) + public static IEnumerable DefaultSleepDurationProvider(Defaults defaults) + => Polly.Contrib.WaitAndRetry.Backoff.DecorrelatedJitterBackoffV2(TimeSpan.FromSeconds(defaults.PolicyOptions.ExponentialBackOffExponent), defaults.PolicyOptions.MaxRetryCount); + + public static Func DefaultExponentialSleepDurationProvider(Defaults defaults) => ExponentialSleepDurationProvider(defaults.PolicyOptions.ExponentialBackOffExponent); public static Func ExponentialSleepDurationProvider(double exponentialBackOffExponent) @@ -366,23 +395,31 @@ TimeSpan AdjustAttemptIfNeeded() /// IPolicyBuilderWait IPolicyBuilderBase.WithRetryPolicyOptionsWait() - => new PolicyBuilderWait(this, Defaults.PolicyOptions.MaxRetryCount, DefaultSleepDurationProvider(Defaults)); + => new PolicyBuilderWait(this, Defaults.PolicyOptions.MaxRetryCount, sleepDurationsEnumerable: DefaultSleepDurationProvider(Defaults)); + + /// + IPolicyBuilderWait IPolicyBuilderBase.WithExponentialBackoffWait() + => new PolicyBuilderWait(this, Defaults.PolicyOptions.MaxRetryCount, sleepDurationProvider: DefaultExponentialSleepDurationProvider(Defaults)); /// IPolicyBuilderWait IPolicyBuilderBase.WithCustomizedRetryPolicyOptionsWait(int maxRetryCount, Func sleepDurationProvider) - => new PolicyBuilderWait(this, maxRetryCount, (attempt, outcome, _1) => sleepDurationProvider(attempt, outcome)); + => new PolicyBuilderWait(this, maxRetryCount, sleepDurationProvider: (attempt, outcome, _1) => sleepDurationProvider(attempt, outcome)); /// IPolicyBuilderWait IPolicyBuilderBase.WithExponentialBackoffWait(int retryCount, double exponentialBackOffExponent) - => new PolicyBuilderWait(this, retryCount, ExponentialSleepDurationProvider(exponentialBackOffExponent)); + => new PolicyBuilderWait(this, retryCount, sleepDurationProvider: ExponentialSleepDurationProvider(exponentialBackOffExponent)); /// - IPolicyBuilderWait IPolicyBuilderBase.WithExceptionBasedWaitWithRetryPolicyOptionsBackup(Func sleepDurationProvider, bool backupSkipProvidedIncrements) - => new PolicyBuilderWait(this, Defaults.PolicyOptions.MaxRetryCount, ExceptionBasedSleepDurationProviderWithExponentialBackoffBackup(sleepDurationProvider, Defaults.PolicyOptions.ExponentialBackOffExponent, backupSkipProvidedIncrements)); + IPolicyBuilderWait IPolicyBuilderBase.WithExceptionBasedWaitWithRetryPolicyOptionsBackup(Func sleepDurationProvider) + => new PolicyBuilderWait(this, Defaults.PolicyOptions.MaxRetryCount, sleepDurationProvider: (attempt, exception, _) => sleepDurationProvider(attempt, exception) ?? TimeSpan.Zero, sleepDurationsEnumerable: DefaultSleepDurationProvider(Defaults), combineSleepDurations: true); + + /// + IPolicyBuilderWait IPolicyBuilderBase.WithExceptionBasedWaitWithExponentialBackoffBackup(Func sleepDurationProvider, bool backupSkipProvidedIncrements) + => new PolicyBuilderWait(this, Defaults.PolicyOptions.MaxRetryCount, sleepDurationProvider: ExceptionBasedSleepDurationProviderWithExponentialBackoffBackup(sleepDurationProvider, Defaults.PolicyOptions.ExponentialBackOffExponent, backupSkipProvidedIncrements)); /// IPolicyBuilderWait IPolicyBuilderBase.WithExceptionBasedWaitWithExponentialBackoffBackup(Func sleepDurationProvider, int retryCount, double exponentialBackOffExponent, bool backupSkipProvidedIncrements) - => new PolicyBuilderWait(this, retryCount, ExceptionBasedSleepDurationProviderWithExponentialBackoffBackup(sleepDurationProvider, exponentialBackOffExponent, backupSkipProvidedIncrements)); + => new PolicyBuilderWait(this, retryCount, sleepDurationProvider: ExceptionBasedSleepDurationProviderWithExponentialBackoffBackup(sleepDurationProvider, exponentialBackOffExponent, backupSkipProvidedIncrements)); } private readonly struct PolicyBuilderBase : IPolicyBuilderBase @@ -400,80 +437,205 @@ public PolicyBuilderBase(PolicyBuilder policyBuilder, Defaults defaults /// IPolicyBuilderWait IPolicyBuilderBase.WithRetryPolicyOptionsWait() - => new PolicyBuilderWait(this, Defaults.PolicyOptions.MaxRetryCount, default, PolicyBuilderBase.DefaultSleepDurationProvider(Defaults)); + => new PolicyBuilderWait(this, Defaults.PolicyOptions.MaxRetryCount, sleepDurationsEnumerable: PolicyBuilderBase.DefaultSleepDurationProvider(Defaults)); + + /// + IPolicyBuilderWait IPolicyBuilderBase.WithExponentialBackoffWait() + => new PolicyBuilderWait(this, Defaults.PolicyOptions.MaxRetryCount, sleepDurationProviderException: PolicyBuilderBase.DefaultExponentialSleepDurationProvider(Defaults)); /// IPolicyBuilderWait IPolicyBuilderBase.WithCustomizedRetryPolicyOptionsWait(int maxRetryCount, Func waitDurationProvider) - => new PolicyBuilderWait(this, maxRetryCount, default, (attempt, outcome, _1) => waitDurationProvider(attempt, outcome)); + => new PolicyBuilderWait(this, maxRetryCount, sleepDurationProviderException: (attempt, outcome, _1) => waitDurationProvider(attempt, outcome)); /// IPolicyBuilderWait IPolicyBuilderBase.WithCustomizedRetryPolicyOptionsWait(int maxRetryCount, Func, TimeSpan> sleepDurationProvider) - => new PolicyBuilderWait(this, maxRetryCount, (attempt, outcome, _1) => sleepDurationProvider(attempt, outcome), default); + => new PolicyBuilderWait(this, maxRetryCount, sleepDurationProviderResult: (attempt, outcome, _1) => sleepDurationProvider(attempt, outcome)); /// IPolicyBuilderWait IPolicyBuilderBase.WithExponentialBackoffWait(int maxRetryCount, double exponentialBackOffExponent) - => new PolicyBuilderWait(this, maxRetryCount, default, PolicyBuilderBase.ExponentialSleepDurationProvider(exponentialBackOffExponent)); + => new PolicyBuilderWait(this, maxRetryCount, sleepDurationProviderException: PolicyBuilderBase.ExponentialSleepDurationProvider(exponentialBackOffExponent)); /// - IPolicyBuilderWait IPolicyBuilderBase.WithExceptionBasedWaitWithRetryPolicyOptionsBackup(Func sleepDurationProvider, bool backupSkipProvidedIncrements) - => new PolicyBuilderWait(this, Defaults.PolicyOptions.MaxRetryCount, default, PolicyBuilderBase.ExceptionBasedSleepDurationProviderWithExponentialBackoffBackup(sleepDurationProvider, Defaults.PolicyOptions.ExponentialBackOffExponent, backupSkipProvidedIncrements)); + IPolicyBuilderWait IPolicyBuilderBase.WithExceptionBasedWaitWithRetryPolicyOptionsBackup(Func sleepDurationProvider) + => new PolicyBuilderWait(this, Defaults.PolicyOptions.MaxRetryCount, sleepDurationProviderException: (attempt, exception, _) => sleepDurationProvider(attempt, exception) ?? TimeSpan.Zero, sleepDurationsEnumerable: PolicyBuilderBase.DefaultSleepDurationProvider(Defaults), combineSleepDurations: true); + + /// + IPolicyBuilderWait IPolicyBuilderBase.WithExceptionBasedWaitWithExponentialBackoffBackup(Func sleepDurationProvider, bool backupSkipProvidedIncrements) + => new PolicyBuilderWait(this, Defaults.PolicyOptions.MaxRetryCount, sleepDurationProviderException: PolicyBuilderBase.ExceptionBasedSleepDurationProviderWithExponentialBackoffBackup(sleepDurationProvider, Defaults.PolicyOptions.ExponentialBackOffExponent, backupSkipProvidedIncrements)); /// IPolicyBuilderWait IPolicyBuilderBase.WithExceptionBasedWaitWithExponentialBackoffBackup(Func sleepDurationProvider, int retryCount, double exponentialBackOffExponent, bool backupSkipProvidedIncrements) - => new PolicyBuilderWait(this, retryCount, default, PolicyBuilderBase.ExceptionBasedSleepDurationProviderWithExponentialBackoffBackup(sleepDurationProvider, exponentialBackOffExponent, backupSkipProvidedIncrements)); + => new PolicyBuilderWait(this, retryCount, sleepDurationProviderException: PolicyBuilderBase.ExceptionBasedSleepDurationProviderWithExponentialBackoffBackup(sleepDurationProvider, exponentialBackOffExponent, backupSkipProvidedIncrements)); } private readonly struct PolicyBuilderWait : IPolicyBuilderWait { public readonly PolicyBuilderBase builderBase; - public readonly Func sleepDurationProvider; + public readonly Func? sleepDurationProvider; + public readonly IEnumerable? sleepDurationsEnumerable; public readonly int maxRetryCount; - public PolicyBuilderWait(PolicyBuilderBase builderBase, int maxRetryCount, Func sleepDurationProvider) + private static Func CombineSleepDurations(Func provider, IEnumerable enumerable) { - ArgumentNullException.ThrowIfNull(sleepDurationProvider); + var combined = enumerable.Select(span => new Func(duration => + { + try + { + return duration + span; + } + catch (OverflowException) + { + return TimeSpan.MaxValue; + } + })).ToList(); + + return new((attempt, exception, context) => + { + List> stored; + + if (attempt == 1) + { + context[RetryHandler.CombineSleepDurationsKey] = stored = combined; + } + else if (context.TryGetValue(RetryHandler.CombineSleepDurationsKey, out var value) && value is List> foundValue) + { + stored = foundValue; + } + else + { + throw new System.Diagnostics.UnreachableException($"{RetryHandler.CombineSleepDurationsKey} should have been set in Polly Context at first retry"); + } + + var final = stored[attempt - 1](provider(attempt, exception, context)); + return final; + }); + } + + public PolicyBuilderWait(PolicyBuilderBase builderBase, int maxRetryCount, Func? sleepDurationProvider = default, IEnumerable? sleepDurationsEnumerable = default, bool combineSleepDurations = false) + { + if (sleepDurationProvider is null && sleepDurationsEnumerable is null) + { + throw new ArgumentNullException(null, $"At least one of {nameof(sleepDurationProvider)} or {nameof(sleepDurationsEnumerable)} must be provided."); + } + + if (combineSleepDurations && (sleepDurationProvider is null || sleepDurationsEnumerable is null)) + { + throw new ArgumentException("Both sleepDurationsEnumerable and a sleep durations provider must be provided.", nameof(combineSleepDurations)); + } + this.builderBase = builderBase; this.maxRetryCount = maxRetryCount; - this.sleepDurationProvider = sleepDurationProvider; + + if (combineSleepDurations) + { + this.sleepDurationProvider = CombineSleepDurations(sleepDurationProvider!, sleepDurationsEnumerable!); + this.sleepDurationsEnumerable = null; + } + else + { + this.sleepDurationProvider = sleepDurationProvider; + this.sleepDurationsEnumerable = sleepDurationsEnumerable; + } + + if (this.sleepDurationProvider is not null && this.sleepDurationsEnumerable is not null) + { + throw new ArgumentException($"{nameof(sleepDurationsEnumerable)} overrides {nameof(sleepDurationProvider)}", nameof(sleepDurationsEnumerable)); + } } /// IPolicyBuilderBuild IPolicyBuilderWait.SetOnRetryBehavior(ILogger? logger, RetryHandler.OnRetryHandler? onRetry, RetryHandler.OnRetryHandlerAsync? onRetryAsync) - => new PolicyBuilderBuild(this, sleepDurationProvider, logger, onRetry, onRetryAsync); + => new PolicyBuilderBuild(this, sleepDurationProvider, sleepDurationsEnumerable, logger, onRetry, onRetryAsync); } private readonly struct PolicyBuilderWait : IPolicyBuilderWait { public readonly PolicyBuilderBase builderBase; - public readonly Func? sleepDurationProvider; public readonly Func, Context, TimeSpan>? genericSleepDurationProvider; + public readonly IEnumerable? sleepDurationsEnumerable; public readonly int maxRetryCount; - private static Func, Context, TimeSpan> PickSleepDurationProvider(Func, Context, TimeSpan>? tResultProvider, Func? exceptionProvider) - => tResultProvider is null ? (attempt, outcome, ctx) => exceptionProvider!(attempt, outcome.Exception, ctx) : tResultProvider; + private static Func, Context, TimeSpan>? PickSleepDurationProvider(Func, Context, TimeSpan>? tResultProvider, Func? exceptionProvider) + => tResultProvider is null ? (exceptionProvider is null ? null : (attempt, outcome, ctx) => exceptionProvider(attempt, outcome.Exception, ctx)) : tResultProvider; - public PolicyBuilderWait(PolicyBuilderBase builderBase, int maxRetryCount, Func, Context, TimeSpan>? sleepDurationProviderResult, Func? sleepDurationProviderException) + private static Func, Context, TimeSpan> CombineSleepDurations(Func, Context, TimeSpan> provider, IEnumerable enumerable) { - if (sleepDurationProviderException is null && sleepDurationProviderResult is null) + var combined = enumerable.Select(span => new Func(duration => { - throw new ArgumentNullException(null, $"At least one of {nameof(sleepDurationProviderResult)} or {nameof(sleepDurationProviderException)} must be provided."); + try + { + return duration + span; + } + catch (OverflowException) + { + return TimeSpan.MaxValue; + } + })).ToList(); + + return new((attempt, result, context) => + { + List> stored; + + if (attempt == 1) + { + context[RetryHandler.CombineSleepDurationsKey] = stored = combined; + } + else if (context.TryGetValue(RetryHandler.CombineSleepDurationsKey, out var value) && value is List> foundValue) + { + stored = foundValue; + } + else + { + throw new System.Diagnostics.UnreachableException($"{RetryHandler.CombineSleepDurationsKey} should have been set in Polly Context at first retry"); + } + + var final = stored[attempt - 1](provider(attempt, result, context)); + return final; + }); + } + + public PolicyBuilderWait(PolicyBuilderBase builderBase, int maxRetryCount, Func, Context, TimeSpan>? sleepDurationProviderResult = default, Func? sleepDurationProviderException = default, IEnumerable? sleepDurationsEnumerable = default, bool combineSleepDurations = false) + { + if (sleepDurationProviderException is null && sleepDurationProviderResult is null && sleepDurationsEnumerable is null) + { + throw new ArgumentNullException(null, $"At least one of {nameof(sleepDurationProviderResult)}, {nameof(sleepDurationProviderException)} or {nameof(sleepDurationsEnumerable)} must be provided."); + } + + if (combineSleepDurations && ((sleepDurationProviderResult is null && sleepDurationProviderException is null) || sleepDurationsEnumerable is null)) + { + throw new ArgumentException("Both sleepDurationsEnumerable and a sleep durations provider must be provided.", nameof(combineSleepDurations)); } this.builderBase = builderBase; this.maxRetryCount = maxRetryCount; - this.sleepDurationProvider = sleepDurationProviderException; - this.genericSleepDurationProvider = sleepDurationProviderResult; + + if (combineSleepDurations) + { + this.genericSleepDurationProvider = CombineSleepDurations(PickSleepDurationProvider(genericSleepDurationProvider, sleepDurationProviderException)!, sleepDurationsEnumerable!); + this.sleepDurationsEnumerable = null; + } + else + { + this.genericSleepDurationProvider = PickSleepDurationProvider(genericSleepDurationProvider, sleepDurationProviderException); + this.sleepDurationsEnumerable = sleepDurationsEnumerable; + } + + if (this.genericSleepDurationProvider is not null && this.sleepDurationsEnumerable is not null) + { + throw new ArgumentException($"{nameof(sleepDurationsEnumerable)} overrides {nameof(sleepDurationProviderResult)} and {nameof(sleepDurationProviderException)}", nameof(sleepDurationsEnumerable)); + } } /// IPolicyBuilderBuild IPolicyBuilderWait.SetOnRetryBehavior(ILogger? logger, RetryHandler.OnRetryHandler? onRetry, RetryHandler.OnRetryHandlerAsync? onRetryAsync) - => new PolicyBuilderBuild(this, PickSleepDurationProvider(genericSleepDurationProvider, sleepDurationProvider), logger, onRetry, onRetryAsync); + => new PolicyBuilderBuild(this, genericSleepDurationProvider, sleepDurationsEnumerable, logger, onRetry, onRetryAsync); } private readonly struct PolicyBuilderBuild : IPolicyBuilderBuild { private readonly PolicyBuilderWait builderWait; - private readonly Func sleepDurationProvider; + private readonly Func? sleepDurationProvider; + public readonly IEnumerable? sleepDurationsEnumerable; private readonly ILogger? logger; private readonly RetryHandler.OnRetryHandler? onRetryHandler; private readonly RetryHandler.OnRetryHandlerAsync? onRetryHandlerAsync; @@ -481,11 +643,16 @@ IPolicyBuilderBuild IPolicyBuilderWait.SetOnRetryBehavior(ILog /// public RetryPolicyBuilder PolicyBuilderBase { get; } - public PolicyBuilderBuild(PolicyBuilderWait builderWait, Func sleepDurationProvider, ILogger? logger, RetryHandler.OnRetryHandler? onRetry, RetryHandler.OnRetryHandlerAsync? onRetryAsync) + public PolicyBuilderBuild(PolicyBuilderWait builderWait, Func? sleepDurationProvider, IEnumerable? sleepDurationsEnumerable, ILogger? logger, RetryHandler.OnRetryHandler? onRetry, RetryHandler.OnRetryHandlerAsync? onRetryAsync) { - ArgumentNullException.ThrowIfNull(sleepDurationProvider); + if (sleepDurationsEnumerable is null) + { + ArgumentNullException.ThrowIfNull(sleepDurationProvider); + } + this.builderWait = builderWait; this.sleepDurationProvider = sleepDurationProvider; + this.sleepDurationsEnumerable = sleepDurationsEnumerable; this.logger = logger; this.onRetryHandler = onRetry; this.onRetryHandlerAsync = onRetryAsync; @@ -528,19 +695,23 @@ public static Func OnRetryHandlerAsync( /// ISyncPolicy IPolicyBuilderBuild.SyncBuildPolicy() { - var waitProvider = sleepDurationProvider; + var waitProvider = sleepDurationProvider!; var onRetryProvider = OnRetryHandler(logger, onRetryHandler); - return builderWait.builderBase.policyBuilder.WaitAndRetry(builderWait.maxRetryCount, (attempt, ctx) => waitProvider(attempt, default, ctx), onRetryProvider); + return sleepDurationsEnumerable is null + ? builderWait.builderBase.policyBuilder.WaitAndRetry(builderWait.maxRetryCount, (attempt, ctx) => waitProvider(attempt, default, ctx), onRetryProvider) + : builderWait.builderBase.policyBuilder.WaitAndRetry(sleepDurationsEnumerable, onRetryProvider); } /// IAsyncPolicy IPolicyBuilderBuild.AsyncBuildPolicy() { - var waitProvider = sleepDurationProvider; + var waitProvider = sleepDurationProvider!; var onRetryProvider = OnRetryHandlerAsync(logger, onRetryHandler, onRetryHandlerAsync); - return builderWait.builderBase.policyBuilder.WaitAndRetryAsync(builderWait.maxRetryCount, waitProvider, onRetryProvider); + return sleepDurationsEnumerable is null + ? builderWait.builderBase.policyBuilder.WaitAndRetryAsync(builderWait.maxRetryCount, waitProvider, onRetryProvider) + : builderWait.builderBase.policyBuilder.WaitAndRetryAsync(sleepDurationsEnumerable, onRetryProvider); } /// @@ -555,7 +726,8 @@ RetryHandler.AsyncRetryHandlerPolicy IPolicyBuilderBuild.AsyncBuild() private readonly struct PolicyBuilderBuild : IPolicyBuilderBuild { private readonly PolicyBuilderWait builderWait; - private readonly Func, Context, TimeSpan> sleepDurationProvider; + private readonly Func, Context, TimeSpan>? sleepDurationProvider; + public readonly IEnumerable? sleepDurationsEnumerable; private readonly ILogger? logger; private readonly RetryHandler.OnRetryHandler? onRetryHandler; private readonly RetryHandler.OnRetryHandlerAsync? onRetryHandlerAsync; @@ -563,11 +735,16 @@ RetryHandler.AsyncRetryHandlerPolicy IPolicyBuilderBuild.AsyncBuild() /// public RetryPolicyBuilder PolicyBuilderBase { get; } - public PolicyBuilderBuild(PolicyBuilderWait builderWait, Func, Context, TimeSpan> sleepDurationProvider, ILogger? logger, RetryHandler.OnRetryHandler? onRetry, RetryHandler.OnRetryHandlerAsync? onRetryAsync) + public PolicyBuilderBuild(PolicyBuilderWait builderWait, Func, Context, TimeSpan>? sleepDurationProvider, IEnumerable? sleepDurationsEnumerable, ILogger? logger, RetryHandler.OnRetryHandler? onRetry, RetryHandler.OnRetryHandlerAsync? onRetryAsync) { - ArgumentNullException.ThrowIfNull(sleepDurationProvider); + if (sleepDurationsEnumerable is null) + { + ArgumentNullException.ThrowIfNull(sleepDurationProvider); + } + this.builderWait = builderWait; this.sleepDurationProvider = sleepDurationProvider; + this.sleepDurationsEnumerable = sleepDurationsEnumerable; this.logger = logger; this.onRetryHandler = onRetry; this.onRetryHandlerAsync = onRetryAsync; @@ -628,10 +805,12 @@ private static Func, TimeSpan, int, Context, Task> OnRet /// IAsyncPolicy IPolicyBuilderBuild.AsyncBuildPolicy() { - var waitProvider = sleepDurationProvider; + var waitProvider = sleepDurationProvider!; var onRetryProvider = OnRetryHandlerAsync(logger, onRetryHandler, onRetryHandlerAsync); - return builderWait.builderBase.policyBuilder.WaitAndRetryAsync(builderWait.maxRetryCount, waitProvider, onRetryProvider); + return sleepDurationsEnumerable is null + ? builderWait.builderBase.policyBuilder.WaitAndRetryAsync(builderWait.maxRetryCount, waitProvider, onRetryProvider) + : builderWait.builderBase.policyBuilder.WaitAndRetryAsync(sleepDurationsEnumerable, onRetryProvider); } ///// diff --git a/src/CommonUtilities/UtilityExtensions.cs b/src/CommonUtilities/UtilityExtensions.cs index 4c3378c2f..91f089c34 100644 --- a/src/CommonUtilities/UtilityExtensions.cs +++ b/src/CommonUtilities/UtilityExtensions.cs @@ -105,13 +105,13 @@ public static void ForEach(this IEnumerable values, Action action) #endregion #region AddRange - //public static void AddRange(this IList list, IEnumerable values) - //{ - // foreach (var value in values) - // { - // list.Add(value); - // }; - //} + public static void AddRange(this IList list, IEnumerable values) + { + foreach (var value in values) + { + list.Add(value); + }; + } public static void AddRange(this IDictionary dictionary, IDictionary values) { diff --git a/src/GenerateBatchVmSkus/AzureBatchSkuValidator.cs b/src/GenerateBatchVmSkus/AzureBatchSkuValidator.cs index 84ce680bb..3297f5d36 100644 --- a/src/GenerateBatchVmSkus/AzureBatchSkuValidator.cs +++ b/src/GenerateBatchVmSkus/AzureBatchSkuValidator.cs @@ -8,12 +8,12 @@ using System.Text.RegularExpressions; using System.Threading.Channels; using Azure.Core; +using CommonUtilities; using Microsoft.Azure.Batch; using Microsoft.Azure.Batch.Common; using Microsoft.Extensions.Primitives; -using Polly; -using Polly.Retry; using Tes.Models; +using static CommonUtilities.RetryHandler; using static GenerateBatchVmSkus.Program; /* @@ -121,9 +121,11 @@ async ValueTask GetResults(IAsyncEnumerable res } } - private static readonly AsyncRetryPolicy asyncRetryPolicy = Policy - .Handle() - .WaitAndRetryForeverAsync(i => TimeSpan.FromSeconds(0.05)); + private static readonly AsyncRetryHandlerPolicy asyncRetryPolicy = new RetryPolicyBuilder(Microsoft.Extensions.Options.Options.Create(new CommonUtilities.Options.RetryPolicyOptions())) + .PolicyBuilder.OpinionatedRetryPolicy() + .WithCustomizedRetryPolicyOptionsWait(int.MaxValue, (_, _) => TimeSpan.FromSeconds(0.05)) + .SetOnRetryBehavior() + .AsyncBuild(); private static IDictionary? batchSkus; @@ -162,12 +164,12 @@ private async ValueTask> GetVmSkusAsync(TestContext co if (CanBatchAccountValidateSku(vm, context)) { result = result.Append(vm); - await asyncRetryPolicy.ExecuteAsync(WriteLog("sort", "process", vm), cancellationToken); + await asyncRetryPolicy.ExecuteWithRetryAsync(WriteLog("sort", "process", vm), cancellationToken); } else { await resultSkus.Writer.WriteAsync(vm, cancellationToken); - await asyncRetryPolicy.ExecuteAsync(WriteLog("sort", "forward", vm), cancellationToken); + await asyncRetryPolicy.ExecuteWithRetryAsync(WriteLog("sort", "forward", vm), cancellationToken); } } @@ -255,7 +257,7 @@ private async ValueTask ValidateSkus(CancellationToken cancellationToken) var StartLoadedTest = new Func>(async vmSize => { _ = Interlocked.Increment(ref started); - await asyncRetryPolicy.ExecuteAsync(WriteLog("process", "post", vmSize), cancellationToken); + await asyncRetryPolicy.ExecuteWithRetryAsync(WriteLog("process", "post", vmSize), cancellationToken); return (vmSize, result: await TestVMSizeInBatchAsync(vmSize, cancellationToken)); }); @@ -274,7 +276,7 @@ private async ValueTask ValidateSkus(CancellationToken cancellationToken) { List skusToTest = [.. await GetVmSkusAsync(context, cancellationToken)]; await skusToTest.ToAsyncEnumerable() - .ForEachAwaitWithCancellationAsync(async (sku, token) => await asyncRetryPolicy.ExecuteAsync(WriteLog("process", "queue", sku), token), cancellationToken); + .ForEachAwaitWithCancellationAsync(async (sku, token) => await asyncRetryPolicy.ExecuteWithRetryAsync(WriteLog("process", "queue", sku), token), cancellationToken); var loadedTests = skusToTest.Where(CanTestNow).ToList(); for (tests = [.. loadedTests.Select(StartLoadedTest)]; @@ -306,20 +308,20 @@ await skusToTest.ToAsyncEnumerable() _ = retries.Remove(vmSize.VmSku.Name); vmSize.Validated = true; await resultSkus.Writer.WriteAsync(vmSize, cancellationToken); - await asyncRetryPolicy.ExecuteAsync(WriteLog("process", "use", vmSize), cancellationToken); + await asyncRetryPolicy.ExecuteWithRetryAsync(WriteLog("process", "use", vmSize), cancellationToken); break; case VerifyVMIResult.Skip: ++processed; _ = retries.Remove(vmSize.VmSku.Name); - await asyncRetryPolicy.ExecuteAsync(WriteLog("process", "skip", vmSize), cancellationToken); + await asyncRetryPolicy.ExecuteWithRetryAsync(WriteLog("process", "skip", vmSize), cancellationToken); break; case VerifyVMIResult.NextRegion: ++processedDeferred; _ = retries.Remove(vmSize.VmSku.Name); await resultSkus.Writer.WriteAsync(vmSize, cancellationToken); - await asyncRetryPolicy.ExecuteAsync(WriteLog("process", "forward", vmSize), cancellationToken); + await asyncRetryPolicy.ExecuteWithRetryAsync(WriteLog("process", "forward", vmSize), cancellationToken); break; case VerifyVMIResult.Retry: @@ -329,7 +331,7 @@ await skusToTest.ToAsyncEnumerable() retries[vmSize.VmSku.Name] = (false, lastRetry.RetryCount + 1, DateTime.UtcNow + AzureBatchSkuValidator.RetryWaitTime, vmSize); _ = Interlocked.Decrement(ref started); _ = Interlocked.Decrement(ref completed); - await asyncRetryPolicy.ExecuteAsync(WriteLog("process", $"wait{lastRetry.RetryCount}", vmSize), cancellationToken); + await asyncRetryPolicy.ExecuteWithRetryAsync(WriteLog("process", $"wait{lastRetry.RetryCount}", vmSize), cancellationToken); } else { @@ -337,7 +339,7 @@ await skusToTest.ToAsyncEnumerable() ++processed; _ = retries.Remove(vmSize.VmSku.Name); await resultSkus.Writer.WriteAsync(vmSize, cancellationToken); - await asyncRetryPolicy.ExecuteAsync(WriteLog("process", "forwardRT", vmSize), cancellationToken); + await asyncRetryPolicy.ExecuteWithRetryAsync(WriteLog("process", "forwardRT", vmSize), cancellationToken); } break; @@ -359,7 +361,7 @@ await skusToTest.ToAsyncEnumerable() skusToTest.AddRange(await (await GetVmSkusAsync(context, cancellationToken)).ToAsyncEnumerable() .WhereAwaitWithCancellation(async (vmSize, token) => { - await asyncRetryPolicy.ExecuteAsync(WriteLog("process", "queue", vmSize), cancellationToken); + await asyncRetryPolicy.ExecuteWithRetryAsync(WriteLog("process", "queue", vmSize), cancellationToken); return true; }) .ToListAsync(cancellationToken)); @@ -386,7 +388,7 @@ await skusToTest.ToAsyncEnumerable() .ToAsyncEnumerable() .WhereAwaitWithCancellation(async (vmSize, token) => { - await asyncRetryPolicy.ExecuteAsync(WriteLog("process", "queueRT", vmSize), cancellationToken); + await asyncRetryPolicy.ExecuteWithRetryAsync(WriteLog("process", "queueRT", vmSize), cancellationToken); return true; }) .ToListAsync(cancellationToken)); @@ -427,7 +429,7 @@ await skusToTest.ToAsyncEnumerable() await skusToTest.ToAsyncEnumerable().ForEachAwaitWithCancellationAsync(async (vmSize, token) => { await resultSkus.Writer.WriteAsync(vmSize, token); - await asyncRetryPolicy.ExecuteAsync(WriteLog("process", "dump", vmSize), cancellationToken); + await asyncRetryPolicy.ExecuteWithRetryAsync(WriteLog("process", "dump", vmSize), cancellationToken); ConsoleHelper.WriteLine(accountInfo.Name, ForegroundColorSpan.LightYellow(), $"Deferring '{vmSize.VmSku.Name}' due to quota (end of processing)."); }, cancellationToken); diff --git a/src/GenerateBatchVmSkus/Program.cs b/src/GenerateBatchVmSkus/Program.cs index c6f733834..63ab62ac5 100644 --- a/src/GenerateBatchVmSkus/Program.cs +++ b/src/GenerateBatchVmSkus/Program.cs @@ -9,6 +9,7 @@ using System.CommandLine.Parsing; using System.CommandLine.Rendering; using System.Globalization; +using System.Linq; using System.Reflection; using Azure.Core; using Azure.Identity; diff --git a/src/Tes.ApiClients.Tests/PriceApiClientTests.cs b/src/Tes.ApiClients.Tests/PriceApiClientTests.cs index 32008d756..32ee51039 100644 --- a/src/Tes.ApiClients.Tests/PriceApiClientTests.cs +++ b/src/Tes.ApiClients.Tests/PriceApiClientTests.cs @@ -35,7 +35,19 @@ public void Cleanup() [TestMethod] public async Task GetPricingInformationPageAsync_ReturnsSinglePageWithItemsWithMaxPageSize() { - var page = await pricingApiClient.GetPricingInformationPageAsync(DateTime.UtcNow, 0, "Virtual Machines", "westus2", CancellationToken.None); + Models.Pricing.RetailPricingData? page = default; + + try + { + page = await pricingApiClient.GetPricingInformationPageAsync(DateTime.UtcNow, 0, "Virtual Machines", "westus2", CancellationToken.None); + } + catch (HttpRequestException ex) + { + if (ex.StatusCode == System.Net.HttpStatusCode.TooManyRequests) + { + Assert.Inconclusive(); + } + } Assert.IsNotNull(page); Assert.IsTrue(page.Items.Length > 0); @@ -60,9 +72,21 @@ public async Task GetPricingInformationAsync_ReturnsMoreThan100Items() [TestMethod] public async Task GetAllPricingInformationForNonWindowsAndNonSpotVmsAsync_ReturnsOnlyNonWindowsAndNonSpotInstances() { - var pages = await pricingApiClient.GetAllPricingInformationForNonWindowsAndNonSpotVmsAsync("westus2", CancellationToken.None).ToListAsync(); + List? pages = default; + + try + { + pages = await pricingApiClient.GetAllPricingInformationForNonWindowsAndNonSpotVmsAsync("westus2", CancellationToken.None).ToListAsync(); + } + catch (HttpRequestException ex) + { + if (ex.StatusCode == System.Net.HttpStatusCode.TooManyRequests) + { + Assert.Inconclusive(); + } + } - Assert.IsTrue(pages.Count > 0); + Assert.IsTrue(pages?.Count > 0); Assert.IsFalse(pages.Any(r => r.productName.Contains(" Windows"))); Assert.IsFalse(pages.Any(r => r.productName.Contains(" Spot"))); } diff --git a/src/Tes.ApiClients.Tests/TerraWsmApiClientTests.cs b/src/Tes.ApiClients.Tests/TerraWsmApiClientTests.cs index 71bb960d8..474f78376 100644 --- a/src/Tes.ApiClients.Tests/TerraWsmApiClientTests.cs +++ b/src/Tes.ApiClients.Tests/TerraWsmApiClientTests.cs @@ -182,7 +182,6 @@ public async Task GetLandingZoneResourcesAsync_ListOfLandingZoneResourcesAndGets tokenCredential.Verify(t => t.GetTokenAsync(It.IsAny(), It.IsAny()), Times.Once); - } [TestMethod] @@ -192,7 +191,6 @@ public void GetLandingZoneResourcesApiUrl_CorrectUrlIsParsed() var expectedUrl = $"{TerraApiStubData.WsmApiHost}/api/workspaces/v1/{terraApiStubData.WorkspaceId}/resources/controlled/azure/landingzone"; Assert.AreEqual(expectedUrl, url.ToString()); - } [TestMethod] @@ -202,7 +200,6 @@ public void GetQuotaApiUrl_CorrectUrlIsParsed() var expectedUrl = $"{TerraApiStubData.WsmApiHost}/api/workspaces/v1/{terraApiStubData.WorkspaceId}/resources/controlled/azure/landingzone/quota?azureResourceId={Uri.EscapeDataString(terraApiStubData.BatchAccountId)}"; Assert.AreEqual(expectedUrl, url.ToString()); - } } } diff --git a/src/Tes.ApiClients/DrsHubApiClient.cs b/src/Tes.ApiClients/DrsHubApiClient.cs index cba923385..9a66c22e2 100644 --- a/src/Tes.ApiClients/DrsHubApiClient.cs +++ b/src/Tes.ApiClients/DrsHubApiClient.cs @@ -57,7 +57,7 @@ public virtual async Task ResolveDrsUriAsync(Uri drsUri, { var apiUrl = GetResolveDrsApiUrl(); - Logger.LogInformation(@"Resolving DRS URI calling: {uri}", apiUrl); + Logger.LogDebug(@"Resolving DRS URI calling: {uri}", apiUrl); response = await HttpSendRequestWithRetryPolicyAsync(() => new HttpRequestMessage(HttpMethod.Post, apiUrl) { Content = GetDrsResolveRequestContent(drsUri) }, @@ -65,7 +65,7 @@ public virtual async Task ResolveDrsUriAsync(Uri drsUri, var apiResponse = await GetDrsResolveApiResponseAsync(response, cancellationToken); - Logger.LogInformation(@"Successfully resolved URI: {drsUri}", drsUri); + Logger.LogDebug(@"Successfully resolved URI: {drsUri}", drsUri); return apiResponse; } diff --git a/src/Tes.ApiClients/HttpApiClient.cs b/src/Tes.ApiClients/HttpApiClient.cs index a7486b391..e1c8289ad 100644 --- a/src/Tes.ApiClients/HttpApiClient.cs +++ b/src/Tes.ApiClients/HttpApiClient.cs @@ -20,14 +20,18 @@ public abstract class HttpApiClient private static readonly HttpClient HttpClient = new(); private readonly TokenCredential tokenCredential = null!; private readonly SHA256 sha256 = SHA256.Create(); + private readonly string tokenScope = null!; + private readonly SemaphoreSlim semaphore = new(1, 1); + private AccessToken accessToken; + /// /// Logger instance /// protected readonly ILogger Logger = null!; - private readonly string tokenScope = null!; - private readonly SemaphoreSlim semaphore = new(1, 1); - private AccessToken accessToken; + /// + /// retry handler + /// protected readonly CachingRetryHandler.CachingAsyncRetryHandlerPolicy cachingRetryHandler; /// @@ -81,7 +85,8 @@ protected HttpApiClient() { } /// /// private RetryHandler.OnRetryHandler LogRetryErrorOnRetryHttpResponseMessageHandler() - => (result, timeSpan, retryCount, correlationId, caller) => + { + return new((result, timeSpan, retryCount, correlationId, caller) => { if (result.Exception is null) { @@ -93,7 +98,8 @@ private RetryHandler.OnRetryHandler LogRetryErrorOnRetryHtt Logger?.LogError(result.Exception, @"Retrying in {Method} due to '{Message}': RetryCount: {RetryCount} TimeSpan: {TimeSpan} CorrelationId: {CorrelationId}", caller, result.Exception.Message, retryCount, timeSpan.ToString("c"), correlationId.ToString("D")); } - }; + }); + } /// /// Sends request with a retry policy @@ -105,7 +111,8 @@ private RetryHandler.OnRetryHandler LogRetryErrorOnRetryHtt /// protected async Task HttpSendRequestWithRetryPolicyAsync( Func httpRequestFactory, CancellationToken cancellationToken, bool setAuthorizationHeader = false) - => await cachingRetryHandler.ExecuteWithRetryAsync(async ct => + { + return await cachingRetryHandler.ExecuteWithRetryAsync(async ct => { var request = httpRequestFactory(); @@ -116,6 +123,7 @@ protected async Task HttpSendRequestWithRetryPolicyAsync( return await HttpClient.SendAsync(request, ct); }, cancellationToken); + } /// /// Sends a Http Get request to the URL and deserializes the body response to the specified type @@ -128,14 +136,15 @@ protected async Task HttpSendRequestWithRetryPolicyAsync( /// Response's content deserialization type. /// protected async Task HttpGetRequestAsync(Uri requestUrl, bool setAuthorizationHeader, - bool cacheResults, JsonTypeInfo typeInfo, CancellationToken cancellationToken) + bool cacheResults, JsonTypeInfo typeInfo, CancellationToken cancellationToken, + [System.Runtime.CompilerServices.CallerMemberName] string caller = default) { if (cacheResults) { - return await HttpGetRequestWithCachingAndRetryPolicyAsync(requestUrl, typeInfo, cancellationToken, setAuthorizationHeader); + return await HttpGetRequestWithCachingAndRetryPolicyAsync(requestUrl, typeInfo, cancellationToken, setAuthorizationHeader, caller); } - return await HttpGetRequestWithRetryPolicyAsync(requestUrl, typeInfo, cancellationToken, setAuthorizationHeader); + return await HttpGetRequestWithRetryPolicyAsync(requestUrl, typeInfo, cancellationToken, setAuthorizationHeader, caller); } /// @@ -148,7 +157,8 @@ protected async Task HttpGetRequestAsync(Uri requestUrl, b /// Response's content deserialization type. /// protected async Task HttpGetRequestWithCachingAndRetryPolicyAsync(Uri requestUrl, - JsonTypeInfo typeInfo, CancellationToken cancellationToken, bool setAuthorizationHeader = false) + JsonTypeInfo typeInfo, CancellationToken cancellationToken, bool setAuthorizationHeader = false, + [System.Runtime.CompilerServices.CallerMemberName] string caller = default) { var cacheKey = await ToCacheKeyAsync(requestUrl, setAuthorizationHeader, cancellationToken); @@ -159,7 +169,8 @@ protected async Task HttpGetRequestWithCachingAndRetryPolicyAsync GetApiResponseContentAsync(m, typeInfo, ct), cancellationToken))!; + (m, ct) => GetApiResponseContentAsync(m, typeInfo, ct), + cancellationToken, caller))!; } /// @@ -198,15 +209,19 @@ protected async Task HttpGetRequestWithExpirableCachingAndRetryPolicy /// Response's content deserialization type. /// protected async Task HttpGetRequestWithRetryPolicyAsync(Uri requestUrl, - JsonTypeInfo typeInfo, CancellationToken cancellationToken, bool setAuthorizationHeader = false) - => await cachingRetryHandler.ExecuteWithRetryAndConversionAsync(async ct => + JsonTypeInfo typeInfo, CancellationToken cancellationToken, bool setAuthorizationHeader = false, + [System.Runtime.CompilerServices.CallerMemberName] string caller = default) + { + return await cachingRetryHandler.ExecuteWithRetryAndConversionAsync(async ct => { //request must be recreated in every retry. var httpRequest = await CreateGetHttpRequest(requestUrl, setAuthorizationHeader, ct); return await HttpClient.SendAsync(httpRequest, ct); }, - (m, ct) => GetApiResponseContentAsync(m, typeInfo, ct), cancellationToken); + (m, ct) => GetApiResponseContentAsync(m, typeInfo, ct), + cancellationToken, caller); + } /// /// Returns an query string key-value, with the value escaped. If the value is null or empty returns an empty string @@ -259,7 +274,8 @@ private async Task CreateGetHttpRequest(Uri requestUrl, bool /// Response's content deserialization type. /// protected async Task HttpGetRequestWithRetryPolicyAsync( - Func httpRequestFactory, JsonTypeInfo typeInfo, CancellationToken cancellationToken, bool setAuthorizationHeader = false) + Func httpRequestFactory, JsonTypeInfo typeInfo, CancellationToken cancellationToken, bool setAuthorizationHeader = false, + [System.Runtime.CompilerServices.CallerMemberName] string caller = default) { return await cachingRetryHandler.ExecuteWithRetryAndConversionAsync(async ct => { @@ -272,7 +288,8 @@ protected async Task HttpGetRequestWithRetryPolicyAsync( return await HttpClient.SendAsync(request, ct); }, - (m, ct) => GetApiResponseContentAsync(m, typeInfo, ct), cancellationToken); + (m, ct) => GetApiResponseContentAsync(m, typeInfo, ct), + cancellationToken, caller); } private async Task AddAuthorizationHeaderToRequestAsync(HttpRequestMessage requestMessage, CancellationToken cancellationToken) @@ -293,7 +310,7 @@ private async Task AddAuthorizationHeaderToRequestAsync(HttpRequestMessage reque } catch (Exception e) { - Logger.LogError(e, @"Failed to set authentication header with the access token for scope:{TokenScope}", + Logger.LogError(e, @"Failed to set authentication header with the access token for scope: {TokenScope}", tokenScope); throw; } diff --git a/src/Tes.ApiClients/TerraSamApiClient.cs b/src/Tes.ApiClients/TerraSamApiClient.cs index 254ba5122..38c26fdc3 100644 --- a/src/Tes.ApiClients/TerraSamApiClient.cs +++ b/src/Tes.ApiClients/TerraSamApiClient.cs @@ -52,7 +52,7 @@ private async Task GetActionManagedIdentity var url = GetSamActionManagedIdentityUrl(resourceType, resourceId, action); - Logger.LogInformation(@"Fetching action managed identity from Sam for {resourceId}", resourceId); + Logger.LogDebug(@"Fetching action managed identity from Sam for {resourceId}", resourceId); try { diff --git a/src/Tes/Repository/DatabaseOverloadedException.cs b/src/Tes.Repository/DatabaseOverloadedException.cs similarity index 100% rename from src/Tes/Repository/DatabaseOverloadedException.cs rename to src/Tes.Repository/DatabaseOverloadedException.cs diff --git a/src/Tes/Repository/ICache.cs b/src/Tes.Repository/ICache.cs similarity index 100% rename from src/Tes/Repository/ICache.cs rename to src/Tes.Repository/ICache.cs diff --git a/src/Tes/Repository/IRepository.cs b/src/Tes.Repository/IRepository.cs similarity index 100% rename from src/Tes/Repository/IRepository.cs rename to src/Tes.Repository/IRepository.cs diff --git a/src/Tes/Migrations/20230106185229_InitialCreate.Designer.cs b/src/Tes.Repository/Migrations/20230106185229_InitialCreate.Designer.cs similarity index 100% rename from src/Tes/Migrations/20230106185229_InitialCreate.Designer.cs rename to src/Tes.Repository/Migrations/20230106185229_InitialCreate.Designer.cs diff --git a/src/Tes/Migrations/20230106185229_InitialCreate.cs b/src/Tes.Repository/Migrations/20230106185229_InitialCreate.cs similarity index 100% rename from src/Tes/Migrations/20230106185229_InitialCreate.cs rename to src/Tes.Repository/Migrations/20230106185229_InitialCreate.cs diff --git a/src/Tes/Migrations/20230320202549_AddIndicesToJson.Designer.cs b/src/Tes.Repository/Migrations/20230320202549_AddIndicesToJson.Designer.cs similarity index 100% rename from src/Tes/Migrations/20230320202549_AddIndicesToJson.Designer.cs rename to src/Tes.Repository/Migrations/20230320202549_AddIndicesToJson.Designer.cs diff --git a/src/Tes/Migrations/20230320202549_AddIndicesToJson.cs b/src/Tes.Repository/Migrations/20230320202549_AddIndicesToJson.cs similarity index 100% rename from src/Tes/Migrations/20230320202549_AddIndicesToJson.cs rename to src/Tes.Repository/Migrations/20230320202549_AddIndicesToJson.cs diff --git a/src/Tes/Migrations/20230808235207_AddGinIndex.Designer.cs b/src/Tes.Repository/Migrations/20230808235207_AddGinIndex.Designer.cs similarity index 100% rename from src/Tes/Migrations/20230808235207_AddGinIndex.Designer.cs rename to src/Tes.Repository/Migrations/20230808235207_AddGinIndex.Designer.cs diff --git a/src/Tes/Migrations/20230808235207_AddGinIndex.cs b/src/Tes.Repository/Migrations/20230808235207_AddGinIndex.cs similarity index 100% rename from src/Tes/Migrations/20230808235207_AddGinIndex.cs rename to src/Tes.Repository/Migrations/20230808235207_AddGinIndex.cs diff --git a/src/Tes/Migrations/TesDbContextModelSnapshot.cs b/src/Tes.Repository/Migrations/TesDbContextModelSnapshot.cs similarity index 100% rename from src/Tes/Migrations/TesDbContextModelSnapshot.cs rename to src/Tes.Repository/Migrations/TesDbContextModelSnapshot.cs diff --git a/src/Tes.Repository/Models/KeyedDbItem.cs b/src/Tes.Repository/Models/KeyedDbItem.cs new file mode 100644 index 000000000..e377a395b --- /dev/null +++ b/src/Tes.Repository/Models/KeyedDbItem.cs @@ -0,0 +1,13 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.ComponentModel.DataAnnotations.Schema; + +namespace Tes.Repository.Models +{ + public abstract class KeyedDbItem + { + [Column("id")] + public long Id { get; set; } + } +} diff --git a/src/Tes/Models/PostgreSqlOptions.cs b/src/Tes.Repository/Models/PostgreSqlOptions.cs similarity index 97% rename from src/Tes/Models/PostgreSqlOptions.cs rename to src/Tes.Repository/Models/PostgreSqlOptions.cs index 65ef70e46..2d389a5b6 100644 --- a/src/Tes/Models/PostgreSqlOptions.cs +++ b/src/Tes.Repository/Models/PostgreSqlOptions.cs @@ -1,7 +1,7 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. -namespace Tes.Models +namespace Tes.Repository.Models { /// /// PostgresSql configuration options diff --git a/src/Tes.Repository/Models/TesTaskPostgres.cs b/src/Tes.Repository/Models/TesTaskPostgres.cs new file mode 100644 index 000000000..3500c4834 --- /dev/null +++ b/src/Tes.Repository/Models/TesTaskPostgres.cs @@ -0,0 +1,25 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System.ComponentModel.DataAnnotations.Schema; +using Microsoft.EntityFrameworkCore.Metadata.Internal; + +namespace Tes.Repository.Models +{ + /// + /// Database schema for encapsulating a TesTask as Json for Postgresql. + /// + [Table(TesDbContext.TesTasksPostgresTableName)] + public class TesTaskDatabaseItem : KeyedDbItem + { + [Column("json", TypeName = "jsonb")] + public Tes.Models.TesTask Json { get; set; } + + public TesTaskDatabaseItem Clone() + { + var result = (TesTaskDatabaseItem)MemberwiseClone(); + result.Json = Json.Clone(); + return result; + } + } +} diff --git a/src/Tes/Repository/PostgreSqlCachingRepository.cs b/src/Tes.Repository/PostgreSqlCachingRepository.cs similarity index 72% rename from src/Tes/Repository/PostgreSqlCachingRepository.cs rename to src/Tes.Repository/PostgreSqlCachingRepository.cs index 9911f67be..de4ab7679 100644 --- a/src/Tes/Repository/PostgreSqlCachingRepository.cs +++ b/src/Tes.Repository/PostgreSqlCachingRepository.cs @@ -9,35 +9,39 @@ using System.Threading; using System.Threading.Channels; using System.Threading.Tasks; +using CommonUtilities; using Microsoft.EntityFrameworkCore; using Microsoft.Extensions.Logging; -using Polly; +using static CommonUtilities.RetryHandler; namespace Tes.Repository { /// - /// A repository for storing in an Entity Framework Postgres table + /// A repository for storing in an Entity Framework Postgres table /// - /// Database table schema class - public abstract class PostgreSqlCachingRepository : IDisposable where T : class + /// Database table schema class + /// Corresponding type for + public abstract class PostgreSqlCachingRepository : IDisposable where TDbItem : Models.KeyedDbItem where TItem : RepositoryItem { private const int BatchSize = 1000; private static readonly TimeSpan defaultCompletedTaskCacheExpiration = TimeSpan.FromDays(1); - protected readonly AsyncPolicy asyncPolicy = Policy - .Handle(e => e.IsTransient) - .WaitAndRetryAsync(10, i => TimeSpan.FromSeconds(Math.Pow(2, i))); + protected readonly AsyncRetryHandlerPolicy asyncPolicy = new RetryPolicyBuilder(Microsoft.Extensions.Options.Options.Create(new CommonUtilities.Options.RetryPolicyOptions() { ExponentialBackOffExponent = 2, MaxRetryCount = 10 })) + .PolicyBuilder.OpinionatedRetryPolicy(Polly.Policy.Handle(e => e.IsTransient)) + .WithRetryPolicyOptionsWait() + .SetOnRetryBehavior() + .AsyncBuild(); - private record struct WriteItem(T DbItem, WriteAction Action, TaskCompletionSource TaskSource); - private readonly Channel itemsToWrite = Channel.CreateUnbounded(); - private readonly ConcurrentDictionary updatingItems = new(); // Collection of all pending updates to be written, to faciliate detection of simultaneous parallel updates. + private record struct WriteItem(TDbItem DbItem, WriteAction Action, TaskCompletionSource TaskSource); + private readonly Channel itemsToWrite = Channel.CreateUnbounded(new() { SingleReader = true }); + private readonly ConcurrentDictionary updatingItems = new(); // Collection of all pending updates to be written, to faciliate detection of simultaneous parallel updates. private readonly CancellationTokenSource writerWorkerCancellationTokenSource = new(); private readonly Task writerWorkerTask; protected enum WriteAction { Add, Update, Delete } protected Func CreateDbContext { get; init; } - protected readonly ICache Cache; + protected readonly ICache Cache; protected readonly ILogger Logger; private bool _disposedValue; @@ -49,7 +53,7 @@ protected enum WriteAction { Add, Update, Delete } /// Logging interface. /// Memory cache for fast access to active items. /// - protected PostgreSqlCachingRepository(Microsoft.Extensions.Hosting.IHostApplicationLifetime hostApplicationLifetime, ILogger logger = default, ICache cache = default) + protected PostgreSqlCachingRepository(Microsoft.Extensions.Hosting.IHostApplicationLifetime hostApplicationLifetime, ILogger logger = default, ICache cache = default) { Logger = logger; Cache = cache; @@ -58,7 +62,7 @@ protected PostgreSqlCachingRepository(Microsoft.Extensions.Hosting.IHostApplicat writerWorkerTask = Task.Run(() => WriterWorkerAsync(writerWorkerCancellationTokenSource.Token)) .ContinueWith(async task => { - Logger?.LogInformation("The repository WriterWorkerAsync ended with TaskStatus: {TaskStatus}", task.Status); + Logger?.LogDebug("The repository WriterWorkerAsync ended with TaskStatus: {TaskStatus}", task.Status); if (task.Status == TaskStatus.Faulted) { @@ -73,7 +77,7 @@ protected PostgreSqlCachingRepository(Microsoft.Extensions.Hosting.IHostApplicat await Task.Delay(TimeSpan.FromSeconds(40)); // Give the logger time to flush; default flush is 30s hostApplicationLifetime?.StopApplication(); }, TaskContinuationOptions.NotOnCanceled) - .ContinueWith(task => Logger?.LogInformation("The repository WriterWorkerAsync ended normally"), TaskContinuationOptions.OnlyOnCanceled); + .ContinueWith(task => Logger?.LogDebug("The repository WriterWorkerAsync ended normally"), TaskContinuationOptions.OnlyOnCanceled); } /// @@ -84,7 +88,7 @@ protected PostgreSqlCachingRepository(Microsoft.Extensions.Hosting.IHostApplicat /// Predicate to determine if is active. /// Converts (extracts and/or copies) the desired portion of . /// (for convenience in fluent/LINQ usage patterns). - protected TResult EnsureActiveItemInCache(T item, Func GetKey, Predicate IsActive, Func GetResult = default) where TResult : class + protected TResult EnsureActiveItemInCache(TDbItem item, Func GetKey, Predicate IsActive, Func GetResult = default) where TResult : class { if (Cache is not null) { @@ -107,12 +111,12 @@ protected TResult EnsureActiveItemInCache(T item, Func GetKe /// The of to query. /// A for controlling the lifetime of the asynchronous operation. /// order-by function. - /// pagination selection (within the order-by). + /// pagination selection (within ). /// The WHERE clause parts for selection in the query. /// The WHERE clause for raw SQL for selection in the query. /// /// Ensure that the from which comes isn't disposed until the entire query completes. - protected async Task> GetItemsAsync(DbSet dbSet, CancellationToken cancellationToken, Func, IQueryable> orderBy = default, Func, IQueryable> pagination = default, IEnumerable>> efPredicates = default, FormattableString rawPredicate = default) + protected async Task> GetItemsAsync(DbSet dbSet, CancellationToken cancellationToken, Func, IQueryable> orderBy = default, Func, IQueryable> pagination = default, IEnumerable>> efPredicates = default, FormattableString rawPredicate = default) { ArgumentNullException.ThrowIfNull(dbSet); @@ -123,7 +127,7 @@ protected async Task> GetItemsAsync(DbSet dbSet, CancellationT var tableQuery = rawPredicate is null ? dbSet.AsQueryable() - : dbSet.FromSql(new PrependableFormattableString($"SELECT *\r\nFROM {dbSet.EntityType.GetTableName()}\r\nWHERE ", rawPredicate)); + : dbSet.FromSql(new Utilities.PrependableFormattableString($"SELECT *\r\nFROM {dbSet.EntityType.GetTableName()}\r\nWHERE ", rawPredicate)); tableQuery = efPredicates.Any() ? efPredicates.Aggregate(tableQuery, (query, efPredicate) => query.Where(efPredicate)) @@ -134,47 +138,52 @@ protected async Task> GetItemsAsync(DbSet dbSet, CancellationT //var sqlQuery = query.ToQueryString(); //System.Diagnostics.Debugger.Break(); - return await asyncPolicy.ExecuteAsync(query.ToListAsync, cancellationToken); + return await asyncPolicy.ExecuteWithRetryAsync(query.ToListAsync, cancellationToken); } /// /// Adds entry into WriterWorker queue. /// /// + /// /// /// - protected Task AddUpdateOrRemoveItemInDbAsync(T item, WriteAction action, CancellationToken cancellationToken) + protected Task AddUpdateOrRemoveItemInDbAsync(TDbItem item, Func getItem, WriteAction action, CancellationToken cancellationToken) { - var source = new TaskCompletionSource(); + ArgumentNullException.ThrowIfNull(getItem); + + var source = new TaskCompletionSource(); var result = source.Task; - if (action == WriteAction.Update) + if (WriteAction.Add != action) { - if (updatingItems.TryAdd(item, null)) + if (updatingItems.TryAdd(item.Id, null)) { result = source.Task.ContinueWith(RemoveUpdatingItem).Unwrap(); } else { - throw new RepositoryCollisionException(); + throw new RepositoryCollisionException( + "Respository concurrency failure: attempt to update item with previously queued update pending.", + getItem(item)); } } if (!itemsToWrite.Writer.TryWrite(new(item, action, source))) { - throw new InvalidOperationException("Failed to TryWrite to _itemsToWrite channel."); + throw new InvalidOperationException("Failed to add item to _itemsToWrite channel."); } return result; - Task RemoveUpdatingItem(Task task) + Task RemoveUpdatingItem(Task task) { - _ = updatingItems.Remove(item, out _); + _ = updatingItems.Remove(item.Id, out _); return task.Status switch { TaskStatus.RanToCompletion => Task.FromResult(task.Result), - TaskStatus.Faulted => Task.FromException(task.Exception), - _ => Task.FromCanceled(cancellationToken) + TaskStatus.Faulted => Task.FromException(task.Exception), + _ => Task.FromCanceled(cancellationToken) }; } } @@ -219,7 +228,7 @@ private async ValueTask WriteItemsAsync(IList dbItems, CancellationTo dbContext.AddRange(dbItems.Where(e => WriteAction.Add.Equals(e.Action)).Select(e => e.DbItem)); dbContext.UpdateRange(dbItems.Where(e => WriteAction.Update.Equals(e.Action)).Select(e => e.DbItem)); dbContext.RemoveRange(dbItems.Where(e => WriteAction.Delete.Equals(e.Action)).Select(e => e.DbItem)); - await asyncPolicy.ExecuteAsync(dbContext.SaveChangesAsync, cancellationToken); + await asyncPolicy.ExecuteWithRetryAsync(dbContext.SaveChangesAsync, cancellationToken); OperateOnAll(dbItems, ActionOnSuccess()); } catch (Exception ex) @@ -245,6 +254,7 @@ protected virtual void Dispose(bool disposing) { if (disposing) { + //_ = itemsToWrite.Writer.TryComplete(); writerWorkerCancellationTokenSource.Cancel(); try @@ -255,6 +265,8 @@ protected virtual void Dispose(bool disposing) { } // Expected return from Wait(). catch (TaskCanceledException ex) when (writerWorkerCancellationTokenSource.Token == ex.CancellationToken) { } // Expected return from Wait(). + + writerWorkerCancellationTokenSource.Dispose(); } _disposedValue = true; @@ -266,39 +278,5 @@ public void Dispose() Dispose(disposing: true); GC.SuppressFinalize(this); } - - internal class PrependableFormattableString : FormattableString - { - private readonly FormattableString source; - private readonly string prefix; - - public PrependableFormattableString(string prefix, FormattableString formattableString) - { - ArgumentNullException.ThrowIfNull(formattableString); - ArgumentException.ThrowIfNullOrEmpty(prefix); - - source = formattableString; - this.prefix = prefix; - } - - public override int ArgumentCount => source.ArgumentCount; - - public override string Format => prefix + source.Format; - - public override object GetArgument(int index) - { - return source.GetArgument(index); - } - - public override object[] GetArguments() - { - return source.GetArguments(); - } - - public override string ToString(IFormatProvider formatProvider) - { - return prefix + source.ToString(formatProvider); - } - } } } diff --git a/src/Tes.Repository/RepositoryCollisionException.cs b/src/Tes.Repository/RepositoryCollisionException.cs new file mode 100644 index 000000000..3b6b8b9b6 --- /dev/null +++ b/src/Tes.Repository/RepositoryCollisionException.cs @@ -0,0 +1,27 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; + +namespace Tes.Repository +{ + public class RepositoryCollisionException : Exception where T : RepositoryItem + { + public T RepositoryItem { get; } + + public RepositoryCollisionException(T repositoryItem) + { + RepositoryItem = repositoryItem; + } + + public RepositoryCollisionException(string message, T repositoryItem) : base(message) + { + RepositoryItem = repositoryItem; + } + + public RepositoryCollisionException(string message, Exception innerException, T repositoryItem) : base(message, innerException) + { + RepositoryItem = repositoryItem; + } + } +} diff --git a/src/Tes.Repository/Tes.Repository.csproj b/src/Tes.Repository/Tes.Repository.csproj new file mode 100644 index 000000000..1df4830ea --- /dev/null +++ b/src/Tes.Repository/Tes.Repository.csproj @@ -0,0 +1,31 @@ + + + + net8.0 + $(Product) TES repository library + + + + + <_Parameter1>TesApi.Tests + + + + + + + all + runtime; build; native; contentfiles; analyzers; buildtransitive + + + + + + + + + + + + + diff --git a/src/Tes/Repository/TesDbContext.cs b/src/Tes.Repository/TesDbContext.cs similarity index 94% rename from src/Tes/Repository/TesDbContext.cs rename to src/Tes.Repository/TesDbContext.cs index 1b92677c2..867fe8d73 100644 --- a/src/Tes/Repository/TesDbContext.cs +++ b/src/Tes.Repository/TesDbContext.cs @@ -5,7 +5,6 @@ using Microsoft.EntityFrameworkCore; using Npgsql; using Npgsql.EntityFrameworkCore.PostgreSQL.Infrastructure; -using Tes.Models; namespace Tes.Repository { @@ -29,7 +28,7 @@ public TesDbContext(NpgsqlDataSource dataSource, Action ContextOptionsBuilder { get; set; } - public DbSet TesTasks { get; set; } + public DbSet TesTasks { get; set; } protected override void OnConfiguring(DbContextOptionsBuilder optionsBuilder) { diff --git a/src/Tes/Repository/TesRepositoryCache.cs b/src/Tes.Repository/TesRepositoryCache.cs similarity index 100% rename from src/Tes/Repository/TesRepositoryCache.cs rename to src/Tes.Repository/TesRepositoryCache.cs diff --git a/src/Tes/Repository/TesTaskPostgreSqlRepository.cs b/src/Tes.Repository/TesTaskPostgreSqlRepository.cs similarity index 88% rename from src/Tes/Repository/TesTaskPostgreSqlRepository.cs rename to src/Tes.Repository/TesTaskPostgreSqlRepository.cs index b4494b85b..8dfd1849e 100644 --- a/src/Tes/Repository/TesTaskPostgreSqlRepository.cs +++ b/src/Tes.Repository/TesTaskPostgreSqlRepository.cs @@ -5,7 +5,6 @@ namespace Tes.Repository { using System; using System.Collections.Generic; - using System.Diagnostics; using System.Linq; using System.Linq.Expressions; using System.Reflection; @@ -18,15 +17,16 @@ namespace Tes.Repository using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; using Npgsql; - using Polly; using Tes.Models; + using Tes.Repository.Models; + using Tes.Repository.Utilities; using Tes.Utilities; /// /// A TesTask specific repository for storing the TesTask as JSON within an Entity Framework Postgres table /// /// - public sealed class TesTaskPostgreSqlRepository : PostgreSqlCachingRepository, IRepository + public sealed class TesTaskPostgreSqlRepository : PostgreSqlCachingRepository, IRepository { // JsonSerializerOptions singleton factory private static readonly Lazy GetSerializerOptions = new(() => @@ -102,7 +102,6 @@ public TesTaskPostgreSqlRepository(IOptions options, Microsof { var dataSource = NpgsqlDataSourceFunc(ConnectionStringUtility.GetPostgresConnectionString(options)); // The datasource itself must be essentially a singleton. CreateDbContext = Initialize(() => new TesDbContext(dataSource, NpgsqlDbContextOptionsBuilder)); - WarmCacheAsync(CancellationToken.None).GetAwaiter().GetResult(); } /// @@ -122,42 +121,6 @@ private static Func Initialize(Func createDbContext) return createDbContext; } - private async Task WarmCacheAsync(CancellationToken cancellationToken) - { - if (Cache is null) - { - Logger?.LogWarning("Cache is null for TesTaskPostgreSqlRepository; no caching will be used."); - return; - } - - var sw = Stopwatch.StartNew(); - Logger?.LogInformation("Warming cache..."); - - // Don't allow the state of the system to change until the cache and system are consistent; - // this is a fast PostgreSQL query even for 1 million items - await Policy - .Handle() - .WaitAndRetryAsync(3, - retryAttempt => - { - Logger?.LogWarning("Warming cache retry attempt #{RetryAttempt}", retryAttempt); - return TimeSpan.FromSeconds(10); - }, - (ex, ts) => - { - Logger?.LogCritical(ex, "Couldn't warm cache, is the database online?"); - }) - .ExecuteAsync(async ct => - { - var activeTasksCount = (await InternalGetItemsAsync( - ct, - orderBy: q => q.OrderBy(t => t.Json.CreationTime), - efPredicates: Enumerable.Empty>>().Append(task => TesTask.ActiveStates.Contains(task.State)))) - .Count(); - Logger?.LogInformation("Cache warmed successfully in {TotalSeconds:n3} seconds. Added {TasksAddedCount:n0} items to the cache.", sw.Elapsed.TotalSeconds, activeTasksCount); - }, cancellationToken); - } - /// public async Task TryGetItemAsync(string id, CancellationToken cancellationToken, Action onSuccess = null) @@ -184,7 +147,7 @@ public async Task> GetItemsAsync(Expression CreateItemAsync(TesTask task, CancellationToken cancellationToken) { var item = new TesTaskDatabaseItem { Json = task }; - item = await ExecuteNpgsqlActionAsync(async () => await AddUpdateOrRemoveItemInDbAsync(item, WriteAction.Add, cancellationToken)); + item = await ExecuteNpgsqlActionAsync(async () => await AddUpdateOrRemoveItemInDbAsync(item, db => db.Json, WriteAction.Add, cancellationToken)); return EnsureActiveItemInCache(item, t => t.Json.Id, t => t.Json.IsActiveState(), CopyTesTask).TesTask; } @@ -193,6 +156,7 @@ public async Task CreateItemAsync(TesTask task, CancellationToken cance /// /// TesTask to store as JSON in the database /// A for controlling the lifetime of the asynchronous operation. + /// public async Task> CreateItemsAsync(List items, CancellationToken cancellationToken) => [.. (await Task.WhenAll(items.Select(task => CreateItemAsync(task, cancellationToken))))]; @@ -201,14 +165,14 @@ public async Task UpdateItemAsync(TesTask tesTask, CancellationToken ca { var item = await ExecuteNpgsqlActionAsync(async () => await GetItemFromCacheOrDatabase(tesTask.Id, true, cancellationToken)); item.Json = tesTask; - item = await ExecuteNpgsqlActionAsync(async () => await AddUpdateOrRemoveItemInDbAsync(item, WriteAction.Update, cancellationToken)); + item = await ExecuteNpgsqlActionAsync(async () => await AddUpdateOrRemoveItemInDbAsync(item, db => db.Json, WriteAction.Update, cancellationToken)); return EnsureActiveItemInCache(item, t => t.Json.Id, t => t.Json.IsActiveState(), CopyTesTask).TesTask; } /// public async Task DeleteItemAsync(string id, CancellationToken cancellationToken) { - _ = await ExecuteNpgsqlActionAsync(async () => await AddUpdateOrRemoveItemInDbAsync(await GetItemFromCacheOrDatabase(id, true, cancellationToken), WriteAction.Delete, cancellationToken)); + _ = await ExecuteNpgsqlActionAsync(async () => await AddUpdateOrRemoveItemInDbAsync(await GetItemFromCacheOrDatabase(id, true, cancellationToken), db => db.Json, WriteAction.Delete, cancellationToken)); _ = Cache?.TryRemove(id); } @@ -307,7 +271,7 @@ private async Task GetItemFromCacheOrDatabase(string id, bo using var dbContext = CreateDbContext(); // Search for Id within the JSON - item = await ExecuteNpgsqlActionAsync(async () => await asyncPolicy.ExecuteAsync(ct => dbContext.TesTasks.FirstOrDefaultAsync(t => t.Json.Id == id, ct), cancellationToken)); + item = await ExecuteNpgsqlActionAsync(async () => await asyncPolicy.ExecuteWithRetryAsync(ct => dbContext.TesTasks.FirstOrDefaultAsync(t => t.Json.Id == id, ct), cancellationToken)); if (throwIfNotFound && item is null) { @@ -315,7 +279,7 @@ private async Task GetItemFromCacheOrDatabase(string id, bo } } - return item; + return item.Clone(); } /// diff --git a/src/Tes/Utilities/ExpressionParameterSubstitute.cs b/src/Tes.Repository/Utilities/ExpressionParameterSubstitute.cs similarity index 100% rename from src/Tes/Utilities/ExpressionParameterSubstitute.cs rename to src/Tes.Repository/Utilities/ExpressionParameterSubstitute.cs diff --git a/src/Tes/Utilities/PostgresConnectionStringUtility.cs b/src/Tes.Repository/Utilities/PostgresConnectionStringUtility.cs similarity index 97% rename from src/Tes/Utilities/PostgresConnectionStringUtility.cs rename to src/Tes.Repository/Utilities/PostgresConnectionStringUtility.cs index 76b3f28cd..b1f175573 100644 --- a/src/Tes/Utilities/PostgresConnectionStringUtility.cs +++ b/src/Tes.Repository/Utilities/PostgresConnectionStringUtility.cs @@ -4,9 +4,9 @@ using System; using System.Text; using Microsoft.Extensions.Options; -using Tes.Models; +using Tes.Repository.Models; -namespace Tes.Utilities +namespace Tes.Repository.Utilities { public static class ConnectionStringUtility { diff --git a/src/Tes.Repository/Utilities/PrependableFormattableString.cs b/src/Tes.Repository/Utilities/PrependableFormattableString.cs new file mode 100644 index 000000000..6d6dc785c --- /dev/null +++ b/src/Tes.Repository/Utilities/PrependableFormattableString.cs @@ -0,0 +1,41 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; + +namespace Tes.Repository.Utilities +{ + internal class PrependableFormattableString : FormattableString + { + private readonly FormattableString source; + private readonly string prefix; + + public PrependableFormattableString(string prefix, FormattableString formattableString) + { + ArgumentNullException.ThrowIfNull(formattableString); + ArgumentException.ThrowIfNullOrEmpty(prefix); + + source = formattableString; + this.prefix = prefix; + } + + public override int ArgumentCount => source.ArgumentCount; + + public override string Format => prefix + source.Format; + + public override object GetArgument(int index) + { + return source.GetArgument(index); + } + + public override object[] GetArguments() + { + return source.GetArguments(); + } + + public override string ToString(IFormatProvider formatProvider) + { + return prefix + source.ToString(formatProvider); + } + } +} diff --git a/src/Tes.Runner.Test/Commands/NodeTaskResolverTests.cs b/src/Tes.Runner.Test/Commands/NodeTaskResolverTests.cs index d19f65c78..704e5e662 100644 --- a/src/Tes.Runner.Test/Commands/NodeTaskResolverTests.cs +++ b/src/Tes.Runner.Test/Commands/NodeTaskResolverTests.cs @@ -133,7 +133,7 @@ public async Task ResolveNodeTaskAsyncWithUriWhenFileExistsDoesNotDownload() public async Task ResolveNodeTaskAsyncWithUriWhenFileNotExistsDoesDownload() { ConfigureBlobApiHttpUtils((_, _) => Task.FromResult(new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(@"{}") })); - SetEnvironment(new() { RuntimeOptions = new() }); + SetEnvironment(new() { RuntimeOptions = new() { MountParentDirectoryPath = Environment.CurrentDirectory + "/task" } }); taskFile = new(Path.Combine(Environment.CurrentDirectory, CommandFactory.DefaultTaskDefinitionFile)); Assert.IsFalse(taskFile.Exists); @@ -147,7 +147,7 @@ public async Task ResolveNodeTaskAsyncWithUriWhenFileNotExistsDoesDownload() public async Task ResolveNodeTaskAsyncWithUriWhenFileNotExistsDoesSave() { ConfigureBlobApiHttpUtils((_, _) => Task.FromResult(new HttpResponseMessage(HttpStatusCode.OK) { Content = new StringContent(@"{}") })); - SetEnvironment(new() { RuntimeOptions = new() }); + SetEnvironment(new() { RuntimeOptions = new() { MountParentDirectoryPath = Environment.CurrentDirectory + "/task" } }); taskFile = new(Path.Combine(Environment.CurrentDirectory, CommandFactory.DefaultTaskDefinitionFile)); Assert.IsFalse(taskFile.Exists); @@ -165,7 +165,7 @@ public async Task ResolveNodeTaskAsyncUsesResolutionPolicyResolver() var sendGetCalled = false; ConfigureBlobApiHttpUtils((request, _) => Task.FromResult(Send(request)), (options, apiVersion) => new MockableResolutionPolicyHandler(ApplySasResolutionToUrl, options, apiVersion)); - SetEnvironment(new() { RuntimeOptions = new() { Terra = new() }, TransformationStrategy = TransformationStrategy.CombinedTerra }); + SetEnvironment(new() { RuntimeOptions = new() { Terra = new(), MountParentDirectoryPath = Environment.CurrentDirectory + "/task" }, TransformationStrategy = TransformationStrategy.CombinedTerra }); taskFile = new(Path.Combine(Environment.CurrentDirectory, CommandFactory.DefaultTaskDefinitionFile)); var result = await nodeTaskResolver.ResolveNodeTaskAsync(file: taskFile, uri: new("http://localhost/task.json"), apiVersion: BlobPipelineOptions.DefaultApiVersion, saveDownload: false); diff --git a/src/Tes.Runner.Test/Events/EventsPublisherTests.cs b/src/Tes.Runner.Test/Events/EventsPublisherTests.cs index bb932bffb..7029619bc 100644 --- a/src/Tes.Runner.Test/Events/EventsPublisherTests.cs +++ b/src/Tes.Runner.Test/Events/EventsPublisherTests.cs @@ -20,9 +20,15 @@ public void SetUp() { Id = "testId", WorkflowId = "workflowID", - ImageName = "image", - ImageTag = "tag", - CommandsToExecute = ["echo hello"], + Executors = + [ + new() + { + ImageName = "image", + ImageTag = "tag", + CommandsToExecute = ["echo hello"], + } + ], Inputs = [ new() @@ -100,28 +106,30 @@ public async Task PublishDownloadEndEventAsync_EventIsPublished_EventContainsAll [TestMethod] public async Task PublishExecutorStartEventAsync_EventIsPublished_EventContainsAllExpectedData() { - await eventsPublisher.PublishExecutorStartEventAsync(nodeTask); + await eventsPublisher.PublishExecutorStartEventAsync(nodeTask, 0); await eventsPublisher.FlushPublishersAsync(); var eventMessage = ((TestEventSink)sinks[0]).EventsHandled[0]; AssertMessageBaseMapping(eventMessage, EventsPublisher.ExecutorStartEvent, EventsPublisher.StartedStatus); - Assert.AreEqual(nodeTask.ImageName, eventMessage.EventData!["image"]); - Assert.AreEqual(nodeTask.ImageTag, eventMessage.EventData!["imageTag"]); - Assert.AreEqual(nodeTask.CommandsToExecute!.First(), eventMessage.EventData!["commands"]); + Assert.AreEqual("1/1", eventMessage.EventData!["executor"]); + Assert.AreEqual(nodeTask.Executors?[0].ImageName, eventMessage.EventData!["image"]); + Assert.AreEqual(nodeTask.Executors?[0].ImageTag, eventMessage.EventData!["imageTag"]); + Assert.AreEqual(nodeTask.Executors?[0].CommandsToExecute?.First(), eventMessage.EventData!["commands"]); } [TestMethod] public async Task PublishExecutorEndEventAsync_EventIsPublished_EventContainsAllExpectedData() { - await eventsPublisher.PublishExecutorEndEventAsync(nodeTask, exitCode: 0, EventsPublisher.SuccessStatus, errorMessage: string.Empty); + await eventsPublisher.PublishExecutorEndEventAsync(nodeTask, 0, exitCode: 0, statusMessage: EventsPublisher.SuccessStatus, errorMessage: string.Empty); await eventsPublisher.FlushPublishersAsync(); var eventMessage = ((TestEventSink)sinks[0]).EventsHandled[0]; AssertMessageBaseMapping(eventMessage, EventsPublisher.ExecutorEndEvent, EventsPublisher.SuccessStatus); - Assert.AreEqual(nodeTask.ImageName, eventMessage.EventData!["image"]); - Assert.AreEqual(nodeTask.ImageTag, eventMessage.EventData!["imageTag"]); + Assert.AreEqual("1/1", eventMessage.EventData!["executor"]); + Assert.AreEqual(nodeTask.Executors?[0].ImageName, eventMessage.EventData!["image"]); + Assert.AreEqual(nodeTask.Executors?[0].ImageTag, eventMessage.EventData!["imageTag"]); Assert.AreEqual(0, int.Parse(eventMessage.EventData!["exitCode"])); Assert.AreEqual("", eventMessage.EventData!["errorMessage"]); } diff --git a/src/Tes.Runner.Test/ExecutorTests.cs b/src/Tes.Runner.Test/ExecutorTests.cs index cdedb862a..1d741e4af 100644 --- a/src/Tes.Runner.Test/ExecutorTests.cs +++ b/src/Tes.Runner.Test/ExecutorTests.cs @@ -43,7 +43,8 @@ public void SetUp() nodeTask = new() { - MountParentDirectoryPath = "/root/parent", + Executors = [new()], + RuntimeOptions = new() { MountParentDirectoryPath = "/root/parent" }, Outputs = [ new() @@ -59,7 +60,7 @@ public void SetUp() ] }; - executor = new Executor(nodeTask, fileOperationResolverMock.Object, eventsPublisherMock.Object, transferOperationFactoryMock.Object, null!); + executor = new(nodeTask, fileOperationResolverMock.Object, eventsPublisherMock.Object, transferOperationFactoryMock.Object, null!); } [TestMethod] @@ -138,7 +139,7 @@ public async Task UploadOutputsAsync_NoOutputProvided_StartSuccessEventsAreCreat var result = await executor.UploadOutputsAsync(blobPipelineOptions); Assert.AreEqual(Executor.ZeroBytesTransferred, result); eventsPublisherMock.Verify(p => p.PublishUploadStartEventAsync(It.IsAny()), Times.Once); - eventsPublisherMock.Verify(p => p.PublishUploadEndEventAsync(It.IsAny(), 0, 0, EventsPublisher.SuccessStatus, string.Empty), Times.Once); + eventsPublisherMock.Verify(p => p.PublishUploadEndEventAsync(It.IsAny(), 0, 0, EventsPublisher.SuccessStatus, string.Empty, It.IsAny?>()), Times.Once); } [TestMethod] @@ -147,7 +148,7 @@ public async Task UploadOutputAsync_NullOptionsThrowsError_StartFailureEventsAre await Assert.ThrowsExceptionAsync(() => executor.UploadOutputsAsync(null!)); eventsPublisherMock.Verify(p => p.PublishUploadStartEventAsync(It.IsAny()), Times.Once); - eventsPublisherMock.Verify(p => p.PublishUploadEndEventAsync(It.IsAny(), 0, 0, EventsPublisher.FailedStatus, It.Is((c) => !string.IsNullOrEmpty(c))), Times.Once); + eventsPublisherMock.Verify(p => p.PublishUploadEndEventAsync(It.IsAny(), 0, 0, EventsPublisher.FailedStatus, It.Is((c) => !string.IsNullOrEmpty(c)), It.IsAny?>()), Times.Once); } [TestMethod] @@ -156,7 +157,7 @@ public async Task ExecuteNodeContainerTaskAsync_SuccessfulExecution_ReturnsConta dockerExecutorMock.Setup(d => d.RunOnContainerAsync(It.IsAny(), It.IsAny>>())) .ReturnsAsync(new ContainerExecutionResult("taskId", Error: string.Empty, ExitCode: 0)); - var result = await executor.ExecuteNodeContainerTaskAsync(dockerExecutorMock.Object); + var result = await executor.ExecuteNodeContainerTaskAsync(dockerExecutorMock.Object, 0); Assert.AreEqual(0, result.ContainerResult.ExitCode); Assert.AreEqual(string.Empty, result.ContainerResult.Error); @@ -169,7 +170,7 @@ public async Task ExecuteNodeContainerTaskAsync_ExecutionFails_ReturnsContainerR dockerExecutorMock.Setup(d => d.RunOnContainerAsync(It.IsAny(), It.IsAny>>())) .ReturnsAsync(new ContainerExecutionResult("taskId", Error: "Error", ExitCode: 1)); - var result = await executor.ExecuteNodeContainerTaskAsync(dockerExecutorMock.Object); + var result = await executor.ExecuteNodeContainerTaskAsync(dockerExecutorMock.Object, 0); Assert.AreEqual(1, result.ContainerResult.ExitCode); Assert.AreEqual("Error", result.ContainerResult.Error); @@ -181,10 +182,10 @@ public async Task ExecuteNodeContainerTaskAsync_SuccessfulExecution_StartAndSucc dockerExecutorMock.Setup(d => d.RunOnContainerAsync(It.IsAny(), It.IsAny>>())) .ReturnsAsync(new ContainerExecutionResult("taskId", Error: string.Empty, ExitCode: 0)); - await executor.ExecuteNodeContainerTaskAsync(dockerExecutorMock.Object); + await executor.ExecuteNodeContainerTaskAsync(dockerExecutorMock.Object, 0); - eventsPublisherMock.Verify(p => p.PublishExecutorStartEventAsync(It.IsAny()), Times.Once); - eventsPublisherMock.Verify(p => p.PublishExecutorEndEventAsync(It.IsAny(), 0, EventsPublisher.SuccessStatus, string.Empty), Times.Once); + eventsPublisherMock.Verify(p => p.PublishExecutorStartEventAsync(It.IsAny(), 0), Times.Once); + eventsPublisherMock.Verify(p => p.PublishExecutorEndEventAsync(It.IsAny(), 0, 0, EventsPublisher.SuccessStatus, string.Empty), Times.Once); } [TestMethod] @@ -193,10 +194,10 @@ public async Task ExecuteNodeContainerTaskAsync_ExecutionFails_StartAndFailureEv dockerExecutorMock.Setup(d => d.RunOnContainerAsync(It.IsAny(), It.IsAny>>())) .ReturnsAsync(new ContainerExecutionResult("taskId", Error: "Error", ExitCode: 1)); - await executor.ExecuteNodeContainerTaskAsync(dockerExecutorMock.Object); + await executor.ExecuteNodeContainerTaskAsync(dockerExecutorMock.Object, 0); - eventsPublisherMock.Verify(p => p.PublishExecutorStartEventAsync(It.IsAny()), Times.Once); - eventsPublisherMock.Verify(p => p.PublishExecutorEndEventAsync(It.IsAny(), 1, EventsPublisher.FailedStatus, "Error"), Times.Once); + eventsPublisherMock.Verify(p => p.PublishExecutorStartEventAsync(It.IsAny(), 0), Times.Once); + eventsPublisherMock.Verify(p => p.PublishExecutorEndEventAsync(It.IsAny(), 0, 1, EventsPublisher.FailedStatus, "Error"), Times.Once); } [TestMethod] @@ -205,10 +206,10 @@ public async Task ExecuteNodeContainerTaskAsync_ExecutionThrows_StartAndFailureE dockerExecutorMock.Setup(d => d.RunOnContainerAsync(It.IsAny(), It.IsAny>>())) .ThrowsAsync(new Exception("Error")); - await Assert.ThrowsExceptionAsync(async () => await executor.ExecuteNodeContainerTaskAsync(dockerExecutorMock.Object)); + await Assert.ThrowsExceptionAsync(async () => await executor.ExecuteNodeContainerTaskAsync(dockerExecutorMock.Object, 0)); - eventsPublisherMock.Verify(p => p.PublishExecutorStartEventAsync(It.IsAny()), Times.Once); - eventsPublisherMock.Verify(p => p.PublishExecutorEndEventAsync(It.IsAny(), Executor.DefaultErrorExitCode, EventsPublisher.FailedStatus, "Error"), Times.Once); + eventsPublisherMock.Verify(p => p.PublishExecutorStartEventAsync(It.IsAny(), 0), Times.Once); + eventsPublisherMock.Verify(p => p.PublishExecutorEndEventAsync(It.IsAny(), 0, Executor.DefaultErrorExitCode, EventsPublisher.FailedStatus, "Error"), Times.Once); } } } diff --git a/src/Tes.Runner.Test/ResolutionPolicyHandlerTests.cs b/src/Tes.Runner.Test/ResolutionPolicyHandlerTests.cs index 14844609d..62c033bea 100644 --- a/src/Tes.Runner.Test/ResolutionPolicyHandlerTests.cs +++ b/src/Tes.Runner.Test/ResolutionPolicyHandlerTests.cs @@ -16,7 +16,7 @@ public class ResolutionPolicyHandlerTests [TestInitialize] public void SetUp() { - runtimeOptions = new RuntimeOptions(); + runtimeOptions = new RuntimeOptions() { MountParentDirectoryPath = "/task" }; resolutionPolicyHandler = new ResolutionPolicyHandler(runtimeOptions, Runner.Transfer.BlobPipelineOptions.DefaultApiVersion); } diff --git a/src/Tes.Runner.Test/Storage/FileOperationResolverTests.cs b/src/Tes.Runner.Test/Storage/FileOperationResolverTests.cs index ec8cdf06c..8dc84878b 100644 --- a/src/Tes.Runner.Test/Storage/FileOperationResolverTests.cs +++ b/src/Tes.Runner.Test/Storage/FileOperationResolverTests.cs @@ -22,7 +22,7 @@ public class FileOperationResolverTests [TestInitialize] public void SetUp() { - resolutionPolicyHandler = new(new(), BlobPipelineOptions.DefaultApiVersion); + resolutionPolicyHandler = new(new() { MountParentDirectoryPath = Environment.CurrentDirectory }, BlobPipelineOptions.DefaultApiVersion); singleFileInput = new() { diff --git a/src/Tes.Runner/Authentication/CredentialsManager.cs b/src/Tes.Runner/Authentication/CredentialsManager.cs index 31c06c9f7..ce2cf574b 100644 --- a/src/Tes.Runner/Authentication/CredentialsManager.cs +++ b/src/Tes.Runner/Authentication/CredentialsManager.cs @@ -3,12 +3,12 @@ using Azure.Core; using Azure.Identity; +using CommonUtilities; using Microsoft.Extensions.Logging; -using Polly; -using Polly.Retry; using Tes.Runner.Exceptions; using Tes.Runner.Models; using Tes.Runner.Transfer; +using static CommonUtilities.RetryHandler; namespace Tes.Runner.Authentication { @@ -16,24 +16,16 @@ public class CredentialsManager { private readonly ILogger logger = PipelineLoggerFactory.Create(); - private readonly RetryPolicy retryPolicy; + private readonly RetryHandlerPolicy retryPolicy; private const int MaxRetryCount = 7; private const int ExponentialBackOffExponent = 2; public CredentialsManager() { - retryPolicy = Policy - .Handle() - .WaitAndRetry(MaxRetryCount, - SleepDurationHandler); - } - - private TimeSpan SleepDurationHandler(int attempt) - { - logger.LogInformation("Attempt {Attempt} to get token credential", attempt); - var duration = TimeSpan.FromSeconds(Math.Pow(ExponentialBackOffExponent, attempt)); - logger.LogInformation("Waiting {Duration} before retrying", duration); - return duration; + retryPolicy = new RetryPolicyBuilder(Microsoft.Extensions.Options.Options.Create(new CommonUtilities.Options.RetryPolicyOptions() { ExponentialBackOffExponent = ExponentialBackOffExponent, MaxRetryCount = MaxRetryCount })) + .DefaultRetryPolicyBuilder() + .SetOnRetryBehavior(logger) + .SyncBuild(); } public virtual TokenCredential GetTokenCredential(RuntimeOptions runtimeOptions, string? tokenScope = default) @@ -56,7 +48,7 @@ public virtual TokenCredential GetTokenCredential(RuntimeOptions runtimeOptions, tokenScope ??= runtimeOptions.AzureEnvironmentConfig!.TokenScope!; try { - return retryPolicy.Execute(() => GetTokenCredentialImpl(managedIdentityResourceId, tokenScope, runtimeOptions.AzureEnvironmentConfig!.AzureAuthorityHostUrl!)); + return retryPolicy.ExecuteWithRetry(() => GetTokenCredentialImpl(managedIdentityResourceId, tokenScope, runtimeOptions.AzureEnvironmentConfig!.AzureAuthorityHostUrl!)); } catch { @@ -73,7 +65,7 @@ private TokenCredential GetTokenCredentialImpl(string? managedIdentityResourceId if (!string.IsNullOrWhiteSpace(managedIdentityResourceId)) { - logger.LogInformation("Token credentials with Managed Identity and resource ID: {NodeManagedIdentityResourceId}", managedIdentityResourceId); + logger.LogDebug("Token credentials with Managed Identity and resource ID: {NodeManagedIdentityResourceId}", managedIdentityResourceId); var tokenCredentialOptions = new TokenCredentialOptions { AuthorityHost = authorityHost }; tokenCredential = new ManagedIdentityCredential( @@ -82,7 +74,7 @@ private TokenCredential GetTokenCredentialImpl(string? managedIdentityResourceId } else { - logger.LogInformation("Token credentials with DefaultAzureCredential"); + logger.LogDebug("Token credentials with DefaultAzureCredential"); var defaultAzureCredentialOptions = new DefaultAzureCredentialOptions { AuthorityHost = authorityHost }; tokenCredential = new DefaultAzureCredential(defaultAzureCredentialOptions); } diff --git a/src/Tes.Runner/CompletedUploadFile.cs b/src/Tes.Runner/CompletedUploadFile.cs new file mode 100644 index 000000000..671d5cde4 --- /dev/null +++ b/src/Tes.Runner/CompletedUploadFile.cs @@ -0,0 +1,13 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +namespace Tes.Runner +{ + /// + /// Upload File Log Entry + /// + /// Size of file in bytes. + /// Target URL + /// Source Path + public record struct CompletedUploadFile(long Length, Uri? BlobUrl, string FileName); +} diff --git a/src/Tes.Runner/Docker/ContainerRegistryAuthorizationManager.cs b/src/Tes.Runner/Docker/ContainerRegistryAuthorizationManager.cs index be21c877a..62d150a08 100644 --- a/src/Tes.Runner/Docker/ContainerRegistryAuthorizationManager.cs +++ b/src/Tes.Runner/Docker/ContainerRegistryAuthorizationManager.cs @@ -48,11 +48,11 @@ public ContainerRegistryAuthorizationManager(CredentialsManager tokenCredentials if (string.IsNullOrWhiteSpace(acrAccessToken)) { - logger.LogInformation(@"The ACR instance is public. No authorization is required. Registry: {RegistryEndpoint}", registryAddress); + logger.LogDebug(@"The ACR instance is public. No authorization is required. Registry: {RegistryEndpoint}", registryAddress); return null; // image is available anonymously } - logger.LogInformation(@"The ACR instance is private. An access token was successfully obtained. Registry: {RegistryEndpoint}", registryAddress); + logger.LogDebug(@"The ACR instance is private. An access token was successfully obtained. Registry: {RegistryEndpoint}", registryAddress); return new AuthConfig { diff --git a/src/Tes.Runner/Docker/VolumeBindingsGenerator.cs b/src/Tes.Runner/Docker/VolumeBindingsGenerator.cs index 708ce16b0..6d19f5f10 100644 --- a/src/Tes.Runner/Docker/VolumeBindingsGenerator.cs +++ b/src/Tes.Runner/Docker/VolumeBindingsGenerator.cs @@ -25,7 +25,7 @@ protected VolumeBindingsGenerator(string mountParentDirectory, IFileInfoProvider this.mountParentDirectory = string.IsNullOrWhiteSpace(mountParentDirectory) ? null! : fileInfoProvider.GetExpandedFileName(mountParentDirectory); } - public List GenerateVolumeBindings(List? inputs, List? outputs) + public List GenerateVolumeBindings(List? inputs, List? outputs, List? containerVolumes = default) { var volumeBindings = new HashSet(); @@ -45,12 +45,18 @@ public List GenerateVolumeBindings(List? inputs, List volumeBindings, string path) { - var mountPath = ToVolumeBinding(mountParentDirectory, path); + var mountPath = ToVolumeBinding(path); if (!string.IsNullOrEmpty(mountPath)) { @@ -58,7 +64,7 @@ private void AddVolumeBindingIfRequired(HashSet volumeBindings, string p } } - private string? ToVolumeBinding(string? mountParentDirectory, string path) + private string? ToVolumeBinding(string path) { if (string.IsNullOrEmpty(mountParentDirectory)) { @@ -69,17 +75,17 @@ private void AddVolumeBindingIfRequired(HashSet volumeBindings, string p if (!expandedPath.StartsWith(mountParentDirectory)) { - logger.LogDebug( + logger.LogTrace( @"The expanded path value {ExpandedPath} does not contain the specified mount parent directory: {MountParentDirectory}. No volume binding will be created for this file in the container.", expandedPath, mountParentDirectory); return default; } - var targetDir = $"{expandedPath.Substring(mountParentDirectory.Length).Split('/', StringSplitOptions.RemoveEmptyEntries)[0].TrimStart('/')}"; + var targetDir = $"{expandedPath[mountParentDirectory.Length..].Split('/', StringSplitOptions.RemoveEmptyEntries)[0].TrimStart('/')}"; var volBinding = $"{mountParentDirectory.TrimEnd('/')}/{targetDir}:/{targetDir}"; - logger.LogDebug(@"Volume binding for {ExpandedPath} is {VolBinding}", expandedPath, volBinding); + logger.LogTrace(@"Volume binding for {ExpandedPath} is {VolBinding}", expandedPath, volBinding); return volBinding; } diff --git a/src/Tes.Runner/Events/BlobStorageEventSink.cs b/src/Tes.Runner/Events/BlobStorageEventSink.cs index 3e7d24669..0ba85e200 100644 --- a/src/Tes.Runner/Events/BlobStorageEventSink.cs +++ b/src/Tes.Runner/Events/BlobStorageEventSink.cs @@ -33,7 +33,7 @@ public override async Task HandleEventAsync(EventMessage eventMessage) var content = JsonSerializer.Serialize(eventMessage, EventMessageContext.Default.EventMessage); await blobApiHttpUtils.ExecuteHttpRequestAsync(() => - BlobApiHttpUtils.CreatePutBlobRequestAsync(ToEventUrl(storageUrl, eventMessage), content, ApiVersion, ToTags(eventMessage))); + BlobApiHttpUtils.CreatePutBlobRequestAsync(ToEventUrl(storageUrl, eventMessage), content, ApiVersion, ToEventTag(eventMessage))); } catch (Exception e) { @@ -58,17 +58,6 @@ private static Uri ToEventUrl(Uri uri, EventMessage message) return blobBuilder.ToUri(); } - private static Dictionary ToTags(EventMessage eventMessage) - { - return new Dictionary - { - { "task-id", eventMessage.EntityId }, - { "workflow-id", eventMessage.CorrelationId }, - { "event-name", eventMessage.Name }, - { "created", eventMessage.Created.ToString(Iso8601DateFormat) } - }; - } - private static string ToBlobName(EventMessage eventMessage) { var blobName = diff --git a/src/Tes.Runner/Events/EventMessage.cs b/src/Tes.Runner/Events/EventMessage.cs index 565d2f0fc..89574f17f 100644 --- a/src/Tes.Runner/Events/EventMessage.cs +++ b/src/Tes.Runner/Events/EventMessage.cs @@ -27,13 +27,15 @@ public sealed class EventMessage [JsonPropertyName("resources")] public List? Resources { get; set; } + [JsonPropertyName("created")] public DateTime Created { get; set; } + [JsonPropertyName("eventVersion")] - public string EventVersion { get; set; } = null!; + public Version EventVersion { get; set; } = null!; [JsonPropertyName("eventDataVersion")] - public string EventDataVersion { get; set; } = null!; + public Version EventDataVersion { get; set; } = null!; [JsonPropertyName("eventData")] public Dictionary? EventData { get; set; } diff --git a/src/Tes.Runner/Events/EventSink.cs b/src/Tes.Runner/Events/EventSink.cs index 2489dc86e..c989e1f58 100644 --- a/src/Tes.Runner/Events/EventSink.cs +++ b/src/Tes.Runner/Events/EventSink.cs @@ -26,7 +26,7 @@ public async Task PublishEventAsync(EventMessage eventMessage) public void Start() { - logger.LogDebug("Starting events processing handler"); + logger.LogTrace("Starting events processing handler"); eventHandlerTask = Task.Run(async () => await EventHandlerAsync()); } @@ -41,7 +41,7 @@ public async Task StopAsync() if (eventHandlerTask.IsCompleted) { - logger.LogDebug("Task already completed"); + logger.LogTrace("Task already completed"); return; } @@ -51,19 +51,17 @@ public async Task StopAsync() await eventHandlerTask.WaitAsync(TimeSpan.FromSeconds(StopWaitDurationInSeconds)); } - protected IDictionary ToEventTag(EventMessage eventMessage) + protected static IDictionary ToEventTag(EventMessage eventMessage) { return new Dictionary { - { "event_name", eventMessage.Name }, - { "event_id", eventMessage.Id }, - { "entity_type", eventMessage.EntityType }, - { "task_id", eventMessage.EntityId }, - { "workflow_id", eventMessage.CorrelationId }, - //format date to ISO 8601, which is URL friendly + { "task-id", eventMessage.EntityId }, + { "workflow-id", eventMessage.CorrelationId }, + { "event-name", eventMessage.Name }, { "created", eventMessage.Created.ToString(Iso8601DateFormat) } }; } + private async Task EventHandlerAsync() { while (await events.Reader.WaitToReadAsync()) @@ -72,11 +70,11 @@ private async Task EventHandlerAsync() { try { - logger.LogDebug($"Handling event. Event Name: {eventMessage.Name} Event ID: {eventMessage.Id} "); + logger.LogTrace($"Handling event. Event Name: {eventMessage.Name} Event ID: {eventMessage.Id} "); await HandleEventAsync(eventMessage); - logger.LogDebug($"Event handled. Event Name: {eventMessage.Name} Event ID: {eventMessage.Id} "); + logger.LogTrace($"Event handled. Event Name: {eventMessage.Name} Event ID: {eventMessage.Id} "); } catch (Exception e) { diff --git a/src/Tes.Runner/Events/EventsPublisher.cs b/src/Tes.Runner/Events/EventsPublisher.cs index 2269f5b16..adfecec40 100644 --- a/src/Tes.Runner/Events/EventsPublisher.cs +++ b/src/Tes.Runner/Events/EventsPublisher.cs @@ -11,8 +11,8 @@ namespace Tes.Runner.Events; public class EventsPublisher : IAsyncDisposable { - const string EventVersion = "1.0"; - const string EventDataVersion = "1.0"; + public static readonly Version EventVersion = new(1, 0); + public static readonly Version EventDataVersion = new(1, 0); public const string TesTaskRunnerEntityType = "TesRunnerTask"; public const string DownloadStartEvent = "downloadStart"; public const string DownloadEndEvent = "downloadEnd"; @@ -95,7 +95,7 @@ public virtual async Task PublishUploadStartEventAsync(NodeTask nodeTask) await PublishAsync(eventMessage); } - public virtual async Task PublishUploadEndEventAsync(NodeTask nodeTask, int numberOfFiles, long totalSizeInBytes, string statusMessage, string? errorMessage = default) + public virtual async Task PublishUploadEndEventAsync(NodeTask nodeTask, int numberOfFiles, long totalSizeInBytes, string statusMessage, string? errorMessage = default, IEnumerable? completedFiles = default) { var eventMessage = CreateNewEventMessage(nodeTask.Id, UploadEndEvent, statusMessage, nodeTask.WorkflowId); @@ -104,38 +104,61 @@ public virtual async Task PublishUploadEndEventAsync(NodeTask nodeTask, int numb { { "numberOfFiles", numberOfFiles.ToString()}, { "totalSizeInBytes", totalSizeInBytes.ToString()}, - { "errorMessage", errorMessage??string.Empty} + { "errorMessage", errorMessage ?? string.Empty} }; + if (completedFiles is not null) + { + completedFiles = completedFiles.ToList(); + eventMessage.EventData.Add(@"fileLog-Count", completedFiles.Count().ToString("D")); + + foreach (var (logEntry, index) in completedFiles.Select((logEntry, index) => (logEntry, index))) + { + eventMessage.EventData.Add($"fileSize-{index}", logEntry.Length.ToString("D")); + eventMessage.EventData.Add($"fileUri-{index}", logEntry.BlobUrl?.AbsoluteUri ?? string.Empty); + eventMessage.EventData.Add($"filePath-{index}", logEntry.FileName); + } + } + await PublishAsync(eventMessage); } - public virtual async Task PublishExecutorStartEventAsync(NodeTask nodeTask) + private static string ExecutorFormatted(NodeTask nodeTask, int selector) + // Maintain format with TesApi.Web.Events.RunnerEventsProcessor.GetMessageBatchStateAsync+ParseExecutorIndex() + => $"{selector + 1}/{nodeTask.Executors?.Count ?? 0}"; + + public virtual async Task PublishExecutorStartEventAsync(NodeTask nodeTask, int selector) { var eventMessage = CreateNewEventMessage(nodeTask.Id, ExecutorStartEvent, StartedStatus, nodeTask.WorkflowId); - var commands = nodeTask.CommandsToExecute ?? []; + var executor = nodeTask.Executors?[selector]; + var commands = executor?.CommandsToExecute ?? []; eventMessage.EventData = new() { - { "image", nodeTask.ImageName??string.Empty}, - { "imageTag", nodeTask.ImageTag??string.Empty}, + { "executor", ExecutorFormatted(nodeTask, selector) }, + { "image", executor?.ImageName ?? string.Empty}, + { "imageTag", executor?.ImageTag ?? string.Empty}, { "commands", string.Join(' ', commands) } }; await PublishAsync(eventMessage); } - public virtual async Task PublishExecutorEndEventAsync(NodeTask nodeTask, long exitCode, string statusMessage, string? errorMessage = default) + public virtual async Task PublishExecutorEndEventAsync(NodeTask nodeTask, int selector, long exitCode, string statusMessage, string? errorMessage = default) { var eventMessage = CreateNewEventMessage(nodeTask.Id, ExecutorEndEvent, statusMessage, nodeTask.WorkflowId); + + var executor = nodeTask.Executors?[selector]; + eventMessage.EventData = new() { - { "image", nodeTask.ImageName??string.Empty}, - { "imageTag", nodeTask.ImageTag??string.Empty}, + { "executor", ExecutorFormatted(nodeTask, selector) }, + { "image", executor?.ImageName ?? string.Empty}, + { "imageTag", executor?.ImageTag ?? string.Empty}, { "exitCode", exitCode.ToString()}, - { "errorMessage", errorMessage??string.Empty} + { "errorMessage", errorMessage ?? string.Empty} }; await PublishAsync(eventMessage); } @@ -156,7 +179,7 @@ public virtual async Task PublishDownloadEndEventAsync(NodeTask nodeTask, int nu { { "numberOfFiles", numberOfFiles.ToString()}, { "totalSizeInBytes", totalSizeInBytes.ToString()}, - { "errorMessage", errorMessage??string.Empty} + { "errorMessage", errorMessage ?? string.Empty} }; await PublishAsync(eventMessage); } @@ -176,7 +199,7 @@ public async Task PublishTaskCompletionEventAsync(NodeTask tesNodeTask, TimeSpan eventMessage.EventData = new() { { "duration", duration.ToString()}, - { "errorMessage", errorMessage??string.Empty} + { "errorMessage", errorMessage ?? string.Empty} }; await PublishAsync(eventMessage); @@ -211,7 +234,7 @@ private async Task PublishAsync(EventMessage message) foreach (var sink in sinks) { - logger.LogInformation("Publishing event {MessageName} to sink: {SinkType}", message.Name, sink.GetType().Name); + logger.LogDebug("Publishing event {MessageName} to sink: {SinkType}", message.Name, sink.GetType().Name); await sink.PublishEventAsync(message); } @@ -224,8 +247,14 @@ public async Task FlushPublishersAsync(int waitTimeInSeconds = 60) await Task.WhenAll(stopTasks).WaitAsync(TimeSpan.FromSeconds(waitTimeInSeconds)); } - public async ValueTask DisposeAsync() + protected async virtual ValueTask DisposeAsyncCore() { await FlushPublishersAsync(); } + + public async ValueTask DisposeAsync() + { + await DisposeAsyncCore(); + GC.SuppressFinalize(this); + } } diff --git a/src/Tes.Runner/Executor.cs b/src/Tes.Runner/Executor.cs index bc493439e..585cac85b 100644 --- a/src/Tes.Runner/Executor.cs +++ b/src/Tes.Runner/Executor.cs @@ -12,7 +12,7 @@ namespace Tes.Runner { - public class Executor : IAsyncDisposable + public sealed class Executor : IAsyncDisposable { public const long ZeroBytesTransferred = 0; public const long DefaultErrorExitCode = 1; @@ -50,19 +50,19 @@ public Executor(NodeTask tesNodeTask, FileOperationResolver operationResolver, E this.apiVersion = apiVersion; } - public async Task ExecuteNodeContainerTaskAsync(DockerExecutor dockerExecutor) + public async Task ExecuteNodeContainerTaskAsync(DockerExecutor dockerExecutor, int selector) { try { - await eventsPublisher.PublishExecutorStartEventAsync(tesNodeTask); + await eventsPublisher.PublishExecutorStartEventAsync(tesNodeTask, selector); - var bindings = new VolumeBindingsGenerator(tesNodeTask.MountParentDirectoryPath!).GenerateVolumeBindings(tesNodeTask.Inputs, tesNodeTask.Outputs); + var bindings = new VolumeBindingsGenerator(tesNodeTask.RuntimeOptions.MountParentDirectoryPath!).GenerateVolumeBindings(tesNodeTask.Inputs, tesNodeTask.Outputs, tesNodeTask.ContainerVolumes); - var executionOptions = CreateExecutionOptions(bindings); + var executionOptions = CreateExecutionOptions(tesNodeTask.Executors![selector], bindings); var result = await dockerExecutor.RunOnContainerAsync(executionOptions, prefix => LogPublisher.CreateStreamReaderLogPublisherAsync(executionOptions.RuntimeOptions, prefix, apiVersion)); - await eventsPublisher.PublishExecutorEndEventAsync(tesNodeTask, result.ExitCode, ToStatusMessage(result), result.Error); + await eventsPublisher.PublishExecutorEndEventAsync(tesNodeTask, selector, result.ExitCode, ToStatusMessage(result), result.Error); return new NodeTaskResult(result); } @@ -70,17 +70,17 @@ public async Task ExecuteNodeContainerTaskAsync(DockerExecutor d { logger.LogError(e, "Failed to execute container"); - await eventsPublisher.PublishExecutorEndEventAsync(tesNodeTask, DefaultErrorExitCode, EventsPublisher.FailedStatus, e.Message); + await eventsPublisher.PublishExecutorEndEventAsync(tesNodeTask, selector, DefaultErrorExitCode, EventsPublisher.FailedStatus, e.Message); throw; } } - private ExecutionOptions CreateExecutionOptions(List bindings) + private ExecutionOptions CreateExecutionOptions(Models.Executor executor, List bindings) { - return new(tesNodeTask.ImageName, tesNodeTask.ImageTag, tesNodeTask.CommandsToExecute, bindings, - tesNodeTask.ContainerWorkDir, tesNodeTask.RuntimeOptions, tesNodeTask.ContainerDeviceRequests, - tesNodeTask.ContainerEnv, tesNodeTask.ContainerStdInPath, tesNodeTask.ContainerStdOutPath, tesNodeTask.ContainerStdErrPath); + return new(executor.ImageName, executor.ImageTag, executor.CommandsToExecute, bindings, + executor.ContainerWorkDir, tesNodeTask.RuntimeOptions, tesNodeTask.ContainerDeviceRequests, + executor.ContainerEnv, executor.ContainerStdInPath, executor.ContainerStdOutPath, executor.ContainerStdErrPath); } private static string ToStatusMessage(ContainerExecutionResult result) @@ -119,6 +119,8 @@ public async Task UploadOutputsAsync(BlobPipelineOptions blobPipelineOptio var bytesTransferred = ZeroBytesTransferred; var numberOfOutputs = 0; var errorMessage = string.Empty; + IEnumerable? completedFiles = default; + try { await eventsPublisher.PublishUploadStartEventAsync(tesNodeTask); @@ -142,7 +144,7 @@ public async Task UploadOutputsAsync(BlobPipelineOptions blobPipelineOptio var optimizedOptions = OptimizeBlobPipelineOptionsForUpload(blobPipelineOptions, outputs); - bytesTransferred = await UploadOutputsAsync(optimizedOptions, outputs); + (bytesTransferred, completedFiles) = await UploadOutputsAsync(optimizedOptions, outputs); await AppendMetrics(tesNodeTask.OutputsMetricsFormat, bytesTransferred); @@ -153,30 +155,61 @@ public async Task UploadOutputsAsync(BlobPipelineOptions blobPipelineOptio logger.LogError(e, "Upload operation failed"); statusMessage = EventsPublisher.FailedStatus; errorMessage = e.Message; + completedFiles = default; throw; } finally { - await eventsPublisher.PublishUploadEndEventAsync(tesNodeTask, numberOfOutputs, bytesTransferred, statusMessage, errorMessage); + await eventsPublisher.PublishUploadEndEventAsync(tesNodeTask, numberOfOutputs, bytesTransferred, statusMessage, errorMessage, completedFiles); } } - private async Task UploadOutputsAsync(BlobPipelineOptions blobPipelineOptions, List outputs) + public async Task UploadTaskOutputsAsync(BlobPipelineOptions blobPipelineOptions) + { + try + { + ArgumentNullException.ThrowIfNull(blobPipelineOptions, nameof(blobPipelineOptions)); + + var outputs = await CreateUploadTaskOutputsAsync(); + + if (outputs is null) + { + return; + } + + if (outputs.Count == 0) + { + logger.LogWarning("No output files were found."); + return; + } + + var optimizedOptions = OptimizeBlobPipelineOptionsForUpload(blobPipelineOptions, outputs); + + _ = await UploadOutputsAsync(optimizedOptions, outputs); + } + catch (Exception e) + { + logger.LogError(e, "Upload operation failed"); + throw; + } + } + + private async Task UploadOutputsAsync(BlobPipelineOptions blobPipelineOptions, List outputs) { var uploader = await transferOperationFactory.CreateBlobUploaderAsync(blobPipelineOptions); var executionResult = await TimedExecutionAsync(async () => await uploader.UploadAsync(outputs)); - logger.LogInformation("Executed Upload. Time elapsed: {ElapsedTime} Bandwidth: {Bandwidth} MiB/s", executionResult.Elapsed, BlobSizeUtils.ToBandwidth(executionResult.Result, executionResult.Elapsed.TotalSeconds)); + logger.LogDebug("Executed Upload. Time elapsed: {ElapsedTime} Bandwidth: {BandwidthMiBpS} MiB/s", executionResult.Elapsed, BlobSizeUtils.ToBandwidth(executionResult.Result, executionResult.Elapsed.TotalSeconds)); - return executionResult.Result; + return new(executionResult.Result, uploader.CompletedFiles); } private async Task?> CreateUploadOutputsAsync() { if ((tesNodeTask.Outputs ?? []).Count == 0) { - logger.LogInformation("No outputs provided"); + logger.LogDebug("No outputs provided"); { return default; } @@ -185,6 +218,19 @@ private async Task UploadOutputsAsync(BlobPipelineOptions blobPipelineOpti return await operationResolver.ResolveOutputsAsync(); } + private async Task?> CreateUploadTaskOutputsAsync() + { + if ((tesNodeTask.Outputs ?? []).Count == 0) + { + logger.LogDebug("No outputs provided"); + { + return default; + } + } + + return await operationResolver.ResolveTaskOutputsAsync(); + } + private BlobPipelineOptions OptimizeBlobPipelineOptionsForUpload(BlobPipelineOptions blobPipelineOptions, List outputs) { var optimizedOptions = @@ -258,7 +304,7 @@ private async Task DownloadInputsAsync(BlobPipelineOptions blobPipelineOpt var executionResult = await TimedExecutionAsync(async () => await downloader.DownloadAsync(inputs)); - logger.LogInformation("Executed Download. Time elapsed: {ElapsedTime} Bandwidth: {Bandwidth} MiB/s", executionResult.Elapsed, BlobSizeUtils.ToBandwidth(executionResult.Result, executionResult.Elapsed.TotalSeconds)); + logger.LogInformation("Executed Download. Time elapsed: {ElapsedTime} Bandwidth: {BandwidthMiBpS} MiB/s", executionResult.Elapsed, BlobSizeUtils.ToBandwidth(executionResult.Result, executionResult.Elapsed.TotalSeconds)); return executionResult.Result; } @@ -267,7 +313,7 @@ private async Task DownloadInputsAsync(BlobPipelineOptions blobPipelineOpt { if (tesNodeTask.Inputs is null || tesNodeTask.Inputs.Count == 0) { - logger.LogInformation("No inputs provided"); + logger.LogDebug("No inputs provided"); { return default; } @@ -286,10 +332,10 @@ private static void ValidateBlockSize(int blockSizeBytes) private void LogStartConfig(BlobPipelineOptions blobPipelineOptions) { - logger.LogInformation("Writers: {NumberOfWriters}", blobPipelineOptions.NumberOfWriters); - logger.LogInformation("Readers: {NumberOfReaders}", blobPipelineOptions.NumberOfReaders); - logger.LogInformation("Capacity: {ReadWriteBuffersCapacity}", blobPipelineOptions.ReadWriteBuffersCapacity); - logger.LogInformation("BlockSize: {BlockSizeBytes}", blobPipelineOptions.BlockSizeBytes); + logger.LogDebug("Writers: {NumberOfWriters}", blobPipelineOptions.NumberOfWriters); + logger.LogDebug("Readers: {NumberOfReaders}", blobPipelineOptions.NumberOfReaders); + logger.LogDebug("Capacity: {ReadWriteBuffersCapacity}", blobPipelineOptions.ReadWriteBuffersCapacity); + logger.LogDebug("BlockSize: {BlockSizeBytes}", blobPipelineOptions.BlockSizeBytes); } private static async Task> TimedExecutionAsync(Func> execution) @@ -301,9 +347,10 @@ private static async Task> TimedExecutionAsync(Func(TimeSpan Elapsed, T Result); + private record struct UploadResults(long BytesTransferred, IEnumerable CompletedFiles); + private record struct TimedExecutionResult(TimeSpan Elapsed, T Result); - public async ValueTask DisposeAsync() + async ValueTask IAsyncDisposable.DisposeAsync() { await eventsPublisher.FlushPublishersAsync(); GC.SuppressFinalize(this); diff --git a/src/Tes.Runner/Logs/AppendBlobLogPublisher.cs b/src/Tes.Runner/Logs/AppendBlobLogPublisher.cs index 2170f9c55..6413ba538 100644 --- a/src/Tes.Runner/Logs/AppendBlobLogPublisher.cs +++ b/src/Tes.Runner/Logs/AppendBlobLogPublisher.cs @@ -13,7 +13,7 @@ namespace Tes.Runner.Logs public class AppendBlobLogPublisher : StreamLogReader { private readonly string targetUrl; - private readonly BlobApiHttpUtils blobApiHttpUtils = new BlobApiHttpUtils(); + private readonly BlobApiHttpUtils blobApiHttpUtils = new(); private readonly string stdOutLogNamePrefix; private readonly string stdErrLogNamePrefix; private readonly ILogger logger = PipelineLoggerFactory.Create(); @@ -36,7 +36,7 @@ public AppendBlobLogPublisher(string targetUrl, string logNamePrefix) stdErrLogNamePrefix = $"{logNamePrefix}_stderr_{prefixTimeStamp}"; } - private string GetBlobNameConsideringBlockCountCurrentState(int blockCount, string logName) + private static string GetBlobNameConsideringBlockCountCurrentState(int blockCount, string logName) { var blobNumber = blockCount / BlobSizeUtils.MaxBlobBlocksCount; diff --git a/src/Tes.Runner/Storage/FileOperationResolver.cs b/src/Tes.Runner/Storage/FileOperationResolver.cs index 89ec12b15..937ade59f 100644 --- a/src/Tes.Runner/Storage/FileOperationResolver.cs +++ b/src/Tes.Runner/Storage/FileOperationResolver.cs @@ -52,6 +52,13 @@ public FileOperationResolver(NodeTask nodeTask, ResolutionPolicyHandler resoluti return await resolutionPolicyHandler.ApplyResolutionPolicyAsync(expandedOutputs); } + public virtual async Task?> ResolveTaskOutputsAsync() + { + var expandedOutputs = ExpandTaskOutputs(); + + return await resolutionPolicyHandler.ApplyResolutionPolicyAsync(expandedOutputs); + } + private List ExpandInputs() { List expandedInputs = []; @@ -104,6 +111,18 @@ private List ExpandOutputs() return outputs; } + private List ExpandTaskOutputs() + { + List outputs = []; + + foreach (var output in nodeTask.TaskOutputs ?? []) + { + outputs.AddRange(ExpandOutput(output)); + } + + return outputs; + } + private IEnumerable ExpandOutput(FileOutput output) { ValidateFileOutput(output); @@ -159,7 +178,7 @@ private IEnumerable ExpandFileOutput(FileOutput output) if (fileInfoProvider.FileExists(expandedPath)) { //treat the output as a single file and use the target URL as is - logger.LogInformation("Adding file: {ExpandedPath} to the output list with a target URL as is", expandedPath); + logger.LogDebug("Adding file: {ExpandedPath} to the output list with a target URL as is", expandedPath); yield return CreateExpandedFileOutputUsingTargetUrl(output, absoluteFilePath: expandedPath); @@ -182,14 +201,14 @@ private IEnumerable ExpandFileOutput(FileOutput output) foreach (var file in fileInfoProvider.GetFilesBySearchPattern(rootPathPair.Root, rootPathPair.RelativePath)) { - logger.LogInformation("Adding file: {RelativePathToSearchPath} with absolute path: {AbsolutePath} to the output list with a combined target URL", file.RelativePathToSearchPath, file.AbsolutePath); + logger.LogDebug("Adding file: {RelativePathToSearchPath} with absolute path: {AbsolutePath} to the output list with a combined target URL", file.RelativePathToSearchPath, file.AbsolutePath); yield return CreateExpandedFileOutputWithCombinedTargetUrl(output, absoluteFilePath: file.AbsolutePath, relativePathToSearchPath: file.RelativePathToSearchPath); } } } - private static FileOutput CreateExpandedFileOutputWithCombinedTargetUrl(FileOutput output, string absoluteFilePath, string relativePathToSearchPath) + private FileOutput CreateExpandedFileOutputWithCombinedTargetUrl(FileOutput output, string absoluteFilePath, string relativePathToSearchPath) { return new() { @@ -200,7 +219,7 @@ private static FileOutput CreateExpandedFileOutputWithCombinedTargetUrl(FileOutp }; } - private static FileOutput CreateExpandedFileOutputUsingTargetUrl(FileOutput output, string absoluteFilePath) + private FileOutput CreateExpandedFileOutputUsingTargetUrl(FileOutput output, string absoluteFilePath) { return new() { diff --git a/src/Tes.Runner/Storage/ResolutionPolicyHandler.cs b/src/Tes.Runner/Storage/ResolutionPolicyHandler.cs index c0ee4d4a2..c97e077e5 100644 --- a/src/Tes.Runner/Storage/ResolutionPolicyHandler.cs +++ b/src/Tes.Runner/Storage/ResolutionPolicyHandler.cs @@ -15,6 +15,7 @@ public class ResolutionPolicyHandler const BlobSasPermissions UploadBlobSasPermissions = BlobSasPermissions.Read | BlobSasPermissions.Write | BlobSasPermissions.Create | BlobSasPermissions.List; private readonly RuntimeOptions runtimeOptions = null!; + private readonly string mountParentDirectoryPath = null!; private readonly string apiVersion; public ResolutionPolicyHandler(RuntimeOptions runtimeOptions, string apiVersion) @@ -24,6 +25,7 @@ public ResolutionPolicyHandler(RuntimeOptions runtimeOptions, string apiVersion) this.runtimeOptions = runtimeOptions; this.apiVersion = apiVersion; + this.mountParentDirectoryPath = Environment.ExpandEnvironmentVariables(runtimeOptions.MountParentDirectoryPath ?? throw new ArgumentException($"{nameof(runtimeOptions.MountParentDirectoryPath)} is missing.", nameof(runtimeOptions))); } /// @@ -98,7 +100,7 @@ private async Task CreateUploadInfoWithStrategyAsync(FileOutput outp { var uri = await ApplySasResolutionToUrlAsync(output.TargetUrl, output.TransformationStrategy, uploadBlobSasPermissions, runtimeOptions, apiVersion); - return new UploadInfo(output.Path!, uri); + return new UploadInfo(output.Path!, uri, mountParentDirectoryPath); } protected virtual async Task ApplySasResolutionToUrlAsync(string? sourceUrl, TransformationStrategy? strategy, diff --git a/src/Tes.Runner/Storage/TerraUrlTransformationStrategy.cs b/src/Tes.Runner/Storage/TerraUrlTransformationStrategy.cs index 19887cece..1eca3f59e 100644 --- a/src/Tes.Runner/Storage/TerraUrlTransformationStrategy.cs +++ b/src/Tes.Runner/Storage/TerraUrlTransformationStrategy.cs @@ -80,7 +80,7 @@ private async Task GetMappedSasUrlFromWsmAsync(TerraBlobInfo blobInfo, Blob { var tokenInfo = await GetWorkspaceSasTokenFromWsmAsync(blobInfo, blobSasPermissions); - logger.LogInformation("Successfully obtained the SAS URL from Terra. WSM resource ID:{ContainerResourceId}", blobInfo.WsmContainerResourceId); + logger.LogDebug("Successfully obtained the SAS URL from Terra. WSM resource ID:{ContainerResourceId}", blobInfo.WsmContainerResourceId); var uriBuilder = new UriBuilder(tokenInfo.Url); @@ -99,7 +99,7 @@ private async Task GetWorkspaceSasTokenFromWsmAsync(Terr { var tokenParams = CreateTokenParamsFromOptions(sasBlobPermissions); - logger.LogInformation( + logger.LogDebug( "Getting SAS URL from Terra. WSM workspace ID:{WorkspaceId}", blobInfo.WorkspaceId); var cacheKey = $"{blobInfo.WorkspaceId}-{blobInfo.WsmContainerResourceId}-{tokenParams.SasPermission}"; @@ -111,7 +111,7 @@ private async Task GetWorkspaceSasTokenFromWsmAsync(Terr throw new InvalidOperationException("The value retrieved from the cache is null"); } - logger.LogInformation("SAS URL found in cache. WSM resource ID:{ContainerResourceId}", blobInfo.WsmContainerResourceId); + logger.LogDebug("SAS URL found in cache. WSM resource ID:{ContainerResourceId}", blobInfo.WsmContainerResourceId); return tokenInfo; } @@ -196,11 +196,11 @@ private async Task GetTerraBlobInfoFromContainerNameAsync(string CheckIfAccountIsTerraStorageAccount(blobUriBuilder.AccountName); - logger.LogInformation("Getting Workspace ID from the Container Name: {BlobContainerName}", blobUriBuilder.BlobContainerName); + logger.LogDebug("Getting Workspace ID from the Container Name: {BlobContainerName}", blobUriBuilder.BlobContainerName); var workspaceId = ToWorkspaceId(blobUriBuilder.BlobContainerName); - logger.LogInformation("Workspace ID to use: {WorkspaceId}", workspaceId); + logger.LogDebug("Workspace ID to use: {WorkspaceId}", workspaceId); var wsmContainerResourceId = await GetWsmContainerResourceIdFromCacheOrWsmAsync(workspaceId, blobUriBuilder.BlobContainerName); @@ -242,7 +242,7 @@ private Guid ToWorkspaceId(string segmentsContainerName) private async Task GetWsmContainerResourceIdFromCacheOrWsmAsync(Guid workspaceId, string containerName) { - logger.LogInformation("Getting container resource information from WSM. Workspace ID: {WorkspaceId} Container Name: {BlobContainerName}", workspaceId, containerName); + logger.LogDebug("Getting container resource information from WSM. Workspace ID: {WorkspaceId} Container Name: {BlobContainerName}", workspaceId, containerName); try { @@ -250,7 +250,7 @@ private async Task GetWsmContainerResourceIdFromCacheOrWsmAsync(Guid works if (memoryCache.TryGetValue(cacheKey, out Guid wsmContainerResourceId)) { - logger.LogInformation("Found the container resource ID in cache. Resource ID: {ContainerResourceId} Container Name: {BlobContainerName}", wsmContainerResourceId, containerName); + logger.LogDebug("Found the container resource ID in cache. Resource ID: {ContainerResourceId} Container Name: {BlobContainerName}", wsmContainerResourceId, containerName); return wsmContainerResourceId; } @@ -263,7 +263,7 @@ private async Task GetWsmContainerResourceIdFromCacheOrWsmAsync(Guid works r.ResourceAttributes.AzureStorageContainer.StorageContainerName.Equals(containerName, StringComparison.OrdinalIgnoreCase)).Metadata; - logger.LogInformation("Found the resource ID for storage container resource. Resource ID: {ContainerResourceId} Container Name: {BlobContainerName}", metadata.ResourceId, containerName); + logger.LogDebug("Found the resource ID for storage container resource. Resource ID: {ContainerResourceId} Container Name: {BlobContainerName}", metadata.ResourceId, containerName); var resourceId = Guid.Parse(metadata.ResourceId); diff --git a/src/Tes.Runner/Transfer/BlobApiHttpUtils.cs b/src/Tes.Runner/Transfer/BlobApiHttpUtils.cs index 744f15329..9703d6f62 100644 --- a/src/Tes.Runner/Transfer/BlobApiHttpUtils.cs +++ b/src/Tes.Runner/Transfer/BlobApiHttpUtils.cs @@ -7,13 +7,13 @@ using System.Text; using Azure.Storage.Blobs; using Microsoft.Extensions.Logging; -using Polly.Retry; +using static CommonUtilities.RetryHandler; namespace Tes.Runner.Transfer; /// /// A class containing the logic to create and make the HTTP requests for the blob block API. /// -public class BlobApiHttpUtils(HttpClient httpClient, AsyncRetryPolicy retryPolicy) +public class BlobApiHttpUtils(HttpClient httpClient, AsyncRetryHandlerPolicy retryPolicy) { //https://learn.microsoft.com/en-us/rest/api/storageservices/understanding-block-blobs--append-blobs--and-page-blobs public const string DefaultApiVersion = "2023-05-03"; @@ -22,7 +22,7 @@ public class BlobApiHttpUtils(HttpClient httpClient, AsyncRetryPolicy retryPolic private readonly HttpClient httpClient = httpClient; private static readonly ILogger Logger = PipelineLoggerFactory.Create(); - private readonly AsyncRetryPolicy retryPolicy = retryPolicy; + private readonly AsyncRetryHandlerPolicy retryPolicy = retryPolicy; public const string RootHashMetadataName = "md5_4mib_hashlist_root_hash"; @@ -57,7 +57,7 @@ public static HttpRequestMessage CreatePutAppendBlockRequestAsync(string data, U } public static HttpRequestMessage CreatePutBlobRequestAsync(Uri blobUrl, string? content, string apiVersion, - Dictionary? tags, string blobType = BlockBlobType) + IDictionary? tags, string blobType = BlockBlobType) { ArgumentNullException.ThrowIfNull(blobUrl); ArgumentException.ThrowIfNullOrEmpty(apiVersion, nameof(apiVersion)); @@ -75,7 +75,7 @@ public static HttpRequestMessage CreatePutBlobRequestAsync(Uri blobUrl, string? return request; } - private static void AddPutBlobHeaders(HttpRequestMessage request, string apiVersion, Dictionary? tags, string blobType) + private static void AddPutBlobHeaders(HttpRequestMessage request, string apiVersion, IDictionary? tags, string blobType) { request.Headers.Add("x-ms-blob-type", blobType); @@ -167,7 +167,7 @@ public static HttpRequestMessage CreateBlobBlockListRequest(long length, Uri blo public async Task ExecuteHttpRequestAsync(Func requestFactory, CancellationToken cancellationToken = default) { - return await retryPolicy.ExecuteAsync(ct => ExecuteHttpRequestImplAsync(requestFactory, ct), cancellationToken); + return await retryPolicy.ExecuteWithRetryAsync(ct => ExecuteHttpRequestImplAsync(requestFactory, ct), cancellationToken); } public static bool UrlContainsSasToken(string sourceUrl) @@ -262,7 +262,7 @@ private static void HandleHttpRequestException(HttpStatusCode? status, HttpReque public async Task ExecuteHttpRequestAndReadBodyResponseAsync(PipelineBuffer buffer, Func requestFactory, CancellationToken cancellationToken = default) { - return await retryPolicy.ExecuteAsync(ct => ExecuteHttpRequestAndReadBodyResponseImplAsync(buffer, requestFactory, ct), cancellationToken); + return await retryPolicy.ExecuteWithRetryAsync(ct => ExecuteHttpRequestAndReadBodyResponseImplAsync(buffer, requestFactory, ct), cancellationToken); } private static bool ContainsRetriableException(Exception? ex) diff --git a/src/Tes.Runner/Transfer/BlobDownloader.cs b/src/Tes.Runner/Transfer/BlobDownloader.cs index acf92ce3c..f3d368d8d 100644 --- a/src/Tes.Runner/Transfer/BlobDownloader.cs +++ b/src/Tes.Runner/Transfer/BlobDownloader.cs @@ -108,7 +108,7 @@ public override async Task GetSourceLengthAsync(string source) /// public override Task OnCompletionAsync(long length, Uri? blobUrl, string fileName, string? rootHash, string? contentMd5) { - Logger.LogInformation($"Completed download. Total bytes: {length:n0} Filename: {fileName}"); + Logger.LogDebug($"Completed download. Total bytes: {length:n0} Filename: {fileName}"); return Task.CompletedTask; } diff --git a/src/Tes.Runner/Transfer/BlobOperationPipeline.cs b/src/Tes.Runner/Transfer/BlobOperationPipeline.cs index f7011f203..a79b69a94 100644 --- a/src/Tes.Runner/Transfer/BlobOperationPipeline.cs +++ b/src/Tes.Runner/Transfer/BlobOperationPipeline.cs @@ -64,13 +64,13 @@ protected BlobOperationPipeline(BlobPipelineOptions pipelineOptions, Channel ExecutePipelineAsync(List operatio { await WhenAllFailFast(pipelineTasks); - Logger.LogInformation("Pipeline processing completed."); + Logger.LogDebug("Pipeline processing completed."); } catch (Exception e) { @@ -99,9 +99,9 @@ protected async Task ExecutePipelineAsync(List operatio throw; } - Logger.LogInformation("Waiting for processed part processor to complete."); + Logger.LogDebug("Waiting for processed part processor to complete."); var bytesProcessed = await processedPartsProcessorTask; - Logger.LogInformation("Processed parts completed."); + Logger.LogDebug("Processed parts completed."); return bytesProcessed; } diff --git a/src/Tes.Runner/Transfer/BlobUploader.cs b/src/Tes.Runner/Transfer/BlobUploader.cs index d1f53ecaa..e01d7a6ff 100644 --- a/src/Tes.Runner/Transfer/BlobUploader.cs +++ b/src/Tes.Runner/Transfer/BlobUploader.cs @@ -14,6 +14,9 @@ public class BlobUploader : BlobOperationPipeline { private readonly ConcurrentDictionary hashListProviders = new(); + internal readonly IDictionary mapPathToMountPrefixLength = new Dictionary(); + internal readonly ConcurrentBag CompletedFiles = new(); + public BlobUploader(BlobPipelineOptions pipelineOptions, Channel memoryBufferPool) : base(pipelineOptions, memoryBufferPool) { } @@ -144,6 +147,14 @@ public override async Task OnCompletionAsync(long length, Uri? blobUrl, string f { response?.Dispose(); } + + if (mapPathToMountPrefixLength.TryGetValue(fileName, out var prefixLength)) + { + CompletedFiles.Add(new( + length, + new Azure.Storage.Blobs.BlobUriBuilder(blobUrl) { Sas = null }.ToUri(), + fileName[prefixLength..])); + } } /// @@ -156,6 +167,11 @@ public virtual async Task UploadAsync(List uploadList) { ValidateUploadList(uploadList); + foreach (var upload in uploadList.Where(upload => !string.IsNullOrWhiteSpace(upload.MountParentDirectory)).Where(upload => upload.FullFilePath.StartsWith(upload.MountParentDirectory!))) + { + mapPathToMountPrefixLength[upload.FullFilePath] = upload.MountParentDirectory!.Length; + } + var operationList = uploadList.Select(d => new BlobOperationInfo(d.TargetUri, d.FullFilePath, d.FullFilePath, true)).ToList(); return await ExecutePipelineAsync(operationList); diff --git a/src/Tes.Runner/Transfer/DefaultFileInfoProvider.cs b/src/Tes.Runner/Transfer/DefaultFileInfoProvider.cs index 69d6b140d..cd71368b4 100644 --- a/src/Tes.Runner/Transfer/DefaultFileInfoProvider.cs +++ b/src/Tes.Runner/Transfer/DefaultFileInfoProvider.cs @@ -15,25 +15,25 @@ public class DefaultFileInfoProvider : IFileInfoProvider public long GetFileSize(string fileName) { - logger.LogDebug("Getting file size for file: {Path}", fileName); + logger.LogTrace("Getting file size for file: {Path}", fileName); return GetFileInfoOrThrowIfFileDoesNotExist(fileName).Length; } public string GetExpandedFileName(string fileName) { - logger.LogDebug("Expanding file name: {Path}", fileName); + logger.LogTrace("Expanding file name: {Path}", fileName); var expandedValue = Environment.ExpandEnvironmentVariables(fileName); - logger.LogDebug("Expanded file name: {ExpandedPath}", expandedValue); + logger.LogTrace("Expanded file name: {ExpandedPath}", expandedValue); return expandedValue; } public bool FileExists(string fileName) { - logger.LogDebug("Checking if file exists: {Path}", fileName); + logger.LogTrace("Checking if file exists: {Path}", fileName); var fileInfo = new FileInfo(Environment.ExpandEnvironmentVariables(fileName)); @@ -43,7 +43,7 @@ public bool FileExists(string fileName) public List GetFilesBySearchPattern(string searchPath, string searchPattern) { - logger.LogInformation("Searching for files in the search path: {Path} with search pattern: {SearchPattern}", searchPath, searchPattern); + logger.LogDebug("Searching for files in the search path: {Path} with search pattern: {SearchPattern}", searchPath, searchPattern); return Directory.GetFiles(Environment.ExpandEnvironmentVariables(searchPath), Environment.ExpandEnvironmentVariables(searchPattern), SearchOption.AllDirectories) .Select(f => new FileResult(f, ToRelativePathToSearchPath(searchPath, searchPattern, f), searchPath)) @@ -62,7 +62,7 @@ private string ToRelativePathToSearchPath(string searchPath, string searchPatter if (!string.IsNullOrWhiteSpace(prefixToRemove) && absolutePath.StartsWith(prefixToRemove)) { - logger.LogInformation("Removing prefix: {Prefix} from absolute path: {Path}", prefixToRemove, absolutePath); + logger.LogDebug("Removing prefix: {Prefix} from absolute path: {Path}", prefixToRemove, absolutePath); return absolutePath[(prefixToRemove.Length + 1)..]; } @@ -74,7 +74,7 @@ public List GetAllFilesInDirectory(string path) { var expandedPath = Environment.ExpandEnvironmentVariables(path); - logger.LogInformation("Getting all files in directory: {Path}", expandedPath); + logger.LogDebug("Getting all files in directory: {Path}", expandedPath); if (!Directory.Exists(expandedPath)) { diff --git a/src/Tes.Runner/Transfer/HttpRetryPolicyDefinition.cs b/src/Tes.Runner/Transfer/HttpRetryPolicyDefinition.cs index 8dc051e91..b733b6331 100644 --- a/src/Tes.Runner/Transfer/HttpRetryPolicyDefinition.cs +++ b/src/Tes.Runner/Transfer/HttpRetryPolicyDefinition.cs @@ -1,9 +1,9 @@ // Copyright (c) Microsoft Corporation. // Licensed under the MIT License. +using CommonUtilities; using Microsoft.Extensions.Logging; -using Polly; -using Polly.Retry; +using static CommonUtilities.RetryHandler; namespace Tes.Runner.Transfer { @@ -13,20 +13,13 @@ public class HttpRetryPolicyDefinition public const int RetryExponent = 2; private static readonly ILogger Logger = PipelineLoggerFactory.Create(); - public static AsyncRetryPolicy DefaultAsyncRetryPolicy(int maxRetryCount = DefaultMaxRetryCount) + public static AsyncRetryHandlerPolicy DefaultAsyncRetryPolicy(int maxRetryCount = DefaultMaxRetryCount) { - return Policy - .Handle() - .WaitAndRetryAsync(maxRetryCount, retryAttempt => - { - return TimeSpan.FromSeconds(Math.Pow(RetryExponent, retryAttempt)); - }, - onRetryAsync: - (exception, _, retryCount, _) => - { - Logger.LogError(exception, "Retrying failed request. Retry count: {retryCount}", retryCount); - return Task.CompletedTask; - }); + return new RetryPolicyBuilder(Microsoft.Extensions.Options.Options.Create(new CommonUtilities.Options.RetryPolicyOptions() { MaxRetryCount = maxRetryCount, ExponentialBackOffExponent = RetryExponent })) + .PolicyBuilder.OpinionatedRetryPolicy(Polly.Policy.Handle()) + .WithRetryPolicyOptionsWait() + .SetOnRetryBehavior(Logger) + .AsyncBuild(); } } } diff --git a/src/Tes.Runner/Transfer/Md5HashListProvider.cs b/src/Tes.Runner/Transfer/Md5HashListProvider.cs index b6717b851..11a99a34c 100644 --- a/src/Tes.Runner/Transfer/Md5HashListProvider.cs +++ b/src/Tes.Runner/Transfer/Md5HashListProvider.cs @@ -43,7 +43,7 @@ public string GetRootHash() var rootHash = CreateBlockMd5CheckSumValue(data, 0, data.Length); - logger.LogInformation($"Root Hash: {rootHash} set in property: {BlobApiHttpUtils.RootHashMetadataName}"); + logger.LogDebug($"Root Hash: {rootHash} set in property: {BlobApiHttpUtils.RootHashMetadataName}"); return rootHash; } diff --git a/src/Tes.Runner/Transfer/PartsProcessor.cs b/src/Tes.Runner/Transfer/PartsProcessor.cs index cb700f394..fe4ad4ab3 100644 --- a/src/Tes.Runner/Transfer/PartsProcessor.cs +++ b/src/Tes.Runner/Transfer/PartsProcessor.cs @@ -73,7 +73,7 @@ protected Task StartProcessorsWithScalingStrategyAsync(int numberOfProcessors, C if (!scalingStrategy.IsScalingAllowed(p, currentMaxPartProcessingTime)) { - logger.LogInformation("The maximum number of tasks for the transfer operation has been set. Max part processing time is: {currentMaxPartProcessingTimeInMs} ms. Processing tasks count: {processorCount}.", currentMaxPartProcessingTime, p); + logger.LogDebug("The maximum number of tasks for the transfer operation has been set. Max part processing time is: {currentMaxPartProcessingTimeInMs} ms. Processing tasks count: {processorCount}.", currentMaxPartProcessingTime, p); break; } } @@ -84,13 +84,13 @@ protected Task StartProcessorsWithScalingStrategyAsync(int numberOfProcessors, C if (readFromChannel.Reader.Completion.IsCompleted) { - logger.LogInformation("The readFromChannel is completed, no need to add more processing tasks. Processing tasks count: {processorCount}.", p); + logger.LogDebug("The readFromChannel is completed, no need to add more processing tasks. Processing tasks count: {processorCount}.", p); break; } var delay = scalingStrategy.GetScalingDelay(p); - logger.LogInformation("Increasing the number of processing tasks to {processorCount}", p + 1); + logger.LogDebug("Increasing the number of processing tasks to {processorCount}", p + 1); tasks.Add(StartProcessorTaskAsync(readFromChannel, processorAsync, cancellationSource)); diff --git a/src/Tes.Runner/Transfer/PartsProducer.cs b/src/Tes.Runner/Transfer/PartsProducer.cs index 9010a45e5..e4d741fa3 100644 --- a/src/Tes.Runner/Transfer/PartsProducer.cs +++ b/src/Tes.Runner/Transfer/PartsProducer.cs @@ -49,7 +49,7 @@ public async Task StartPartsProducersAsync(List blobOperation { await Task.WhenAll(partsProducerTasks); - logger.LogInformation("All parts from requested operations were created."); + logger.LogDebug("All parts from requested operations were created."); } catch (Exception e) { diff --git a/src/Tes.Runner/Transfer/PartsReader.cs b/src/Tes.Runner/Transfer/PartsReader.cs index af87fe119..1aaf554ae 100644 --- a/src/Tes.Runner/Transfer/PartsReader.cs +++ b/src/Tes.Runner/Transfer/PartsReader.cs @@ -53,6 +53,6 @@ async Task ReadPartAsync(PipelineBuffer buffer, CancellationToken cancellationTo writeBufferChannel.Writer.Complete(); } - logger.LogInformation("All part read operations completed successfully."); + logger.LogDebug("All part read operations completed successfully."); } } diff --git a/src/Tes.Runner/Transfer/PartsWriter.cs b/src/Tes.Runner/Transfer/PartsWriter.cs index 5931ba43a..c5de2fd71 100644 --- a/src/Tes.Runner/Transfer/PartsWriter.cs +++ b/src/Tes.Runner/Transfer/PartsWriter.cs @@ -53,7 +53,7 @@ async Task WritePartAsync(PipelineBuffer buffer, CancellationToken cancellationT processedBufferChannel.Writer.Complete(); } - logger.LogInformation("All part write operations completed successfully."); + logger.LogDebug("All part write operations completed successfully."); } private ProcessedBuffer ToProcessedBuffer(PipelineBuffer buffer) diff --git a/src/Tes.Runner/Transfer/ProcessedPartsProcessor.cs b/src/Tes.Runner/Transfer/ProcessedPartsProcessor.cs index 29cd85bdc..f19ed27d6 100644 --- a/src/Tes.Runner/Transfer/ProcessedPartsProcessor.cs +++ b/src/Tes.Runner/Transfer/ProcessedPartsProcessor.cs @@ -70,7 +70,7 @@ public async ValueTask StartProcessedPartsProcessorAsync(int expectedNumbe readBufferChannel.Writer.Complete(); } - logger.LogInformation("All parts were successfully processed."); + logger.LogDebug("All parts were successfully processed."); return totalBytes; } @@ -99,7 +99,7 @@ private async Task CompleteFileProcessingAsync(ProcessedBuffer buffer, Cancellat if (!cancellationTokenSource.IsCancellationRequested) { - logger.LogDebug("Cancelling tasks in the processed parts processor."); + logger.LogTrace("Cancelling tasks in the processed parts processor."); cancellationTokenSource.Cancel(); } throw; diff --git a/src/Tes.Runner/Transfer/UploadInfo.cs b/src/Tes.Runner/Transfer/UploadInfo.cs index d1b0d6655..bc879b6fb 100644 --- a/src/Tes.Runner/Transfer/UploadInfo.cs +++ b/src/Tes.Runner/Transfer/UploadInfo.cs @@ -3,5 +3,5 @@ namespace Tes.Runner.Transfer { - public record UploadInfo(string FullFilePath, Uri TargetUri); + public record UploadInfo(string FullFilePath, Uri TargetUri, string? MountParentDirectory = null); } diff --git a/src/Tes.RunnerCLI/Commands/CommandFactory.cs b/src/Tes.RunnerCLI/Commands/CommandFactory.cs index 4941e7ad0..45e7c0485 100644 --- a/src/Tes.RunnerCLI/Commands/CommandFactory.cs +++ b/src/Tes.RunnerCLI/Commands/CommandFactory.cs @@ -18,6 +18,7 @@ internal static class CommandFactory internal const string DownloadCommandName = "download"; internal const string ExecutorCommandName = "exec"; internal const string DockerUriOption = "docker-url"; + internal const string ExecutorSelectorOption = "executor"; private static readonly IReadOnlyCollection /// The file path to convert. Two-part path is treated as container path. Paths with three or more parts are treated as blobs. + /// Requested permissions to include in the SAS token. /// A for controlling the lifetime of the asynchronous operation. /// Duration SAS should be valid. - /// Get the container SAS even if path is longer than two parts. /// An Azure Block Blob or Container URL with SAS token - public Task MapLocalPathToSasUrlAsync(string path, CancellationToken cancellationToken, TimeSpan? sasTokenDuration = default, bool getContainerSas = false); + public Task MapLocalPathToSasUrlAsync(string path, BlobSasPermissions sasPermissions, CancellationToken cancellationToken, TimeSpan? sasTokenDuration = default); /// /// Returns an Azure Storage Blob URL with a SAS token for the specified blob path in the TES internal storage location. /// - /// Path within the reserved blob storage area. + /// A relative path within the blob storage space reserved for the TES server. + /// Requested permissions to include in the SAS token. /// A for controlling the lifetime of the asynchronous operation. /// An Azure Block Blob or Container URL with SAS token - public Task GetInternalTesBlobUrlAsync(string blobPath, CancellationToken cancellationToken); + public Task GetInternalTesBlobUrlAsync(string blobPath, BlobSasPermissions sasPermissions, CancellationToken cancellationToken); /// /// Returns an Azure Storage Blob URL with a SAS token for the specified blob path in the TES task internal storage location. /// - /// A . - /// Path within the reserved blob storage area. + /// A + /// A relative path within the blob storage space reserved for the . + /// Requested permissions to include in the SAS token. /// A for controlling the lifetime of the asynchronous operation. - /// An Azure Block Blob URL with SAS token in the area reserved for . - public Task GetInternalTesTaskBlobUrlAsync(TesTask task, string blobPath, CancellationToken cancellationToken); + /// An Azure Block Blob storage URL with SAS token. + public Task GetInternalTesTaskBlobUrlAsync(TesTask task, string blobPath, BlobSasPermissions sasPermissions, CancellationToken cancellationToken); /// /// Returns an Azure Storage Blob URL without a SAS token for the specified blob path in the TES task internal storage location. /// - /// A . - /// Path within the reserved blob storage area. + /// A + /// A relative path within the blob storage space reserved for the . /// An Azure Block Blob URL without SAS token in the area reserved for . public Uri GetInternalTesTaskBlobUrlWithoutSasToken(TesTask task, string blobPath); /// /// Returns an Azure Storage Blob URL without a SAS token for the specified blob path in the TES internal storage location. /// - /// Path within the reserved blob storage area. + /// A relative path within the blob storage space reserved for the TES server. /// An Azure Block Blob or Container URL without SAS token. public Uri GetInternalTesBlobUrlWithoutSasToken(string blobPath); diff --git a/src/TesApi.Web/Storage/StorageAccessProvider.cs b/src/TesApi.Web/Storage/StorageAccessProvider.cs index 35377cc31..40bff10ed 100644 --- a/src/TesApi.Web/Storage/StorageAccessProvider.cs +++ b/src/TesApi.Web/Storage/StorageAccessProvider.cs @@ -6,6 +6,7 @@ using System.Linq; using System.Threading; using System.Threading.Tasks; +using Azure.Storage.Sas; using Microsoft.Extensions.Logging; using Tes.Models; @@ -45,7 +46,7 @@ public abstract class StorageAccessProvider : IStorageAccessProvider /// /// Logger /// Azure proxy - public StorageAccessProvider(ILogger logger, IAzureProxy azureProxy) + protected StorageAccessProvider(ILogger logger, IAzureProxy azureProxy) { this.Logger = logger; this.AzureProxy = azureProxy; @@ -54,7 +55,7 @@ public StorageAccessProvider(ILogger logger, IAzureProxy azureProxy) /// public async Task DownloadBlobAsync(string blobRelativePath, CancellationToken cancellationToken) { - var blobUrl = await MapLocalPathToSasUrlAsync(blobRelativePath, cancellationToken); + var blobUrl = await MapLocalPathToSasUrlAsync(blobRelativePath, BlobSasPermissions.Read, cancellationToken, sasTokenDuration: default); if (blobUrl is null) { @@ -94,31 +95,19 @@ public async Task UploadBlobAsync(Uri blobAbsoluteUrl, string content, /// public async Task> GetBlobUrlsAsync(Uri blobVirtualDirectory, CancellationToken cancellationToken) - { - Azure.Storage.Blobs.BlobUriBuilder blobBuilder = new(blobVirtualDirectory) { Sas = null }; - return (await AzureProxy.ListBlobsAsync(blobVirtualDirectory, cancellationToken)).Select(GetBlobUri).ToList(); - - Uri GetBlobUri(Azure.Storage.Blobs.Models.BlobItem blob) - { - // This implementation reuses the BlobUriBuilder in the parent method, so GetBlobUri cannot be called in parallel with the same instance of BlobUriBuilder. - // It is safe for concurrent instances of GetBlobUrlsAsync to run simultaneously, however. - // Refactor if the ListBlobsAsync enumeration is ever parallelized at the stage of calling this converter method. - blobBuilder.BlobName = blob.Name; - return blobBuilder.ToUri(); - } - } + => await AzureProxy.ListBlobsAsync(blobVirtualDirectory, cancellationToken).Select(blob => blob.BlobUri).ToListAsync(cancellationToken); /// public abstract Task IsPublicHttpUrlAsync(string uriString, CancellationToken cancellationToken); /// - public abstract Task MapLocalPathToSasUrlAsync(string path, CancellationToken cancellationToken, TimeSpan? sasTokenDuration = default, bool getContainerSas = false); + public abstract Task MapLocalPathToSasUrlAsync(string path, BlobSasPermissions sasPermissions, CancellationToken cancellationToken, TimeSpan? sasTokenDuration); /// - public abstract Task GetInternalTesBlobUrlAsync(string blobPath, CancellationToken cancellationToken); + public abstract Task GetInternalTesBlobUrlAsync(string blobPath, BlobSasPermissions sasPermissions, CancellationToken cancellationToken); /// - public abstract Task GetInternalTesTaskBlobUrlAsync(TesTask task, string blobPath, CancellationToken cancellationToken); + public abstract Task GetInternalTesTaskBlobUrlAsync(TesTask task, string blobPath, BlobSasPermissions sasPermissions, CancellationToken cancellationToken); /// public abstract Uri GetInternalTesTaskBlobUrlWithoutSasToken(TesTask task, string blobPath); diff --git a/src/TesApi.Web/Storage/TerraStorageAccessProvider.cs b/src/TesApi.Web/Storage/TerraStorageAccessProvider.cs index 781fc6c73..ab524886b 100644 --- a/src/TesApi.Web/Storage/TerraStorageAccessProvider.cs +++ b/src/TesApi.Web/Storage/TerraStorageAccessProvider.cs @@ -8,6 +8,7 @@ using System.Threading.Tasks; using System.Web; using Azure.Storage.Blobs; +using Azure.Storage.Sas; using CommonUtilities; using Microsoft.Extensions.Logging; using Microsoft.Extensions.Options; @@ -47,7 +48,7 @@ public class TerraStorageAccessProvider : StorageAccessProvider public TerraStorageAccessProvider(Lazy terraWsmApiClient, IAzureProxy azureProxy, IOptions terraOptions, IOptions batchSchedulingOptions, AzureEnvironmentConfig azureEnvironmentConfig, ILogger logger) : base( - logger, azureProxy) + logger, azureProxy) { ArgumentNullException.ThrowIfNull(terraOptions); ArgumentNullException.ThrowIfNull(batchSchedulingOptions); @@ -87,9 +88,12 @@ public override Task IsPublicHttpUrlAsync(string uriString, CancellationTo } /// - public override async Task MapLocalPathToSasUrlAsync(string path, CancellationToken cancellationToken, TimeSpan? sasTokenDuration = default, bool getContainerSas = false) + public override async Task MapLocalPathToSasUrlAsync(string path, BlobSasPermissions sasPermissions, CancellationToken cancellationToken, TimeSpan? sasTokenDuration) { + // Currently all SAS tokens with Terra are R/W but sasPermissions so only List value is used to select between a Container SAS vs a Blob SAS. + ArgumentException.ThrowIfNullOrEmpty(path); + if (sasTokenDuration is not null) { throw new ArgumentException("Terra does not support extended length SAS tokens."); @@ -102,12 +106,12 @@ public override async Task MapLocalPathToSasUrlAsync(string path, Cancellat var terraBlobInfo = await GetTerraBlobInfoFromContainerNameAsync(path, cancellationToken); - if (getContainerSas) + if (sasPermissions.HasFlag(BlobSasPermissions.List) || string.IsNullOrWhiteSpace(terraBlobInfo.BlobName)) { - return await GetMappedSasContainerUrlFromWsmAsync(terraBlobInfo, cancellationToken); + return await GetMappedSasContainerUrlFromWsmAsync(terraBlobInfo, false, cancellationToken); } - return await GetMappedSasUrlFromWsmAsync(terraBlobInfo, cancellationToken); + return await GetMappedSasUrlFromWsmAsync(terraBlobInfo, false, cancellationToken); } /// @@ -116,16 +120,18 @@ public override async Task MapLocalPathToSasUrlAsync(string path, Cancellat /// If the blobPath is not provided(empty), a container SAS token is generated. /// If the blobPath is provided, a SAS token to the blobPath prefixed with the TES internal segments is generated. /// - public override async Task GetInternalTesBlobUrlAsync(string blobPath, CancellationToken cancellationToken) + public override async Task GetInternalTesBlobUrlAsync(string blobPath, BlobSasPermissions sasPermissions, CancellationToken cancellationToken) { + // Currently all SAS tokens with Terra are R/W so sasPermissions is waiting for a safer future. + var blobInfo = GetTerraBlobInfoForInternalTes(blobPath); if (string.IsNullOrEmpty(blobPath)) { - return await GetMappedSasContainerUrlFromWsmAsync(blobInfo, cancellationToken); + return await GetMappedSasContainerUrlFromWsmAsync(blobInfo, sasPermissions.HasFlag(BlobSasPermissions.Tag), cancellationToken); } - return await GetMappedSasUrlFromWsmAsync(blobInfo, cancellationToken); + return await GetMappedSasUrlFromWsmAsync(blobInfo, sasPermissions.HasFlag(BlobSasPermissions.Tag), cancellationToken); } /// @@ -134,16 +140,18 @@ public override async Task GetInternalTesBlobUrlAsync(string blobPath, Canc /// If the blobPath is not provided(empty), a container SAS token is generated. /// If the blobPath is provided, a SAS token to the blobPath prefixed with the TES task internal segments is generated. /// - public override async Task GetInternalTesTaskBlobUrlAsync(TesTask task, string blobPath, CancellationToken cancellationToken) + public override async Task GetInternalTesTaskBlobUrlAsync(TesTask task, string blobPath, BlobSasPermissions sasPermissions, CancellationToken cancellationToken) { + // Currently all SAS tokens with Terra are R/W so sasPermissions is waiting for a safer future. + var blobInfo = GetTerraBlobInfoForInternalTesTask(task, blobPath); if (string.IsNullOrEmpty(blobPath)) { - return await GetMappedSasContainerUrlFromWsmAsync(blobInfo, cancellationToken); + return await GetMappedSasContainerUrlFromWsmAsync(blobInfo, false, cancellationToken); } - return await GetMappedSasUrlFromWsmAsync(blobInfo, cancellationToken); + return await GetMappedSasUrlFromWsmAsync(blobInfo, false, cancellationToken); } /// @@ -233,11 +241,11 @@ private async Task GetTerraBlobInfoFromContainerNameAsync(string CheckIfAccountIsTerraStorageAccount(segments.AccountName); - Logger.LogInformation($"Getting Workspace ID from the Container Name: {segments.ContainerName}"); + Logger.LogDebug($"Getting Workspace ID from the Container Name: {segments.ContainerName}"); var workspaceId = ToWorkspaceId(segments.ContainerName); - Logger.LogInformation($"Workspace ID to use: {segments.ContainerName}"); + Logger.LogDebug($"Workspace ID to use: {segments.ContainerName}"); var wsmContainerResourceId = await GetWsmContainerResourceIdAsync(workspaceId, segments.ContainerName, cancellationToken); @@ -246,7 +254,7 @@ private async Task GetTerraBlobInfoFromContainerNameAsync(string private async Task GetWsmContainerResourceIdAsync(Guid workspaceId, string containerName, CancellationToken cancellationToken) { - Logger.LogInformation($"Getting container resource information from WSM. Workspace ID: {workspaceId} Container Name: {containerName}"); + Logger.LogDebug($"Getting container resource information from WSM. Workspace ID: {workspaceId} Container Name: {containerName}"); try { @@ -258,7 +266,7 @@ private async Task GetWsmContainerResourceIdAsync(Guid workspaceId, string r.ResourceAttributes.AzureStorageContainer.StorageContainerName.Equals(containerName, StringComparison.OrdinalIgnoreCase)).Metadata; - Logger.LogInformation($"Found the resource id for storage container resource. Resource ID: {metadata.ResourceId} Container Name: {containerName}"); + Logger.LogDebug($"Found the resource id for storage container resource. Resource ID: {metadata.ResourceId} Container Name: {containerName}"); return Guid.Parse(metadata.ResourceId); } @@ -281,7 +289,7 @@ private Guid ToWorkspaceId(string segmentsContainerName) { ArgumentException.ThrowIfNullOrEmpty(segmentsContainerName); - var guidString = segmentsContainerName.Substring(3); // remove the sc- prefix + var guidString = segmentsContainerName[3..]; // remove the sc- prefix return Guid.Parse(guidString); // throws if not a guid } @@ -292,9 +300,9 @@ private Guid ToWorkspaceId(string segmentsContainerName) } } - private async Task GetMappedSasContainerUrlFromWsmAsync(TerraBlobInfo blobInfo, CancellationToken cancellationToken) + private async Task GetMappedSasContainerUrlFromWsmAsync(TerraBlobInfo blobInfo, bool? needsTags, CancellationToken cancellationToken) { - var tokenInfo = await GetWorkspaceContainerSasTokenFromWsmAsync(blobInfo, cancellationToken); + var tokenInfo = await GetWorkspaceContainerSasTokenFromWsmAsync(blobInfo, needsTags, cancellationToken); var urlBuilder = new UriBuilder(tokenInfo.Url); @@ -310,13 +318,14 @@ private async Task GetMappedSasContainerUrlFromWsmAsync(TerraBlobInfo blobI /// Returns a Url with a SAS token for the given input /// /// + /// /// A for controlling the lifetime of the asynchronous operation. /// URL with a SAS token - public async Task GetMappedSasUrlFromWsmAsync(TerraBlobInfo blobInfo, CancellationToken cancellationToken) + internal async Task GetMappedSasUrlFromWsmAsync(TerraBlobInfo blobInfo, bool? needsTags, CancellationToken cancellationToken) { - var tokenInfo = await GetWorkspaceBlobSasTokenFromWsmAsync(blobInfo, cancellationToken); + var tokenInfo = await GetWorkspaceBlobSasTokenFromWsmAsync(blobInfo, needsTags, cancellationToken); - Logger.LogInformation($"Successfully obtained the Sas Url from Terra. Wsm resource id:{terraOptions.WorkspaceStorageContainerResourceId}"); + Logger.LogDebug($"Successfully obtained the Sas Url from Terra. Wsm resource id:{terraOptions.WorkspaceStorageContainerResourceId}"); var uriBuilder = new UriBuilder(tokenInfo.Url); @@ -338,11 +347,11 @@ private SasTokenApiParameters CreateTokenParamsFromOptions(string blobName, stri sasPermissions, blobName); - private async Task GetWorkspaceBlobSasTokenFromWsmAsync(TerraBlobInfo blobInfo, CancellationToken cancellationToken) + private async Task GetWorkspaceBlobSasTokenFromWsmAsync(TerraBlobInfo blobInfo, bool? needsTags, CancellationToken cancellationToken) { - var tokenParams = CreateTokenParamsFromOptions(blobInfo.BlobName, SasBlobPermissions); + var tokenParams = CreateTokenParamsFromOptions(blobInfo.BlobName, SasBlobPermissions + (needsTags.GetValueOrDefault() ? "t" : string.Empty)); - Logger.LogInformation( + Logger.LogDebug( $"Getting Sas Url from Terra. Wsm workspace id:{blobInfo.WorkspaceId}"); return await terraWsmApiClient.Value.GetSasTokenAsync( @@ -351,12 +360,12 @@ private async Task GetWorkspaceBlobSasTokenFromWsmAsync( tokenParams, cancellationToken); } - private async Task GetWorkspaceContainerSasTokenFromWsmAsync(TerraBlobInfo blobInfo, CancellationToken cancellationToken) + private async Task GetWorkspaceContainerSasTokenFromWsmAsync(TerraBlobInfo blobInfo, bool? needsTags, CancellationToken cancellationToken) { // an empty blob name gets a container Sas token - var tokenParams = CreateTokenParamsFromOptions(blobName: "", SasContainerPermissions); + var tokenParams = CreateTokenParamsFromOptions(blobName: "", SasContainerPermissions + (needsTags.GetValueOrDefault() ? "t" : string.Empty)); - Logger.LogInformation( + Logger.LogDebug( $"Getting Sas container Url from Terra. Wsm workspace id:{blobInfo.WorkspaceId}"); return await terraWsmApiClient.Value.GetSasTokenAsync( diff --git a/src/TesApi.Web/TaskScheduler.cs b/src/TesApi.Web/TaskScheduler.cs new file mode 100644 index 000000000..d8786a002 --- /dev/null +++ b/src/TesApi.Web/TaskScheduler.cs @@ -0,0 +1,554 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +using System; +using System.Collections.Concurrent; +using System.Collections.Generic; +using System.Collections.Immutable; +using System.Linq; +using System.Threading; +using System.Threading.Channels; +using System.Threading.Tasks; +using CommonUtilities; +using Microsoft.Extensions.Logging; +using Tes.Models; +using Tes.Repository; +using TesApi.Web.Events; +using TesApi.Web.Extensions; + +namespace TesApi.Web +{ + /// + /// An interface for scheduling s. + /// + public interface ITaskScheduler + { + + /// + /// Schedules a + /// + /// A to schedule on the batch system. + void QueueTesTask(TesTask tesTask); + + /// + /// Updates s with task-related state + /// + /// s to schedule on the batch system. + /// s corresponding to each . + /// A for controlling the lifetime of the asynchronous operation. + /// True for each corresponding that needs to be persisted. + IAsyncEnumerable> ProcessTesTaskBatchStatesAsync(IEnumerable tesTasks, AzureBatchTaskState[] taskStates, CancellationToken cancellationToken); + } + + /// + /// A background service that schedules s in the batch system, orchestrates their lifecycle, and updates their state. + /// This should only be used as a system-wide singleton service. This class does not support scale-out on multiple machines, + /// nor does it implement a leasing mechanism. In the future, consider using the Lease Blob operation. + /// + /// The task node event processor. + /// Used for requesting termination of the current application during initialization. + /// The main TES task database repository implementation. + /// The batch scheduler implementation. + /// The logger instance. + internal class TaskScheduler(RunnerEventsProcessor nodeEventProcessor, Microsoft.Extensions.Hosting.IHostApplicationLifetime hostApplicationLifetime, IRepository repository, IBatchScheduler batchScheduler, ILogger taskSchedulerLogger) + : OrchestrateOnBatchSchedulerServiceBase(hostApplicationLifetime, repository, batchScheduler, taskSchedulerLogger) + , ITaskScheduler + { + private static readonly TimeSpan blobRunInterval = TimeSpan.FromSeconds(15); + private static readonly TimeSpan queuedRunInterval = TimeSpan.FromMilliseconds(100); + internal static readonly TimeSpan BatchRunInterval = TimeSpan.FromSeconds(30); // The very fastest processes inside of Azure Batch accessing anything within pools or jobs appears to use a 30 second polling interval + private static readonly TimeSpan shortBackgroundRunInterval = TimeSpan.FromSeconds(1); + private static readonly TimeSpan longBackgroundRunInterval = TimeSpan.FromSeconds(2.5); + private static readonly TimeSpan orphanedTaskInterval = TimeSpan.FromMinutes(10); + private readonly RunnerEventsProcessor nodeEventProcessor = nodeEventProcessor; + + /// + /// Checks to see if the hosted service is running. + /// + /// False if the service hasn't started up yet, True if it has started, throws TaskCanceledException if service is/has shutdown. + private bool IsRunning => stoppingToken is not null && (stoppingToken.Value.IsCancellationRequested ? throw new TaskCanceledException() : true); + + private CancellationToken? stoppingToken = null; + private readonly ConcurrentQueue queuedTesTasks = []; + private readonly ConcurrentQueue<(TesTask[] TesTasks, AzureBatchTaskState[] TaskStates, ChannelWriter> Channel)> tesTaskBatchStates = []; + + /// + protected override async ValueTask ExecuteSetupAsync(CancellationToken cancellationToken) + { + try + { + // Delay "starting" TaskScheduler until this completes to finish initializing BatchScheduler. + await BatchScheduler.UploadTaskRunnerIfNeededAsync(cancellationToken); + // Ensure BatchScheduler has loaded existing pools before "starting". + //await BatchScheduler.LoadExistingPoolsAsync(cancellationToken); + } + catch (Exception exc) + { + Logger.LogError(exc, @"Checking/storing the node task runner binary failed with {Message}", exc.Message); + throw; + } + + if (cancellationToken.IsCancellationRequested) + { + return; + } + + Logger.LogTrace(@"Querying active tasks"); + + foreach (var tesTask in + (await Repository.GetItemsAsync( + predicate: t => !TesTask.TerminalStates.Contains(t.State), + cancellationToken: cancellationToken)) + .OrderBy(t => t.CreationTime)) + { + try + { + if (TesState.QUEUED.Equals(tesTask.State) && string.IsNullOrWhiteSpace(tesTask.PoolId)) + { + Logger.LogTrace(@"Adding queued task from repository"); + queuedTesTasks.Enqueue(tesTask); + } + else + { + var pool = BatchScheduler.GetPools().SingleOrDefault(pool => tesTask.PoolId.Equals(pool.PoolId, StringComparison.OrdinalIgnoreCase)); + + if (pool is null) + { + Logger.LogDebug(@"Adding task w/o pool id from repository"); + queuedTesTasks.Enqueue(tesTask); // TODO: is there a better way to treat tasks that are not "queued" that are also not associated with any known pool? + } + else + { + Logger.LogTrace(@"Adding task to pool w/o cloudtask"); + _ = pool.AssociatedTesTasks.AddOrUpdate(tesTask.Id, key => null, (key, value) => value); + } + } + } + catch (Exception ex) + { + await ProcessOrchestratedTesTaskAsync("Initialization", new(Task.FromException(ex), tesTask), ex => { Logger.LogCritical(ex, "Unexpected repository failure in initialization with {TesTask}", ex.RepositoryItem.Id); return ValueTask.CompletedTask; }, cancellationToken); + } + } + + Logger.LogTrace(@"Active tasks processed"); + } + + /// + protected override async ValueTask ExecuteCoreAsync(CancellationToken cancellationToken) + { + stoppingToken = cancellationToken; + List queuedTasks = []; + + while (!cancellationToken.IsCancellationRequested && tesTaskBatchStates.TryDequeue(out var result)) + { + queuedTasks.Add(ProcessQueuedTesTaskStatesRequestAsync(result.TesTasks, result.TaskStates, result.Channel, cancellationToken)); + } + + if (cancellationToken.IsCancellationRequested) + { + return; + } + + queuedTasks.Add(ExecuteShortBackgroundTasksAsync(cancellationToken)); + queuedTasks.Add(ExecuteLongBackgroundTasksAsync(cancellationToken)); + queuedTasks.Add(ExecuteQueuedTesTasksOnBatchAsync(cancellationToken)); + queuedTasks.Add(ExecuteCancelledTesTasksOnBatchAsync(cancellationToken)); + queuedTasks.Add(ExecuteUpdateTesTaskFromEventBlobAsync(cancellationToken)); + queuedTasks.Add(ExecuteProcessOrphanedTasksAsync(cancellationToken)); + + if (cancellationToken.IsCancellationRequested) + { + return; + } + + Logger.LogTrace(@"Task load: {TaskCount}", queuedTasks.Count); + await Task.WhenAll(queuedTasks); + } + + private async Task ProcessQueuedTesTaskStatesRequestAsync(TesTask[] tesTasks, AzureBatchTaskState[] taskStates, ChannelWriter> channel, CancellationToken cancellationToken) + { + try + { + await foreach (var relatedTask in ((ITaskScheduler)this).ProcessTesTaskBatchStatesAsync(tesTasks, taskStates, cancellationToken)) + { + await channel.WriteAsync(relatedTask, cancellationToken); + } + + channel.Complete(); + } + catch (Exception ex) + { + channel.Complete(ex); + } + } + + /// + /// Retrieves TesTasks queued via ProcessQueuedTesTaskAsync and schedules them for execution. + /// + /// Triggered when Microsoft.Extensions.Hosting.IHostedService.StopAsync(System.Threading.CancellationToken) is called. + /// + private Task ExecuteQueuedTesTasksOnBatchAsync(CancellationToken cancellationToken) + { + return ExecuteActionOnIntervalAsync(queuedRunInterval, ProcessQueuedTesTasksAsync, cancellationToken); + } + + /// + /// Schedules queued TesTasks via !BatchScheduler.ProcessQueuedTesTaskAsync. + /// + /// Triggered when Microsoft.Extensions.Hosting.IHostedService.StopAsync(System.Threading.CancellationToken) is called. + /// + private async ValueTask ProcessQueuedTesTasksAsync(CancellationToken cancellationToken) + { + while (!cancellationToken.IsCancellationRequested && queuedTesTasks.TryDequeue(out var tesTask)) + { + await ProcessOrchestratedTesTaskAsync("Queued", new(BatchScheduler.ProcessQueuedTesTaskAsync(tesTask, cancellationToken), tesTask), Requeue, cancellationToken); + } + + async ValueTask Requeue(RepositoryCollisionException exception) + { + TesTask tesTask = default; + + if (await Repository.TryGetItemAsync(exception.RepositoryItem.Id, cancellationToken, task => tesTask = task) && (tesTask?.IsActiveState() ?? false) && tesTask?.State != TesState.CANCELING) + { + queuedTesTasks.Enqueue(tesTask); + } + } + } + + /// + /// Retrieves all event blobs from storage and updates the resultant state. + /// + /// Triggered when Microsoft.Extensions.Hosting.IHostedService.StopAsync(System.Threading.CancellationToken) is called. + /// + private Task ExecuteShortBackgroundTasksAsync(CancellationToken cancellationToken) + { + return ExecuteActionOnIntervalAsync(shortBackgroundRunInterval, BatchScheduler.PerformShortBackgroundTasksAsync, cancellationToken); + } + + /// + /// Retrieves all event blobs from storage and updates the resultant state. + /// + /// Triggered when Microsoft.Extensions.Hosting.IHostedService.StopAsync(System.Threading.CancellationToken) is called. + /// + private async Task ExecuteLongBackgroundTasksAsync(CancellationToken cancellationToken) + { + await ExecuteActionOnIntervalAsync(longBackgroundRunInterval, + async token => await Task.WhenAll(BatchScheduler.PerformLongBackgroundTasksAsync(token).ToBlockingEnumerable(token)), + cancellationToken); + } + + /// + /// Retrieves all cancelled TES tasks from the database, performs an action in the batch system, and updates the resultant state + /// + /// Triggered when Microsoft.Extensions.Hosting.IHostedService.StopAsync(System.Threading.CancellationToken) is called. + /// + private Task ExecuteCancelledTesTasksOnBatchAsync(CancellationToken cancellationToken) + { + Func>> query = new( + async token => (await Repository.GetItemsAsync( + predicate: t => t.State == TesState.CANCELING, + cancellationToken: token)) + .OrderByDescending(t => t.CreationTime) + .ToAsyncEnumerable()); + + return ExecuteActionOnIntervalAsync(BatchRunInterval, + async token => + { + ConcurrentBag requeues = []; + List tasks = []; + + await foreach (var task in await query(cancellationToken)) + { + tasks.Add(task); + } + + do + { + requeues.Clear(); + await OrchestrateTesTasksOnBatchAsync( + "Cancelled", + _ => ValueTask.FromResult(tasks.ToAsyncEnumerable()), + (tasks, ct) => ((ITaskScheduler)this).ProcessTesTaskBatchStatesAsync( + tasks, + Enumerable.Repeat(new(AzureBatchTaskState.TaskState.CancellationRequested), tasks.Length).ToArray(), + ct), + ex => { requeues.Add(ex.RepositoryItem.Id); return ValueTask.CompletedTask; }, token); + + // Fetch updated TesTasks from the repository + ConcurrentBag requeuedTasks = []; + await Parallel.ForEachAsync(requeues, cancellationToken, async (id, token) => + { + TesTask tesTask = default; + + if (await Repository.TryGetItemAsync(id, token, task => tesTask = task)) + { + requeuedTasks.Add(tesTask); + } + }); + + // Stage next loop + tasks.Clear(); + requeuedTasks.ForEach(tasks.Add); + } + while (!requeues.IsEmpty); + }, + cancellationToken); + } + + /// + /// Retrieves all event blobs from storage and updates the resultant state. + /// + /// Triggered when Microsoft.Extensions.Hosting.IHostedService.StopAsync(System.Threading.CancellationToken) is called. + /// + private Task ExecuteUpdateTesTaskFromEventBlobAsync(CancellationToken cancellationToken) + { + return ExecuteActionOnIntervalAsync(blobRunInterval, + async token => + await UpdateTesTasksFromAvailableEventsAsync( + await ParseAvailableEvents(token), + token), + cancellationToken); + } + + /// + /// Determines the s from each event available for processing and their associated s. + /// + /// A System.Threading.CancellationToken for controlling the lifetime of the asynchronous operation. + /// s and s from all events. + private async ValueTask MarkProcessedAsync)>> ParseAvailableEvents(CancellationToken cancellationToken) + { + var tasks = new ConcurrentDictionary>(StringComparer.OrdinalIgnoreCase); // TODO: Are tesTask.Ids case sensitive? + var messages = new ConcurrentBag<(RunnerEventsMessage Message, TesTask Task, AzureBatchTaskState State, Func MarkProcessedAsync)>(); + + // Get tasks for event blobs + await Parallel.ForEachAsync(BatchScheduler.GetEventMessagesAsync(cancellationToken), cancellationToken, async (eventMessage, token) => + { + TesTask tesTask = default; + + try + { + tesTask = await GetTesTaskAsync(eventMessage.Tags["task-id"], eventMessage.Tags["event-name"]); + + if (tesTask is null) + { + return; + } + + nodeEventProcessor.ValidateMessageMetadata(eventMessage); + tasks.AddOrUpdate(tesTask.Id, _ => [(eventMessage, tesTask)], (_, list) => list.Add((eventMessage, tesTask))); + } + catch (OperationCanceledException) when (token.IsCancellationRequested) + { + throw; + } + catch (ArgumentException ex) + { + Logger.LogError(ex, @"Verifying event metadata failed: {ErrorMessage}", ex.Message); + + messages.Add(( + eventMessage, + tesTask, + new(AzureBatchTaskState.TaskState.InfoUpdate, Warning: + [ + "EventParsingFailed", + $"{ex.GetType().FullName}: {ex.Message}" + ]), + ct => nodeEventProcessor.RemoveMessageFromReattemptsAsync(eventMessage, ct))); + } + + // Helpers + async ValueTask GetTesTaskAsync(string id, string @event) + { + TesTask tesTask = default; + if (await Repository.TryGetItemAsync(id, token, task => tesTask = task) && tesTask is not null) + { + Logger.LogTrace("Attempting to complete event '{TaskEvent}' for task {TesTask}.", @event, tesTask.Id); + return tesTask; + } + else + { + Logger.LogDebug("Could not find task {TesTask} for event '{TaskEvent}'.", id, @event); + return null; + } + } + }); + + // Parse event blobs, deferring later events for the same TesTask + await Parallel.ForEachAsync(tasks.Select(pair => nodeEventProcessor.OrderProcessedByExecutorSequence(pair.Value, m => m.Event).First()), cancellationToken, async (tuple, token) => + { + var (eventMessage, tesTask) = tuple; + + try + { + eventMessage = await nodeEventProcessor.DownloadAndValidateMessageContentAsync(eventMessage, token); + var state = await nodeEventProcessor.GetMessageBatchStateAsync(eventMessage, tesTask, token); + messages.Add((eventMessage, tesTask, state, ct => nodeEventProcessor.MarkMessageProcessedAsync(eventMessage, ct))); + } + catch (OperationCanceledException) when (token.IsCancellationRequested) + { + throw; + } + catch (Exception ex) + { + Logger.LogError(ex, @"Downloading and parsing event failed: {ErrorMessage}", ex.Message); + + messages.Add(( + eventMessage, + tesTask, + new(AzureBatchTaskState.TaskState.InfoUpdate, Warning: + [ + "EventParsingFailed", + $"{ex.GetType().FullName}: {ex.Message}" + ]), + (ex is System.Diagnostics.UnreachableException || ex is RunnerEventsProcessor.DownloadOrParseException || ex is ArgumentException) + ? ct => nodeEventProcessor.MarkMessageProcessedAsync(eventMessage, ct) // Mark event processed to prevent retries + : default)); // Retry this event. + } + }); + + return nodeEventProcessor.OrderProcessedByExecutorSequence(messages, @event => @event.Message).Select(@event => (@event.Task, @event.State, @event.MarkProcessedAsync)); + } + + /// + /// Updates each task based on the provided state. + /// + /// A collection of associated s, s, and a method to mark the source event processed. + /// A System.Threading.CancellationToken for controlling the lifetime of the asynchronous operation. + /// + private async ValueTask UpdateTesTasksFromAvailableEventsAsync(IEnumerable<(TesTask Task, AzureBatchTaskState State, Func MarkProcessedAsync)> eventStates, CancellationToken cancellationToken) + { + eventStates = eventStates.ToList(); + + if (!eventStates.Any()) + { + return; + } + + ConcurrentBag requeues = []; + ConcurrentDictionary MarkProcessedAsync)>> statesByTask = new(StringComparer.Ordinal); + HashSet tasks = []; + + eventStates.ForEach(t => + { + _ = tasks.Add(t.Task); + _ = statesByTask.AddOrUpdate(t.Task.Id, _ => [(t.State, t.MarkProcessedAsync)], (_, array) => array.Add((t.State, t.MarkProcessedAsync))); + }); + + do + { + // Update TesTasks one event each per loop + requeues.Clear(); + await OrchestrateTesTasksOnBatchAsync( + "NodeEvent", + _ => ValueTask.FromResult(tasks.ToAsyncEnumerable()), + (tesTasks, token) => ((ITaskScheduler)this).ProcessTesTaskBatchStatesAsync(tesTasks, tesTasks.Select(task => statesByTask[task.Id][0].State).ToArray(), token), + ex => { requeues.Add(ex.RepositoryItem.Id); return ValueTask.CompletedTask; }, + cancellationToken, + "events"); + + // Get next state for each task (if any) for next loop + _ = Parallel.ForEach(tasks, task => + { + // Don't remove current state if there was a repository conflict + if (!requeues.Contains(task.Id)) + { + var states = statesByTask[task.Id].RemoveAt(0); + + if (!states.IsEmpty) + { + statesByTask[task.Id] = states; + requeues.Add(task.Id); + } + } + }); + + // Fetch updated TesTasks from the repository + ConcurrentBag requeuedTasks = []; + await Parallel.ForEachAsync(requeues, cancellationToken, async (id, token) => + { + TesTask tesTask = default; + + if (await Repository.TryGetItemAsync(id, token, task => tesTask = task)) + { + requeuedTasks.Add(tesTask); + } + }); + + // Stage next loop + tasks.Clear(); + requeuedTasks.ForEach(task => _ = tasks.Add(task)); + } + while (!requeues.IsEmpty); + + await Parallel.ForEachAsync(eventStates.Select(@event => @event.MarkProcessedAsync).Where(func => func is not null), cancellationToken, async (markEventProcessed, token) => + { + try + { + await markEventProcessed(token); + } + catch (OperationCanceledException) when (token.IsCancellationRequested) + { + throw; + } + catch (Exception ex) + { + Logger.LogError(ex, @"Failed to tag event as processed."); + } + }); + } + + /// + void ITaskScheduler.QueueTesTask(TesTask tesTask) + { + queuedTesTasks.Enqueue(tesTask); + } + + private async Task ExecuteProcessOrphanedTasksAsync(CancellationToken cancellationToken) + { + List statesToSkip = [TesState.QUEUED, TesState.CANCELING]; + statesToSkip.AddRange(TesTask.TerminalStates); + + await ExecuteActionOnIntervalAsync(orphanedTaskInterval, + async token => + { + var pools = BatchScheduler.GetPools().Select(p => p.PoolId).ToArray(); + var now = DateTimeOffset.UtcNow; + + await OrchestrateTesTasksOnBatchAsync( + $"OrphanedTasks", + async cancellation => (await Repository.GetItemsAsync(task => !statesToSkip.Contains(task.State), cancellation)) + .Where(task => !pools.Contains(task.PoolId, StringComparer.OrdinalIgnoreCase)) + .ToAsyncEnumerable(), + (tesTasks, cancellation) => ((ITaskScheduler)this).ProcessTesTaskBatchStatesAsync(tesTasks, tesTasks.Select(_ => new AzureBatchTaskState(AzureBatchTaskState.TaskState.CompletedWithErrors, BatchTaskEndTime: now, Failure: new(AzureBatchTaskState.SystemError, ["RemovedPoolOrJob", "Batch pool or job was removed."]))).ToArray(), cancellation), + ex => { Logger.LogError(ex, "Repository collision while failing task ('{TesTask}') due to pool or job removal.", ex.RepositoryItem?.Id ?? ""); return ValueTask.CompletedTask; }, + token); + }, + cancellationToken); + } + + /// + IAsyncEnumerable> ITaskScheduler.ProcessTesTaskBatchStatesAsync(IEnumerable tesTasks, AzureBatchTaskState[] taskStates, CancellationToken cancellationToken) + { + ArgumentNullException.ThrowIfNull(tesTasks); + ArgumentNullException.ThrowIfNull(taskStates); + + if (IsRunning) + { + return taskStates.Zip(tesTasks, (TaskState, TesTask) => (TaskState, TesTask)) + .Select(entry => new RelatedTask(entry.TesTask?.IsActiveState() ?? false // Removes already terminal (and null) TesTasks from being further processed. + ? WrapHandleTesTaskTransitionAsync(entry.TesTask, entry.TaskState, cancellationToken) + : Task.FromResult(false), entry.TesTask)) + .WhenEach(cancellationToken, tesTaskTask => tesTaskTask.Task); + + async Task WrapHandleTesTaskTransitionAsync(TesTask tesTask, AzureBatchTaskState azureBatchTaskState, CancellationToken cancellationToken) + => await BatchScheduler.ProcessTesTaskBatchStateAsync(tesTask, azureBatchTaskState, cancellationToken); + } + else + { + var channel = Channel.CreateBounded>(new BoundedChannelOptions(taskStates.Length) { SingleReader = true, SingleWriter = true }); + tesTaskBatchStates.Enqueue((tesTasks.ToArray(), taskStates, channel.Writer)); + return channel.Reader.ReadAllAsync(cancellationToken); + } + } + } +} diff --git a/src/TesApi.Web/TerraActionIdentityProvider.cs b/src/TesApi.Web/TerraActionIdentityProvider.cs index 3778edbf9..06535c80c 100644 --- a/src/TesApi.Web/TerraActionIdentityProvider.cs +++ b/src/TesApi.Web/TerraActionIdentityProvider.cs @@ -51,12 +51,12 @@ public async Task GetAcrPullActionIdentity(CancellationToken cancellatio if (response is null) { // Corresponds to no identity existing in Sam, or the user not having access to it. - Logger.LogInformation(@"Found no ACR Pull action identity in Sam for {id}", samResourceIdForAcrPull); + Logger.LogDebug(@"Found no ACR Pull action identity in Sam for {id}", samResourceIdForAcrPull); return null; } else { - Logger.LogInformation(@"Successfully fetched ACR action identity from Sam: {ObjectId}", response.ObjectId); + Logger.LogDebug(@"Successfully fetched ACR action identity from Sam: {ObjectId}", response.ObjectId); return response.ObjectId; } } diff --git a/src/TesApi.Web/TesApi.Web.csproj b/src/TesApi.Web/TesApi.Web.csproj index 66492e950..defdc0f89 100644 --- a/src/TesApi.Web/TesApi.Web.csproj +++ b/src/TesApi.Web/TesApi.Web.csproj @@ -5,6 +5,7 @@ net8.0 true true + true tesapi false GA4GH Task Execution Service @@ -80,10 +81,10 @@ - + - + diff --git a/src/TesApi.Web/appsettings.json b/src/TesApi.Web/appsettings.json index 294a017ad..2f1e87b30 100644 --- a/src/TesApi.Web/appsettings.json +++ b/src/TesApi.Web/appsettings.json @@ -5,7 +5,14 @@ }, "Logging": { "LogLevel": { - "Default": "Warning" + "Azure": "Warning", + "Microsoft": "Warning", + "TesApi.Web.AzureProxy": "Trace", + "TesApi.Web.BatchPool": "Trace", + "TesApi.Web.BatchScheduler": "Trace", + "TesApi.Web.PoolScheduler": "Trace", + "TesApi.Web.TaskScheduler": "Trace", + "Default": "Information" } }, "AllowedHosts": "*", diff --git a/src/deploy-tes-on-azure.Tests/deploy-tes-on-azure.Tests.csproj b/src/deploy-tes-on-azure.Tests/deploy-tes-on-azure.Tests.csproj index 125cc27b4..bb157d31d 100644 --- a/src/deploy-tes-on-azure.Tests/deploy-tes-on-azure.Tests.csproj +++ b/src/deploy-tes-on-azure.Tests/deploy-tes-on-azure.Tests.csproj @@ -1,4 +1,4 @@ - + net8.0 diff --git a/src/deploy-tes-on-azure/Deployer.cs b/src/deploy-tes-on-azure/Deployer.cs index cc608b35f..c1fab46d8 100644 --- a/src/deploy-tes-on-azure/Deployer.cs +++ b/src/deploy-tes-on-azure/Deployer.cs @@ -47,15 +47,13 @@ using CommonUtilities; using CommonUtilities.AzureCloud; using k8s; -using Microsoft.EntityFrameworkCore; using Microsoft.Graph; using Newtonsoft.Json; using Polly; -using Polly.Retry; using Polly.Utilities; -using Tes.Extensions; using Tes.Models; using Tes.SDK; +using static CommonUtilities.RetryHandler; using Batch = Azure.ResourceManager.Batch.Models; using Storage = Azure.ResourceManager.Storage.Models; @@ -63,20 +61,26 @@ namespace TesDeployer { public class Deployer(Configuration configuration) { - private static readonly AsyncRetryPolicy roleAssignmentHashConflictRetryPolicy = Policy - .Handle(requestFailedException => - "HashConflictOnDifferentRoleAssignmentIds".Equals(requestFailedException.ErrorCode, StringComparison.OrdinalIgnoreCase)) - .RetryAsync(); - - private static readonly AsyncRetryPolicy operationNotAllowedConflictRetryPolicy = Policy - .Handle(azureException => + private static readonly AsyncRetryHandlerPolicy roleAssignmentHashConflictRetryPolicy = new RetryPolicyBuilder(Microsoft.Extensions.Options.Options.Create(new CommonUtilities.Options.RetryPolicyOptions())) + .PolicyBuilder.OpinionatedRetryPolicy(Polly.Policy.Handle(requestFailedException => + "HashConflictOnDifferentRoleAssignmentIds".Equals(requestFailedException.ErrorCode, StringComparison.OrdinalIgnoreCase))) + .WithCustomizedRetryPolicyOptionsWait(int.MaxValue, (_, _) => TimeSpan.Zero) + .SetOnRetryBehavior() + .AsyncBuild(); + + private static readonly AsyncRetryHandlerPolicy operationNotAllowedConflictRetryPolicy = new RetryPolicyBuilder(Microsoft.Extensions.Options.Options.Create(new CommonUtilities.Options.RetryPolicyOptions())) + .PolicyBuilder.OpinionatedRetryPolicy(Polly.Policy.Handle(azureException => (int)HttpStatusCode.Conflict == azureException.Status && - "OperationNotAllowed".Equals(azureException.ErrorCode, StringComparison.OrdinalIgnoreCase)) - .WaitAndRetryAsync(30, retryAttempt => TimeSpan.FromSeconds(10)); + "OperationNotAllowed".Equals(azureException.ErrorCode, StringComparison.OrdinalIgnoreCase))) + .WithCustomizedRetryPolicyOptionsWait(30, (_, _) => TimeSpan.FromSeconds(10)) + .SetOnRetryBehavior() + .AsyncBuild(); - private static readonly AsyncRetryPolicy buildPushAcrRetryPolicy = Policy - .Handle(AsyncRetryExceptionPolicy) - .WaitAndRetryAsync(3, retryAttempt => TimeSpan.FromSeconds(1)); + private static readonly AsyncRetryHandlerPolicy buildPushAcrRetryPolicy = new RetryPolicyBuilder(Microsoft.Extensions.Options.Options.Create(new CommonUtilities.Options.RetryPolicyOptions())) + .PolicyBuilder.OpinionatedRetryPolicy(Polly.Policy.Handle(AsyncRetryExceptionPolicy)) + .WithCustomizedRetryPolicyOptionsWait(3, (_, _) => TimeSpan.FromSeconds(1)) + .SetOnRetryBehavior() + .AsyncBuild(); private static bool AsyncRetryExceptionPolicy(Exception ex) { @@ -96,19 +100,25 @@ private static bool AsyncRetryExceptionPolicy(Exception ex) return !dontRetry; } - private static readonly AsyncRetryPolicy acrGetDigestRetryPolicy = Policy - .Handle(azureException => (int)HttpStatusCode.NotFound == azureException.Status) - .WaitAndRetryAsync(30, retryAttempt => TimeSpan.FromSeconds(10)); + private static readonly AsyncRetryHandlerPolicy acrGetDigestRetryPolicy = new RetryPolicyBuilder(Microsoft.Extensions.Options.Options.Create(new CommonUtilities.Options.RetryPolicyOptions())) + .PolicyBuilder.OpinionatedRetryPolicy(Polly.Policy.Handle(azureException => (int)HttpStatusCode.NotFound == azureException.Status)) + .WithCustomizedRetryPolicyOptionsWait(30, (_, _) => TimeSpan.FromSeconds(10)) + .SetOnRetryBehavior() + .AsyncBuild(); - private static readonly AsyncRetryPolicy generalRetryPolicy = Policy - .Handle() - .WaitAndRetryAsync(3, retryAttempt => TimeSpan.FromSeconds(1)); + private static readonly AsyncRetryHandlerPolicy generalRetryPolicy = new RetryPolicyBuilder(Microsoft.Extensions.Options.Options.Create(new CommonUtilities.Options.RetryPolicyOptions())) + .PolicyBuilder.OpinionatedRetryPolicy() + .WithCustomizedRetryPolicyOptionsWait(3, (_, _) => TimeSpan.FromSeconds(1)) + .SetOnRetryBehavior() + .AsyncBuild(); - private static readonly AsyncRetryPolicy internalServerErrorRetryPolicy = Policy - .Handle(azureException => + private static readonly AsyncRetryHandlerPolicy internalServerErrorRetryPolicy = new RetryPolicyBuilder(Microsoft.Extensions.Options.Options.Create(new CommonUtilities.Options.RetryPolicyOptions())) + .PolicyBuilder.OpinionatedRetryPolicy(Polly.Policy.Handle(azureException => (int)HttpStatusCode.OK == azureException.Status && - "InternalServerError".Equals(azureException.ErrorCode, StringComparison.OrdinalIgnoreCase)) - .WaitAndRetryAsync(3, retryAttempt => longRetryWaitTime); + "InternalServerError".Equals(azureException.ErrorCode, StringComparison.OrdinalIgnoreCase))) + .WithCustomizedRetryPolicyOptionsWait(3, (_, _) => longRetryWaitTime) + .SetOnRetryBehavior() + .AsyncBuild(); private static readonly TimeSpan longRetryWaitTime = TimeSpan.FromSeconds(15); @@ -252,6 +262,7 @@ await Execute("Connecting to Azure Services...", async () => { storageAccount = await GetExistingStorageAccountAsync(configuration.StorageAccountName) ?? throw new ValidationException($"Storage account {configuration.StorageAccountName} does not exist in region {configuration.RegionName} or is not accessible to the current user.", displayExample: false); + } storageAccountData = (await FetchResourceDataAsync(ct => storageAccount.GetAsync(cancellationToken: ct), cts.Token, account => storageAccount = account)).Data; @@ -1283,7 +1294,7 @@ private async Task EnableWorkloadIdentit var aksClusterCollection = resourceGroup.GetContainerServiceManagedClusters(); var cluster = await Execute("Updating AKS cluster...", - async () => await operationNotAllowedConflictRetryPolicy.ExecuteAsync(token => aksClusterCollection.CreateOrUpdateAsync(WaitUntil.Completed, aksCluster.Data.Name, aksCluster.Data, token), cts.Token)); + async () => await operationNotAllowedConflictRetryPolicy.ExecuteWithRetryAsync(token => aksClusterCollection.CreateOrUpdateAsync(WaitUntil.Completed, aksCluster.Data.Name, aksCluster.Data, token), cts.Token)); var aksOidcIssuer = cluster.Value.Data.OidcIssuerProfile.IssuerUriInfo; @@ -1299,7 +1310,7 @@ private async Task EnableWorkloadIdentit data.Audiences.Add("api://AzureADTokenExchange"); await Execute("Enabling workload identity...", - async () => _ = await operationNotAllowedConflictRetryPolicy.ExecuteAsync(token => federatedCredentialsCollection.CreateOrUpdateAsync(WaitUntil.Completed, "toaFederatedIdentity", data, token), cts.Token)); + async () => _ = await operationNotAllowedConflictRetryPolicy.ExecuteWithRetryAsync(token => federatedCredentialsCollection.CreateOrUpdateAsync(WaitUntil.Completed, "toaFederatedIdentity", data, token), cts.Token)); } return cluster.Value; @@ -1358,7 +1369,7 @@ private async Task BuildPushAcrAsync(Dictionary settings, string } var build = await Execute($"Building TES image on {acr.Id.Name}...", - () => buildPushAcrRetryPolicy.ExecuteAsync(async () => + () => buildPushAcrRetryPolicy.ExecuteWithRetryAsync(async () => { AcrBuild build = default; await Policy.Handle(ae => (int)HttpStatusCode.Unauthorized == ae.ResponseStatusCode) @@ -1422,7 +1433,7 @@ private async Task BuildPushAcrAsync(Dictionary settings, string return build; })); - var tesDigest = (await acrGetDigestRetryPolicy.ExecuteAsync(token => (client ??= GetClient()).GetArtifact("cromwellonazure/tes", build.Tag.ToString()).GetManifestPropertiesAsync(token), cts.Token)).Value.Digest; + var tesDigest = (await acrGetDigestRetryPolicy.ExecuteWithRetryAsync(token => (client ??= GetClient()).GetArtifact("cromwellonazure/tes", build.Tag.ToString()).GetManifestPropertiesAsync(token), cts.Token)).Value.Digest; settings["ActualTesImageName"] = $"{acr.Data.LoginServer}/cromwellonazure/tes@{tesDigest}"; Azure.Containers.ContainerRegistry.ContainerRegistryClient GetClient() @@ -1954,7 +1965,7 @@ private async Task CreatePostgreSqlServerAndDa var server = await Execute( $"Creating Azure Flexible Server for PostgreSQL: {configuration.PostgreSqlServerName}...", - async () => (await internalServerErrorRetryPolicy.ExecuteAsync(token => resourceGroup.GetPostgreSqlFlexibleServers().CreateOrUpdateAsync(WaitUntil.Completed, configuration.PostgreSqlServerName, data, token), cts.Token)).Value); + async () => (await internalServerErrorRetryPolicy.ExecuteWithRetryAsync(token => resourceGroup.GetPostgreSqlFlexibleServers().CreateOrUpdateAsync(WaitUntil.Completed, configuration.PostgreSqlServerName, data, token), cts.Token)).Value); await Execute( $"Creating PostgreSQL tes database: {configuration.PostgreSqlTesDatabaseName}...", @@ -2021,7 +2032,7 @@ private async Task AssignRoleToResourceAsync(IEnumerable principalId { try { - await roleAssignmentHashConflictRetryPolicy.ExecuteAsync(token => + await roleAssignmentHashConflictRetryPolicy.ExecuteWithRetryAsync(token => (Task)resource.GetRoleAssignments().CreateOrUpdateAsync(WaitUntil.Completed, Guid.NewGuid().ToString(), new(roleDefinitionId, principal) { @@ -2638,7 +2649,7 @@ private static void AddServiceEndpointsToSubnet(SubnetData subnet) private async Task ValidateVmAsync() { - var computeSkus = await generalRetryPolicy.ExecuteAsync(async ct => + var computeSkus = await generalRetryPolicy.ExecuteWithRetryAsync(async ct => await armSubscription.GetComputeResourceSkusAsync( filter: $"location eq '{configuration.RegionName}'", cancellationToken: ct) diff --git a/src/deploy-tes-on-azure/KubernetesManager.cs b/src/deploy-tes-on-azure/KubernetesManager.cs index 2d12c59ad..f6f837bbd 100644 --- a/src/deploy-tes-on-azure/KubernetesManager.cs +++ b/src/deploy-tes-on-azure/KubernetesManager.cs @@ -14,11 +14,11 @@ using Azure.ResourceManager.ContainerService; using Azure.ResourceManager.ManagedServiceIdentities; using Azure.Storage.Blobs; +using CommonUtilities; using CommonUtilities.AzureCloud; using k8s; using k8s.Models; -using Polly; -using Polly.Retry; +using static CommonUtilities.RetryHandler; namespace TesDeployer { @@ -27,13 +27,17 @@ namespace TesDeployer /// public class KubernetesManager { - private static readonly AsyncRetryPolicy WorkloadReadyRetryPolicy = Policy - .Handle() - .WaitAndRetryAsync(80, retryAttempt => TimeSpan.FromSeconds(15)); - - private static readonly AsyncRetryPolicy KubeExecRetryPolicy = Policy - .Handle(ex => ex.WebSocketErrorCode == WebSocketError.NotAWebSocket) - .WaitAndRetryAsync(200, retryAttempt => TimeSpan.FromSeconds(5)); + private static readonly AsyncRetryHandlerPolicy WorkloadReadyRetryPolicy = new RetryPolicyBuilder(Microsoft.Extensions.Options.Options.Create(new CommonUtilities.Options.RetryPolicyOptions())) + .PolicyBuilder.OpinionatedRetryPolicy() + .WithCustomizedRetryPolicyOptionsWait(80, (_, _) => TimeSpan.FromSeconds(15)) + .SetOnRetryBehavior() + .AsyncBuild(); + + private static readonly AsyncRetryHandlerPolicy KubeExecRetryPolicy = new RetryPolicyBuilder(Microsoft.Extensions.Options.Options.Create(new CommonUtilities.Options.RetryPolicyOptions())) + .PolicyBuilder.OpinionatedRetryPolicy(Polly.Policy.Handle(ex => ex.WebSocketErrorCode == WebSocketError.NotAWebSocket)) + .WithCustomizedRetryPolicyOptionsWait(200, (_, _) => TimeSpan.FromSeconds(5)) + .SetOnRetryBehavior() + .AsyncBuild(); private const string NginxIngressRepo = "https://kubernetes.github.io/ingress-nginx"; private const string NginxIngressVersion = "4.7.1"; @@ -396,7 +400,7 @@ async Task StreamHandler(Stream stream) } }, cancellationToken); - if (result.Outcome != OutcomeType.Successful && result.FinalException is not null) + if (result.Outcome != Polly.OutcomeType.Successful && result.FinalException is not null) { throw result.FinalException; } @@ -675,7 +679,7 @@ private static async Task WaitForWorkloadAsync(IKubernetes client, string } }, cancellationToken); - return result.Outcome == OutcomeType.Successful; + return result.Outcome == Polly.OutcomeType.Successful; } public class HelmValues diff --git a/src/deploy-tes-on-azure/Properties/launchSettings.json b/src/deploy-tes-on-azure/Properties/launchSettings.json index 48d66dde6..f5d6f1e4b 100644 --- a/src/deploy-tes-on-azure/Properties/launchSettings.json +++ b/src/deploy-tes-on-azure/Properties/launchSettings.json @@ -6,4 +6,4 @@ "sqlDebugging": true } } -} \ No newline at end of file +}