Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Mysql] az mysql flexible-server import create: Add support for operation progress estimated completion time for import from physical backup from azure blob to flexible server #28243

Merged
merged 4 commits into from
Jan 25, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 125 additions & 4 deletions src/azure-cli/azure/cli/command_modules/mysql/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
# --------------------------------------------------------------------------------------------

# pylint: disable=unused-argument, line-too-long, import-outside-toplevel, raise-missing-from
from enum import Enum
import json
import math
import os
import random
import subprocess
Expand All @@ -19,17 +22,19 @@
from msrestazure.tools import parse_resource_id
from msrestazure.azure_exceptions import CloudError
from azure.cli.core.commands.client_factory import get_subscription_id
from azure.cli.core.commands.progress import IndeterminateProgressBar
from azure.cli.core.util import CLIError
from azure.core.exceptions import HttpResponseError
from azure.core.paging import ItemPaged
from azure.core.rest import HttpRequest
from azure.cli.core.commands import LongRunningOperation, AzArgumentContext, _is_poller
from azure.cli.core.azclierror import RequiredArgumentMissingError, InvalidArgumentValueError, AuthenticationError
from azure.cli.command_modules.role.custom import create_service_principal_for_rbac
from azure.mgmt.rdbms import mysql_flexibleservers, postgresql_flexibleservers
from azure.mgmt.resource.resources.models import ResourceGroup
from ._client_factory import resource_client_factory, cf_mysql_flexible_location_capabilities
from ._client_factory import resource_client_factory, cf_mysql_flexible_location_capabilities, get_mysql_flexible_management_client
from azure.cli.core.commands.validators import get_default_location_from_resource_group, validate_tags

from urllib.parse import urlencode, urlparse, parse_qsl

logger = get_logger(__name__)

Expand Down Expand Up @@ -105,9 +110,9 @@ def call(*args, **kwargs):
return decorate


def resolve_poller(result, cli_ctx, name):
def resolve_poller(result, cli_ctx, name, progress_bar=None):
if _is_poller(result):
return LongRunningOperation(cli_ctx, 'Starting {}'.format(name))(result)
return LongRunningOperation(cli_ctx, 'Starting {}'.format(name), progress_bar=progress_bar)(result)
return result


Expand Down Expand Up @@ -570,3 +575,119 @@ def get_single_to_flex_sku_mapping(source_single_server_sku, tier, sku_name):

def get_firewall_rules_from_paged_response(firewall_rules):
return list(firewall_rules) if isinstance(firewall_rules, ItemPaged) else firewall_rules


def get_current_utc_time():
return datetime.utcnow().replace(tzinfo=dt.timezone.utc)


class ImportFromStorageState(Enum):
STARTING = "Starting"
PROVISIONING = "Provisioning Server"
IMPORTING = "Importing"
DEFAULT = "Running"


class ImportFromStorageProgressHook:

def __init__(self):
self._import_started = False
self._import_state = ImportFromStorageState.STARTING
self._import_estimated_completion_time = None

def update_progress(self, operation_progress_response):
if operation_progress_response is not None:
try:
jsonresp = json.loads(operation_progress_response.text())
self._update_import_from_storage_progress_status(jsonresp)
except:
pass

def get_progress_message(self):
msg = self._import_state.value
if self._import_estimated_completion_time is not None:
msg = msg + " " + self._get_eta_time_duration_in_user_readable_string()
elif self._import_state == ImportFromStorageState.IMPORTING :
msg = msg + " " + "Preparing (This might take few minutes)"

return msg

def _get_eta_time_duration_in_user_readable_string(self):
time_remaining = datetime.fromisoformat(self._import_estimated_completion_time) - get_current_utc_time()
msg = " ETA : "

if time_remaining.total_seconds() < 60:
return msg + "Few seconds remaining"

days = time_remaining.days
hours, remainder = divmod(time_remaining.seconds, 3600)
minutes = math.ceil(remainder/60.0)

if days > 0:
msg = msg + str(days) + " days "
if hours > 0:
msg = msg + str(hours) + " hours "
if minutes > 0:
msg = msg + str(minutes) + " minutes "

return msg + " remaining"

def _update_import_from_storage_progress_status(self, progress_resp_json):
if "status" in progress_resp_json:
progress_status = progress_resp_json["status"]
previous_import_state = self._import_state

# Updating the import state
if progress_status == "Importing":
self._import_started = True
self._import_state = ImportFromStorageState.IMPORTING
elif progress_status == "InProgress" and self._import_started == False:
self._import_state = ImportFromStorageState.PROVISIONING
else:
self._import_state = ImportFromStorageState.DEFAULT

# Updating the estimated completion time
is_state_same = self._import_state == previous_import_state
if is_state_same == False:
self._import_estimated_completion_time = None
if "properties" in progress_resp_json and "estimatedCompletionTime" in progress_resp_json["properties"]:
self._import_estimated_completion_time = str(progress_resp_json["properties"]["estimatedCompletionTime"])


class OperationProgressBar(IndeterminateProgressBar):

""" Define progress bar update view for operation progress """
def __init__(self, cli_ctx, poller, operation_progress_hook, progress_message_update_interval_in_sec = 60.0):
self._poller = poller
self._operation_progress_hook = operation_progress_hook
self._operation_progress_request = self._get_operation_progress_request()
self._client = get_mysql_flexible_management_client(cli_ctx)
self._progress_message_update_interval_in_sec = progress_message_update_interval_in_sec
self._progress_message_last_updated = None
super().__init__(cli_ctx)

def update_progress(self):
self._safe_update_progress_message()
super().update_progress()

def _safe_update_progress_message(self):
try:
if self._should_update_progress_message():
operation_progress_resp = self._client._send_request(self._operation_progress_request)
self._operation_progress_hook.update_progress(operation_progress_resp)
self.message = self._operation_progress_hook.get_progress_message()
self._progress_message_last_updated = get_current_utc_time()
except:
pass

def _should_update_progress_message(self):
return (self._progress_message_last_updated is None) or ((get_current_utc_time() - self._progress_message_last_updated).total_seconds() > self._progress_message_update_interval_in_sec)

def _get_operation_progress_request(self):
location_url = self._poller._polling_method._initial_response.http_response.headers["Location"]
operation_progress_url = location_url.replace('operationResults', 'operationProgress')
operation_progress_url_parsed = urlparse(operation_progress_url)
query_params = dict(parse_qsl(operation_progress_url_parsed.query))
query_params['api-version'] = "2023-12-01-preview"
updated_operation_progress_url = operation_progress_url_parsed._replace(query=urlencode(query_params)).geturl()
return HttpRequest('GET', updated_operation_progress_url)
21 changes: 16 additions & 5 deletions src/azure-cli/azure/cli/command_modules/mysql/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from msrestazure.azure_exceptions import CloudError
from msrestazure.tools import resource_id, is_valid_resource_id, parse_resource_id
from azure.core.exceptions import ResourceNotFoundError
from azure.core.rest import HttpRequest
from azure.cli.core.commands.client_factory import get_subscription_id
from azure.cli.command_modules.mysql.random.generate import generate_username
from azure.cli.core.util import CLIError, sdk_no_wait, user_confirmation
Expand All @@ -27,7 +28,8 @@
cf_mysql_firewall_rules
from ._util import resolve_poller, generate_missing_parameters, get_mysql_list_skus_info, generate_password, parse_maintenance_window, \
replace_memory_optimized_tier, build_identity_and_data_encryption, get_identity_and_data_encryption, get_tenant_id, run_subprocess, \
run_subprocess_get_output, fill_action_template, get_git_root_dir, get_single_to_flex_sku_mapping, get_firewall_rules_from_paged_response, GITHUB_ACTION_PATH
run_subprocess_get_output, fill_action_template, get_git_root_dir, get_single_to_flex_sku_mapping, get_firewall_rules_from_paged_response, \
ImportFromStorageProgressHook, OperationProgressBar, GITHUB_ACTION_PATH
from ._network import prepare_mysql_exist_private_dns_zone, prepare_mysql_exist_private_network, prepare_private_network, prepare_private_dns_zone, prepare_public_network
from ._validators import mysql_arguments_validator, mysql_auto_grow_validator, mysql_georedundant_backup_validator, mysql_restore_tier_validator, \
mysql_retention_validator, mysql_sku_name_validator, mysql_storage_validator, validate_mysql_replica, validate_server_name, \
Expand Down Expand Up @@ -626,7 +628,8 @@ def flexible_server_import_create(cmd, client,
availability_zone=zone,
data_encryption=data_encryption,
source_server_id=source_server_id,
import_source_properties=import_source_properties)
import_source_properties=import_source_properties,
data_source_type = data_source_type)

# Adding firewall rule
if start_ip != -1 and end_ip != -1:
Expand Down Expand Up @@ -1375,7 +1378,7 @@ def _create_server(db_context, cmd, resource_group_name, server_name, tags, loca


def _import_create_server(db_context, cmd, resource_group_name, server_name, create_mode, source_server_id, tags, location, sku, administrator_login, administrator_login_password,
storage, backup, network, version, high_availability, availability_zone, identity, data_encryption, import_source_properties):
storage, backup, network, version, high_availability, availability_zone, identity, data_encryption, import_source_properties, data_source_type):
logging_name, server_client = db_context.logging_name, db_context.server_client
logger.warning('Creating %s Server \'%s\' in group \'%s\'...', logging_name, server_name, resource_group_name)

Expand All @@ -1400,10 +1403,18 @@ def _import_create_server(db_context, cmd, resource_group_name, server_name, cre
source_server_resource_id=source_server_id,
create_mode=create_mode,
import_source_properties=import_source_properties)

import_poller = server_client.begin_create(resource_group_name, server_name, parameters)

import_progress_bar = None

if data_source_type.lower() == "azure_blob":
import_progress_bar = OperationProgressBar(cmd.cli_ctx, import_poller, ImportFromStorageProgressHook())

return resolve_poller(
server_client.begin_create(resource_group_name, server_name, parameters), cmd.cli_ctx,
'{} Server Import Create'.format(logging_name))
import_poller, cmd.cli_ctx,
'{} Server Import Create'.format(logging_name),
progress_bar=import_progress_bar)


def flexible_server_connection_string(
Expand Down
Loading