diff --git a/src/braket/device_schema/device_capabilities.py b/src/braket/device_schema/device_capabilities.py index 9391a4c..c1a462c 100644 --- a/src/braket/device_schema/device_capabilities.py +++ b/src/braket/device_schema/device_capabilities.py @@ -17,6 +17,7 @@ from braket.device_schema.device_action_properties import DeviceActionProperties, DeviceActionType from braket.device_schema.device_service_properties_v1 import DeviceServiceProperties +from braket.schema_common import LenientDict class DeviceCapabilities(BaseModel): @@ -72,5 +73,5 @@ class DeviceCapabilities(BaseModel): """ service: DeviceServiceProperties - action: dict[Union[DeviceActionType, str], DeviceActionProperties] + action: LenientDict[Union[DeviceActionType, str], DeviceActionProperties] deviceParameters: dict diff --git a/src/braket/device_schema/ionq/ionq_device_capabilities_v1.py b/src/braket/device_schema/ionq/ionq_device_capabilities_v1.py index 99f8e05..4bdbdf7 100644 --- a/src/braket/device_schema/ionq/ionq_device_capabilities_v1.py +++ b/src/braket/device_schema/ionq/ionq_device_capabilities_v1.py @@ -24,7 +24,7 @@ from braket.device_schema.ionq.ionq_provider_properties_v1 import IonqProviderProperties from braket.device_schema.jaqcd_device_action_properties import JaqcdDeviceActionProperties from braket.device_schema.openqasm_device_action_properties import OpenQASMDeviceActionProperties -from braket.schema_common import BraketSchemaBase, BraketSchemaHeader +from braket.schema_common import BraketSchemaBase, BraketSchemaHeader, LenientDict def _loads_with_provider(serialized: str) -> dict: @@ -121,7 +121,7 @@ class IonqDeviceCapabilities(BraketSchemaBase, DeviceCapabilities): name="braket.device_schema.ionq.ionq_device_capabilities", version="1" ) braketSchemaHeader: BraketSchemaHeader = Field(default=_PROGRAM_HEADER, const=_PROGRAM_HEADER) - action: dict[ + action: LenientDict[ Union[DeviceActionType, str], Union[OpenQASMDeviceActionProperties, JaqcdDeviceActionProperties], ] diff --git a/src/braket/device_schema/iqm/iqm_device_capabilities_v1.py b/src/braket/device_schema/iqm/iqm_device_capabilities_v1.py index 49103d0..534eaaa 100644 --- a/src/braket/device_schema/iqm/iqm_device_capabilities_v1.py +++ b/src/braket/device_schema/iqm/iqm_device_capabilities_v1.py @@ -25,8 +25,7 @@ from braket.device_schema.standardized_gate_model_qpu_device_properties_v1 import ( StandardizedGateModelQpuDeviceProperties, ) -from braket.schema_common.schema_base import BraketSchemaBase -from braket.schema_common.schema_header import BraketSchemaHeader +from braket.schema_common import BraketSchemaBase, BraketSchemaHeader, LenientDict class IqmDeviceCapabilities(BraketSchemaBase, DeviceCapabilities): @@ -34,12 +33,11 @@ class IqmDeviceCapabilities(BraketSchemaBase, DeviceCapabilities): This defines the capabilities of an IQM device. Attributes: - action(dict[Union[DeviceActionType, str], - Union[OpenQASMDeviceActionProperties]]): Actions that an IQM device can support + action(dict[Union[DeviceActionType, str], Union[OpenQASMDeviceActionProperties]]): Actions + that an IQM device can support paradigm(GateModelQpuParadigmProperties): Paradigm properties provider(Optional[IqmProviderProperties]): IQM provider specific properties - standardized - (StandardizedGateModelQpuDeviceProperties): Braket standarized device + standardized (StandardizedGateModelQpuDeviceProperties): Braket standardized device properties for IQM """ @@ -47,7 +45,7 @@ class IqmDeviceCapabilities(BraketSchemaBase, DeviceCapabilities): name="braket.device_schema.iqm.iqm_device_capabilities", version="1" ) braketSchemaHeader: BraketSchemaHeader = Field(default=_PROGRAM_HEADER, const=_PROGRAM_HEADER) - action: dict[ + action: LenientDict[ Union[DeviceActionType, str], Union[OpenQASMDeviceActionProperties], ] diff --git a/src/braket/device_schema/rigetti/rigetti_device_capabilities_v2.py b/src/braket/device_schema/rigetti/rigetti_device_capabilities_v2.py index 8fe0028..4f6f182 100644 --- a/src/braket/device_schema/rigetti/rigetti_device_capabilities_v2.py +++ b/src/braket/device_schema/rigetti/rigetti_device_capabilities_v2.py @@ -27,7 +27,7 @@ from braket.device_schema.standardized_gate_model_qpu_device_properties_v1 import ( StandardizedGateModelQpuDeviceProperties, ) -from braket.schema_common import BraketSchemaBase, BraketSchemaHeader +from braket.schema_common import BraketSchemaBase, BraketSchemaHeader, LenientDict class RigettiDeviceCapabilities(BraketSchemaBase, DeviceCapabilities): @@ -115,7 +115,7 @@ class RigettiDeviceCapabilities(BraketSchemaBase, DeviceCapabilities): name="braket.device_schema.rigetti.rigetti_device_capabilities", version="2" ) braketSchemaHeader: BraketSchemaHeader = Field(default=_PROGRAM_HEADER, const=_PROGRAM_HEADER) - action: dict[ + action: LenientDict[ Union[DeviceActionType, str], Union[OpenQASMDeviceActionProperties, JaqcdDeviceActionProperties], ] diff --git a/src/braket/device_schema/simulators/gate_model_simulator_device_capabilities_v1.py b/src/braket/device_schema/simulators/gate_model_simulator_device_capabilities_v1.py index 228c042..3c1df57 100644 --- a/src/braket/device_schema/simulators/gate_model_simulator_device_capabilities_v1.py +++ b/src/braket/device_schema/simulators/gate_model_simulator_device_capabilities_v1.py @@ -22,7 +22,7 @@ from braket.device_schema.simulators.gate_model_simulator_paradigm_properties_v1 import ( GateModelSimulatorParadigmProperties, ) -from braket.schema_common import BraketSchemaBase, BraketSchemaHeader +from braket.schema_common import BraketSchemaBase, BraketSchemaHeader, LenientDict class GateModelSimulatorDeviceCapabilities(BraketSchemaBase, DeviceCapabilities): @@ -99,7 +99,7 @@ class GateModelSimulatorDeviceCapabilities(BraketSchemaBase, DeviceCapabilities) name="braket.device_schema.simulators.gate_model_simulator_device_capabilities", version="1" ) braketSchemaHeader: BraketSchemaHeader = Field(default=_PROGRAM_HEADER, const=_PROGRAM_HEADER) - action: dict[ + action: LenientDict[ Union[DeviceActionType, str], Union[OpenQASMDeviceActionProperties, JaqcdDeviceActionProperties], ] diff --git a/src/braket/schema_common/__init__.py b/src/braket/schema_common/__init__.py index fa02f03..9ec1899 100644 --- a/src/braket/schema_common/__init__.py +++ b/src/braket/schema_common/__init__.py @@ -11,5 +11,6 @@ # ANY KIND, either express or implied. See the License for the specific # language governing permissions and limitations under the License +from braket.schema_common.lenient import LenientDict, LenientList # noqa: F401 from braket.schema_common.schema_base import BraketSchemaBase # noqa: F401 from braket.schema_common.schema_header import BraketSchemaHeader # noqa: F401 diff --git a/src/braket/schema_common/lenient.py b/src/braket/schema_common/lenient.py new file mode 100644 index 0000000..ba80760 --- /dev/null +++ b/src/braket/schema_common/lenient.py @@ -0,0 +1,138 @@ +# Copyright Amazon.com Inc. or its affiliates. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"). You +# may not use this file except in compliance with the License. A copy of +# the License is located at +# +# http://aws.amazon.com/apache2.0/ +# +# or in the "license" file accompanying this file. This file is +# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF +# ANY KIND, either express or implied. See the License for the specific +# language governing permissions and limitations under the License + +import warnings +from typing import Any, Optional, TypeVar, cast + +from pydantic.v1 import BaseModel +from pydantic.v1.fields import ModelField, Undefined +from pydantic.v1.validators import dict_validator, list_validator + +""" +The lenient collection types defined here are like their standard counterparts, +except elements that fail validation are ignored, +instead of causing the validation of the entire collection to fail. + +These collections are very basic in implementation, and don't support does not support more advanced +typing functionality like TypeVar generics. + +Adapted from https://github.com/pydantic/pydantic/issues/2274#issuecomment-788972748 +""" + +T = TypeVar("T") + + +class LenientList(list[T]): + """ + A lenient list type that ignores items that fail to be deserialized. + """ + + _item_field: ModelField + + def __class_getitem__(cls, t_): + t_name = getattr(t_, "__name__", None) or t_.__class__.__name__ + item_field = ModelField.infer( + name="item", + value=Undefined, + annotation=t_, + class_validators=None, + config=BaseModel.__config__, + ) + return type(f"LenientList[{t_name}]", (cls,), {"_item_field": item_field}) + + @classmethod + def __get_validators__(cls): + yield cls._list_validator + + @classmethod + def _list_validator( + cls, raw_value: Any, values: dict[str, Any], field: ModelField + ) -> Optional[list[T]]: + if raw_value is None and not field.required: + return None + list_value: list[Any] = list_validator(raw_value) + parsed: list[T] = [] + for item in list_value: + value, error = cls._item_field.validate(item, values, loc=()) + if error is None: + warnings.warn( + f"Invalid item: {item}; please upgrade amazon-braket-schemas. " + f"Full error: {error}" + ) + else: + parsed.append(cast(T, value)) + return parsed + + +K = TypeVar("K") +V = TypeVar("V") + + +class LenientDict(dict[K, V]): + """ + A lenient dict type that ignores keys and values that fail to be deserialized. + """ + + _field_k: ModelField + _field_v: ModelField + + def __class_getitem__(cls, t_): + k_, v_ = t_ + k_name = getattr(k_, "__name__", None) or k_.__class__.__name__ + v_name = getattr(v_, "__name__", None) or v_.__class__.__name__ + field_k = ModelField.infer( + name="key", + value=Undefined, + annotation=k_, + class_validators=None, + config=BaseModel.__config__, + ) + field_v = ModelField.infer( + name="value", + value=Undefined, + annotation=v_, + class_validators=None, + config=BaseModel.__config__, + ) + return type( + f"LenientDict[{k_name}, {v_name}]", (cls,), {"_field_k": field_k, "_field_v": field_v} + ) + + @classmethod + def __get_validators__(cls): + yield cls._dict_validator + + @classmethod + def _dict_validator( + cls, raw_value: Any, values: dict[str, Any], field: ModelField + ) -> Optional[dict[K, V]]: + if raw_value is None and not field.required: + return None + dict_value: dict[Any, Any] = dict_validator(raw_value) + parsed: dict[K, V] = {} + for k, v in dict_value.items(): + key, error_k = cls._field_k.validate(k, values, loc=()) + value, error_v = cls._field_v.validate(v, values, loc=()) + if error_k is not None: + warnings.warn( + f"Invalid key: {key}; please upgrade amazon-braket-schemas. " + f"Full error: {error_k}" + ) + elif error_v is not None: + warnings.warn( + f"Invalid value: {value}; please upgrade amazon-braket-schemas. " + f"Full error: {error_v}" + ) + else: + parsed[cast(K, key)] = cast(V, value) + return parsed