Skip to content

Commit

Permalink
IFC-1245 Initial implementation for object templates (#5610)
Browse files Browse the repository at this point in the history
  • Loading branch information
gmazoyer authored Feb 17, 2025
1 parent ff71a78 commit 5544ada
Show file tree
Hide file tree
Showing 41 changed files with 1,066 additions and 122 deletions.
20 changes: 17 additions & 3 deletions backend/infrahub/api/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
SchemaDiff,
SchemaUpdateValidationResult,
)
from infrahub.core.schema import GenericSchema, MainSchemaTypes, NodeSchema, ProfileSchema, SchemaRoot
from infrahub.core.schema import GenericSchema, MainSchemaTypes, NodeSchema, ProfileSchema, SchemaRoot, TemplateSchema
from infrahub.core.schema.constants import SchemaNamespace # noqa: TC001
from infrahub.core.validators.models.validate_migration import (
SchemaValidateMigrationData,
Expand Down Expand Up @@ -87,11 +87,17 @@ class APIProfileSchema(ProfileSchema, APISchemaMixin):
hash: str


class APITemplateSchema(TemplateSchema, APISchemaMixin):
api_kind: str | None = Field(default=None, alias="kind", validate_default=True)
hash: str


class SchemaReadAPI(BaseModel):
main: str = Field(description="Main hash for the entire schema")
nodes: list[APINodeSchema] = Field(default_factory=list)
generics: list[APIGenericSchema] = Field(default_factory=list)
profiles: list[APIProfileSchema] = Field(default_factory=list)
templates: list[APITemplateSchema] = Field(default_factory=list)
namespaces: list[SchemaNamespace] = Field(default_factory=list)


Expand Down Expand Up @@ -191,6 +197,11 @@ async def get_schema(
for value in all_schemas
if isinstance(value, ProfileSchema) and value.namespace != "Internal"
],
templates=[
APITemplateSchema.from_schema(value)
for value in all_schemas
if isinstance(value, TemplateSchema) and value.namespace != "Internal"
],
namespaces=schema_branch.get_namespaces(),
)

Expand All @@ -207,15 +218,16 @@ async def get_schema_summary(
@router.get("/{schema_kind}")
async def get_schema_by_kind(
schema_kind: str, branch: Branch = Depends(get_branch_dep), _: AccountSession = Depends(get_current_user)
) -> APIProfileSchema | APINodeSchema | APIGenericSchema:
) -> APIProfileSchema | APINodeSchema | APIGenericSchema | APITemplateSchema:
log.debug("schema_kind_request", branch=branch.name)

schema = registry.schema.get(name=schema_kind, branch=branch, duplicate=False)

api_schema: dict[str, type[APIProfileSchema | APINodeSchema | APIGenericSchema]] = {
api_schema: dict[str, type[APIProfileSchema | APINodeSchema | APIGenericSchema | APITemplateSchema]] = {
"profile": APIProfileSchema,
"node": APINodeSchema,
"generic": APIGenericSchema,
"template": APITemplateSchema,
}
key = ""

Expand All @@ -225,6 +237,8 @@ async def get_schema_by_kind(
key = "node"
if isinstance(schema, GenericSchema):
key = "generic"
if isinstance(schema, TemplateSchema):
key = "template"

return api_schema[key].from_schema(schema=schema)

Expand Down
5 changes: 5 additions & 0 deletions backend/infrahub/core/constants/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,7 @@ class RelationshipKind(InfrahubStringEnum):
GROUP = "Group"
HIERARCHY = "Hierarchy"
PROFILE = "Profile"
TEMPLATE = "Template"


class RelationshipStatus(InfrahubStringEnum):
Expand Down Expand Up @@ -301,6 +302,7 @@ class AttributeDBNodeType(InfrahubStringEnum):
"Lineage",
"Schema",
"Profile",
"Template",
]

NODE_NAME_REGEX = r"^[A-Z][a-zA-Z0-9]+$"
Expand All @@ -315,3 +317,6 @@ class AttributeDBNodeType(InfrahubStringEnum):
NAMESPACE_REGEX = r"^[A-Z][a-z0-9]+$"
NODE_KIND_REGEX = r"^[A-Z][a-zA-Z0-9]+$"
DEFAULT_REL_IDENTIFIER_LENGTH = 128

OBJECT_TEMPLATE_RELATIONSHIP_NAME = "object_template"
OBJECT_TEMPLATE_NAME_ATTR = "template_name"
1 change: 1 addition & 0 deletions backend/infrahub/core/constants/infrahubkind.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
LINEAGEOWNER = "LineageOwner"
LINEAGESOURCE = "LineageSource"
OBJECTPERMISSION = "CoreObjectPermission"
OBJECTTEMPLATE = "CoreObjectTemplate"
OBJECTTHREAD = "CoreObjectThread"
PASSWORDCREDENTIAL = "CorePasswordCredential"
PROFILE = "CoreProfile"
Expand Down
4 changes: 2 additions & 2 deletions backend/infrahub/core/diff/enricher/hierarchy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from infrahub.core.constants.database import DatabaseEdgeType
from infrahub.core.query.node import NodeGetHierarchyQuery
from infrahub.core.query.relationship import RelationshipGetPeerQuery, RelationshipPeerData
from infrahub.core.schema import ProfileSchema
from infrahub.core.schema import ProfileSchema, TemplateSchema
from infrahub.database import InfrahubDatabase

from ..model.path import (
Expand Down Expand Up @@ -38,7 +38,7 @@ async def enrich(
name=node.kind, branch=enriched_diff_root.diff_branch_name, duplicate=False
)

if isinstance(schema_node, ProfileSchema):
if isinstance(schema_node, ProfileSchema | TemplateSchema):
continue

if schema_node.has_parent_relationship:
Expand Down
13 changes: 10 additions & 3 deletions backend/infrahub/core/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,14 @@
from infrahub.core.query.relationship import RelationshipGetPeerQuery
from infrahub.core.registry import registry
from infrahub.core.relationship import Relationship, RelationshipManager
from infrahub.core.schema import GenericSchema, MainSchemaTypes, NodeSchema, ProfileSchema, RelationshipSchema
from infrahub.core.schema import (
GenericSchema,
MainSchemaTypes,
NodeSchema,
ProfileSchema,
RelationshipSchema,
TemplateSchema,
)
from infrahub.core.timestamp import Timestamp
from infrahub.exceptions import NodeNotFoundError, ProcessingError, SchemaNotFoundError
from infrahub.graphql.models import OrderModel
Expand Down Expand Up @@ -127,7 +134,7 @@ class NodeManager:
async def query(
cls,
db: InfrahubDatabase,
schema: Union[NodeSchema, GenericSchema, ProfileSchema, str],
schema: Union[NodeSchema, GenericSchema, ProfileSchema, TemplateSchema, str],
filters: dict | None = ...,
fields: dict | None = ...,
offset: int | None = ...,
Expand Down Expand Up @@ -265,7 +272,7 @@ async def query(
async def count(
cls,
db: InfrahubDatabase,
schema: Union[type[SchemaProtocol], NodeSchema, GenericSchema, ProfileSchema, str],
schema: Union[type[SchemaProtocol], NodeSchema, GenericSchema, ProfileSchema, TemplateSchema, str],
filters: Optional[dict] = None,
at: Optional[Union[Timestamp, str]] = None,
branch: Optional[Union[Branch, str]] = None,
Expand Down
80 changes: 69 additions & 11 deletions backend/infrahub/core/node/__init__.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,26 @@
from __future__ import annotations

from enum import Enum
from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union, overload
from typing import TYPE_CHECKING, Any, Optional, Sequence, TypeVar, Union, overload

from infrahub_sdk.utils import is_valid_uuid
from infrahub_sdk.uuidt import UUIDT

from infrahub.core import registry
from infrahub.core.changelog.models import NodeChangelog
from infrahub.core.constants import (
OBJECT_TEMPLATE_NAME_ATTR,
OBJECT_TEMPLATE_RELATIONSHIP_NAME,
BranchSupportType,
ComputedAttributeKind,
InfrahubKind,
RelationshipCardinality,
RelationshipKind,
)
from infrahub.core.constants.schema import SchemaElementPathType
from infrahub.core.protocols import CoreNumberPool
from infrahub.core.protocols import CoreNumberPool, CoreObjectTemplate
from infrahub.core.query.node import NodeCheckIDQuery, NodeCreateAllQuery, NodeDeleteQuery, NodeGetListQuery
from infrahub.core.schema import AttributeSchema, NodeSchema, ProfileSchema, RelationshipSchema
from infrahub.core.schema import AttributeSchema, NodeSchema, ProfileSchema, RelationshipSchema, TemplateSchema
from infrahub.core.timestamp import Timestamp
from infrahub.exceptions import InitializationError, NodeNotFoundError, PoolExhaustedError, ValidationError
from infrahub.support.macro import MacroDefinition
Expand Down Expand Up @@ -59,7 +61,7 @@ def __init_subclass_with_meta__(cls, _meta=None, default_filter=None, **options)
_meta.default_filter = default_filter
super().__init_subclass_with_meta__(_meta=_meta, **options)

def get_schema(self) -> Union[NodeSchema, ProfileSchema]:
def get_schema(self) -> Union[NodeSchema, ProfileSchema, TemplateSchema]:
return self._schema

def get_kind(self) -> str:
Expand Down Expand Up @@ -133,7 +135,7 @@ def get_labels(self) -> list[str]:
labels.append(InfrahubKind.NODE)
return labels

if isinstance(self._schema, ProfileSchema):
if isinstance(self._schema, ProfileSchema | TemplateSchema):
labels = [self.get_kind()] + self._schema.inherit_from
return labels

Expand All @@ -156,8 +158,8 @@ def __repr__(self) -> str:

return f"{self.get_kind()}(ID: {str(self.id)})"

def __init__(self, schema: Union[NodeSchema, ProfileSchema], branch: Branch, at: Timestamp):
self._schema: Union[NodeSchema, ProfileSchema] = schema
def __init__(self, schema: Union[NodeSchema, ProfileSchema, TemplateSchema], branch: Branch, at: Timestamp):
self._schema: Union[NodeSchema, ProfileSchema, TemplateSchema] = schema
self._branch: Branch = branch
self._at: Timestamp = at
self._existing: bool = False
Expand Down Expand Up @@ -187,7 +189,7 @@ def node_changelog(self) -> NodeChangelog:
@classmethod
async def init(
cls,
schema: Union[NodeSchema, ProfileSchema, str],
schema: Union[NodeSchema, ProfileSchema, TemplateSchema, str],
db: InfrahubDatabase,
branch: Optional[Union[Branch, str]] = ...,
at: Optional[Union[Timestamp, str]] = ...,
Expand All @@ -206,7 +208,7 @@ async def init(
@classmethod
async def init(
cls,
schema: Union[NodeSchema, ProfileSchema, str, type[SchemaProtocol]],
schema: Union[NodeSchema, ProfileSchema, TemplateSchema, str, type[SchemaProtocol]],
db: InfrahubDatabase,
branch: Optional[Union[Branch, str]] = None,
at: Optional[Union[Timestamp, str]] = None,
Expand All @@ -215,15 +217,17 @@ async def init(

branch = await registry.get_branch(branch=branch, db=db)

if isinstance(schema, NodeSchema | ProfileSchema):
if isinstance(schema, NodeSchema | ProfileSchema | TemplateSchema):
attrs["schema"] = schema
elif isinstance(schema, str):
# TODO need to raise a proper exception for this, right now it will raise a generic ValueError
attrs["schema"] = db.schema.get(name=schema, branch=branch)
elif hasattr(schema, "_is_runtime_protocol") and schema._is_runtime_protocol:
attrs["schema"] = db.schema.get(name=schema.__name__, branch=branch)
else:
raise ValueError(f"Invalid schema provided {type(schema)}, expected NodeSchema or ProfileSchema")
raise ValueError(
f"Invalid schema provided {type(schema)}, expected NodeSchema, ProfileSchema or TemplateSchema"
)

attrs["branch"] = branch
attrs["at"] = Timestamp(at)
Expand Down Expand Up @@ -272,6 +276,40 @@ async def handle_pool(self, db: InfrahubDatabase, attribute: BaseAttribute, erro
)
)

async def handle_object_template(self, fields: dict, db: InfrahubDatabase, errors: list) -> None:
"""Fill the `fields` parameters with values from an object template if one is in use."""
object_template_field = fields.get(OBJECT_TEMPLATE_RELATIONSHIP_NAME)
if not object_template_field:
return

try:
template: CoreObjectTemplate = await registry.manager.find_object(
db=db,
kind=self._schema.get_relationship(name=OBJECT_TEMPLATE_RELATIONSHIP_NAME).peer,
id=object_template_field.get("id"),
hfid=object_template_field.get("hfid"),
branch=self.get_branch_based_on_support_type(),
)
except NodeNotFoundError:
errors.append(
ValidationError(
{
f"{OBJECT_TEMPLATE_RELATIONSHIP_NAME}": (
"Unable to find the object template in the database "
f"'{object_template_field.get('id') or object_template_field.get('hfid')}'"
)
}
)
)
return

# Handle attributes, copy values from template
# Relationships handling in performed in GraphQL mutation to create nodes for relationships
for attribute in template._attributes:
if attribute in list(fields) + [OBJECT_TEMPLATE_NAME_ATTR]:
continue
fields[attribute] = {"value": getattr(template, attribute).value}

async def _process_fields(self, fields: dict, db: InfrahubDatabase) -> None:
errors = []

Expand All @@ -290,6 +328,9 @@ async def _process_fields(self, fields: dict, db: InfrahubDatabase) -> None:
if field_name not in self._schema.valid_input_names:
errors.append(ValidationError({field_name: f"{field_name} is not a valid input for {self.get_kind()}"}))

# Backfill fields with the ones from the template if there's one
await self.handle_object_template(fields=fields, db=db, errors=errors)

# If the object is new, we need to ensure that all mandatory attributes and relationships have been provided
if not self._existing:
for mandatory_attr in self._schema.mandatory_attribute_names:
Expand Down Expand Up @@ -804,3 +845,20 @@ def _get_parent_relationship_name(self) -> str | None:
for relationship in self._schema.relationships:
if relationship.kind == RelationshipKind.PARENT:
return relationship.name

async def get_object_template(self, db: InfrahubDatabase) -> Node | None:
object_template: RelationshipManager = getattr(self, OBJECT_TEMPLATE_RELATIONSHIP_NAME, None)
return None if not object_template else await object_template.get_peer(db=db)

def get_relationships(
self, kind: RelationshipKind, exclude: Sequence[str] | None = None
) -> list[RelationshipSchema]:
"""Return relationships of a given kind with the possiblity to exclude some of them by name."""
if exclude is None:
exclude = []

return [
relationship
for relationship in self.get_schema().relationships
if relationship.name not in exclude and relationship.kind == kind
]
8 changes: 4 additions & 4 deletions backend/infrahub/core/node/delete_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
RelationshipGetByIdentifierQuery,
RelationshipPeersData,
)
from infrahub.core.schema import MainSchemaTypes, NodeSchema, ProfileSchema
from infrahub.core.schema import MainSchemaTypes, NodeSchema, ProfileSchema, TemplateSchema
from infrahub.core.timestamp import Timestamp
from infrahub.database import InfrahubDatabase
from infrahub.exceptions import ValidationError
Expand All @@ -28,7 +28,7 @@ def __init__(self, all_schemas_map: dict[str, MainSchemaTypes]) -> None:
# {node_schema: {DeleteRelationshipType: {relationship_identifier: peer_node_schema}}}
self._dependency_graph: dict[str, dict[DeleteRelationshipType, dict[str, set[str]]]] = {}

def index(self, start_schemas: Iterable[NodeSchema | ProfileSchema]) -> None:
def index(self, start_schemas: Iterable[NodeSchema | ProfileSchema | TemplateSchema]) -> None:
self._index_cascading_deletes(start_schemas=start_schemas)
self._index_dependent_schema(start_schemas=start_schemas)

Expand All @@ -50,7 +50,7 @@ def _add_to_dependency_graph(
self._dependency_graph[kind][relationship_type] = defaultdict(set)
self._dependency_graph[kind][relationship_type][relationship_identifier].update(peer_kinds)

def _index_cascading_deletes(self, start_schemas: Iterable[NodeSchema | ProfileSchema]) -> None:
def _index_cascading_deletes(self, start_schemas: Iterable[NodeSchema | ProfileSchema | TemplateSchema]) -> None:
kinds_to_check: set[str] = {schema.kind for schema in start_schemas}
while True:
try:
Expand All @@ -72,7 +72,7 @@ def _index_cascading_deletes(self, start_schemas: Iterable[NodeSchema | ProfileS
if peer_kind not in self._dependency_graph:
kinds_to_check.add(peer_kind)

def _index_dependent_schema(self, start_schemas: Iterable[NodeSchema | ProfileSchema]) -> None:
def _index_dependent_schema(self, start_schemas: Iterable[NodeSchema | ProfileSchema | TemplateSchema]) -> None:
start_schema_kinds: set[str] = set()
for start_schema in start_schemas:
start_schema_kinds.add(start_schema.kind)
Expand Down
4 changes: 4 additions & 0 deletions backend/infrahub/core/protocols.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,10 @@ class CoreMenu(CoreNode):
children: RelationshipManager


class CoreObjectTemplate(CoreNode):
template_name: String


class CoreProfile(CoreNode):
profile_name: String
profile_priority: IntegerOptional
Expand Down
4 changes: 4 additions & 0 deletions backend/infrahub/core/protocols_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ class NodeSchema(Protocol): ...
class ProfileSchema(Protocol): ...


@runtime_checkable
class TemplateSchema(Protocol): ...


@runtime_checkable
class Branch(Protocol): ...

Expand Down
3 changes: 2 additions & 1 deletion backend/infrahub/core/query/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,13 @@
from infrahub.core.schema import GenericSchema, NodeSchema
from infrahub.core.schema.profile_schema import ProfileSchema
from infrahub.core.schema.relationship_schema import RelationshipSchema
from infrahub.core.schema.template_schema import TemplateSchema
from infrahub.database import InfrahubDatabase


@dataclass
class NodeToProcess:
schema: Optional[Union[NodeSchema, ProfileSchema]]
schema: Optional[Union[NodeSchema, ProfileSchema, TemplateSchema]]

node_id: str
node_uuid: str
Expand Down
Loading

0 comments on commit 5544ada

Please sign in to comment.