Skip to content

Commit

Permalink
Merge pull request #5141 from opsmill/lgu-merge-stable-release-1-1
Browse files Browse the repository at this point in the history
Merge stable into release-1.1
  • Loading branch information
LucasG0 authored Dec 10, 2024
2 parents d47f07c + df68791 commit c1e5ca3
Show file tree
Hide file tree
Showing 67 changed files with 1,250 additions and 223 deletions.
3 changes: 3 additions & 0 deletions backend/infrahub/api/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,9 @@ async def token(
_validate_response(response=userinfo_response)
user_info = userinfo_response.json()
sso_groups = user_info.get("groups", [])
if not sso_groups and config.SETTINGS.security.sso_user_default_group:
sso_groups = [config.SETTINGS.security.sso_user_default_group]

user_token = await signin_sso_account(db=db, account_name=user_info["name"], sso_groups=sso_groups)

response.set_cookie(
Expand Down
4 changes: 3 additions & 1 deletion backend/infrahub/api/oidc.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,10 @@ async def token(

_validate_response(response=userinfo_response)
user_info = userinfo_response.json()

sso_groups = user_info.get("groups", [])
if not sso_groups and config.SETTINGS.security.sso_user_default_group:
sso_groups = [config.SETTINGS.security.sso_user_default_group]

user_token = await signin_sso_account(db=db, account_name=user_info["name"], sso_groups=sso_groups)

response.set_cookie(
Expand Down
5 changes: 4 additions & 1 deletion backend/infrahub/api/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,10 +135,13 @@ def evaluate_candidate_schemas(
for schema in schemas_to_evaluate.schemas:
candidate_schema.load_schema(schema=schema)
candidate_schema.process()

schema_diff = branch_schema.diff(other=candidate_schema)
candidate_schema.validate_node_deletions(diff=schema_diff)
except ValueError as exc:
raise SchemaNotValidError(message=str(exc)) from exc

result = branch_schema.validate_update(other=candidate_schema)
result = branch_schema.validate_update(other=candidate_schema, diff=schema_diff)

if result.errors:
raise SchemaNotValidError(message=", ".join([error.to_string() for error in result.errors]))
Expand Down
6 changes: 5 additions & 1 deletion backend/infrahub/cli/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,11 @@ async def update_core_schema( # pylint: disable=too-many-statements
candidate_schema.load_schema(schema=SchemaRoot(**deprecated_models))
candidate_schema.process()

result = branch_schema.validate_update(other=candidate_schema, enforce_update_support=False)
schema_diff = branch_schema.diff(other=candidate_schema)
branch_schema.validate_node_deletions(diff=schema_diff)
result = branch_schema.validate_update(
other=candidate_schema, diff=schema_diff, enforce_update_support=False
)
if result.errors:
rprint(f"{error_badge} | Unable to update the schema, due to failed validations")
for error in result.errors:
Expand Down
12 changes: 12 additions & 0 deletions backend/infrahub/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ def default_cors_allow_headers() -> list[str]:
return ["accept", "authorization", "content-type", "user-agent", "x-csrftoken", "x-requested-with"]


def default_append_git_suffix_domains() -> list[str]:
return ["github.com", "gitlab.com"]


class UserInfoMethod(str, Enum):
POST = "post"
GET = "get"
Expand Down Expand Up @@ -355,6 +359,10 @@ class GitSettings(BaseSettings):
description="Time (in seconds) between git repositories synchronizations",
deprecated="This setting is deprecated and not currently in use.",
)
append_git_suffix: list[str] = Field(
default_factory=default_append_git_suffix_domains,
description="Automatically append '.git' to HTTP URLs if for these domains.",
)


class HTTPSettings(BaseSettings):
Expand Down Expand Up @@ -582,6 +590,10 @@ class SecuritySettings(BaseSettings):
oidc_provider_settings: SecurityOIDCProviderSettings = Field(default_factory=SecurityOIDCProviderSettings)
_oauth2_settings: dict[str, SecurityOAuth2Settings] = PrivateAttr(default_factory=dict)
_oidc_settings: dict[str, SecurityOIDCSettings] = PrivateAttr(default_factory=dict)
sso_user_default_group: str | None = Field(
default=None,
description="Name of the group to which users authenticated via SSO will belong if not provided by identity provider",
)

@model_validator(mode="after")
def check_oauth2_provider_settings(self) -> Self:
Expand Down
8 changes: 4 additions & 4 deletions backend/infrahub/core/account.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ async def query_init(self, db: InfrahubDatabase, **kwargs: Any) -> None:
self.params["account_id"] = self.account_id

branch_filter, branch_params = self.branch.get_query_filter_path(
at=self.at.to_string(), branch_agnostic=self.branch_agnostic
at=self.at.to_string(), branch_agnostic=self.branch_agnostic, is_isolated=False
)
self.params.update(branch_params)

Expand Down Expand Up @@ -185,7 +185,7 @@ async def query_init(self, db: InfrahubDatabase, **kwargs: Any) -> None:
self.params["account_id"] = self.account_id

branch_filter, branch_params = self.branch.get_query_filter_path(
at=self.at.to_string(), branch_agnostic=self.branch_agnostic
at=self.at.to_string(), branch_agnostic=self.branch_agnostic, is_isolated=False
)
self.params.update(branch_params)

Expand Down Expand Up @@ -336,7 +336,7 @@ async def query_init(self, db: InfrahubDatabase, **kwargs: Any) -> None:
self.params["role_id"] = self.role_id

branch_filter, branch_params = self.branch.get_query_filter_path(
at=self.at.to_string(), branch_agnostic=self.branch_agnostic
at=self.at.to_string(), branch_agnostic=self.branch_agnostic, is_isolated=False
)
self.params.update(branch_params)

Expand Down Expand Up @@ -425,7 +425,7 @@ async def query_init(self, db: InfrahubDatabase, **kwargs: Any) -> None:
self.params["role_id"] = self.role_id

branch_filter, branch_params = self.branch.get_query_filter_path(
at=self.at.to_string(), branch_agnostic=self.branch_agnostic
at=self.at.to_string(), branch_agnostic=self.branch_agnostic, is_isolated=False
)
self.params.update(branch_params)

Expand Down
178 changes: 89 additions & 89 deletions backend/infrahub/core/branch/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,101 +41,101 @@
@flow(name="branch-rebase", flow_run_name="Rebase branch {branch}")
async def rebase_branch(branch: str) -> None:
service = services.service
log = get_run_logger()
await add_branch_tag(branch_name=branch)

obj = await Branch.get_by_name(db=service.database, name=branch)
base_branch = await Branch.get_by_name(db=service.database, name=registry.default_branch)
component_registry = get_component_registry()
diff_repository = await component_registry.get_component(DiffRepository, db=service.database, branch=obj)
diff_coordinator = await component_registry.get_component(DiffCoordinator, db=service.database, branch=obj)
diff_merger = await component_registry.get_component(DiffMerger, db=service.database, branch=obj)
merger = BranchMerger(
db=service.database,
diff_coordinator=diff_coordinator,
diff_merger=diff_merger,
diff_repository=diff_repository,
source_branch=obj,
service=service,
)
diff_repository = await component_registry.get_component(DiffRepository, db=service.database, branch=obj)
enriched_diff = await diff_coordinator.update_branch_diff(base_branch=base_branch, diff_branch=obj)
if enriched_diff.get_all_conflicts():
raise ValidationError(
f"Branch {obj.name} contains conflicts with the default branch that must be addressed."
" Please review the diff for details and manually update the conflicts before rebasing."
async with service.database.start_session() as db:
log = get_run_logger()
await add_branch_tag(branch_name=branch)
obj = await Branch.get_by_name(db=db, name=branch)
base_branch = await Branch.get_by_name(db=db, name=registry.default_branch)
component_registry = get_component_registry()
diff_repository = await component_registry.get_component(DiffRepository, db=db, branch=obj)
diff_coordinator = await component_registry.get_component(DiffCoordinator, db=db, branch=obj)
diff_merger = await component_registry.get_component(DiffMerger, db=db, branch=obj)
merger = BranchMerger(
db=db,
diff_coordinator=diff_coordinator,
diff_merger=diff_merger,
diff_repository=diff_repository,
source_branch=obj,
service=service,
)
node_diff_field_summaries = await diff_repository.get_node_field_summaries(
diff_branch_name=enriched_diff.diff_branch_name, diff_id=enriched_diff.uuid
)

candidate_schema = merger.get_candidate_schema()
determiner = ConstraintValidatorDeterminer(schema_branch=candidate_schema)
constraints = await determiner.get_constraints(node_diffs=node_diff_field_summaries)

# If there are some changes related to the schema between this branch and main, we need to
# - Run all the validations to ensure everything is correct before rebasing the branch
# - Run all the migrations after the rebase
if obj.has_schema_changes:
constraints += await merger.calculate_validations(target_schema=candidate_schema)
if constraints:
responses = await schema_validate_migrations(
message=SchemaValidateMigrationData(branch=obj, schema_branch=candidate_schema, constraints=constraints)
diff_repository = await component_registry.get_component(DiffRepository, db=db, branch=obj)
enriched_diff = await diff_coordinator.update_branch_diff(base_branch=base_branch, diff_branch=obj)
if enriched_diff.get_all_conflicts():
raise ValidationError(
f"Branch {obj.name} contains conflicts with the default branch that must be addressed."
" Please review the diff for details and manually update the conflicts before rebasing."
)
node_diff_field_summaries = await diff_repository.get_node_field_summaries(
diff_branch_name=enriched_diff.diff_branch_name, diff_id=enriched_diff.uuid
)

error_messages = [violation.message for response in responses for violation in response.violations]
if error_messages:
raise ValidationError(",\n".join(error_messages))

schema_in_main_before = merger.destination_schema.duplicate()

async with lock.registry.global_graph_lock():
async with service.database.start_transaction() as dbt:
await obj.rebase(db=dbt)
log.info("Branch successfully rebased")
candidate_schema = merger.get_candidate_schema()
determiner = ConstraintValidatorDeterminer(schema_branch=candidate_schema)
constraints = await determiner.get_constraints(node_diffs=node_diff_field_summaries)

# If there are some changes related to the schema between this branch and main, we need to
# - Run all the validations to ensure everything is correct before rebasing the branch
# - Run all the migrations after the rebase
if obj.has_schema_changes:
# NOTE there is a bit additional work in order to calculate a proper diff that will
# allow us to pull only the part of the schema that has changed, for now the safest option is to pull
# Everything
# schema_diff = await merger.has_schema_changes()
# TODO Would be good to convert this part to a Prefect Task in order to track it properly
updated_schema = await registry.schema.load_schema_from_db(
db=service.database,
branch=obj,
# schema=merger.source_schema.duplicate(),
# schema_diff=schema_diff,
constraints += await merger.calculate_validations(target_schema=candidate_schema)
if constraints:
responses = await schema_validate_migrations(
message=SchemaValidateMigrationData(branch=obj, schema_branch=candidate_schema, constraints=constraints)
)
registry.schema.set_schema_branch(name=obj.name, schema=updated_schema)
obj.update_schema_hash()
await obj.save(db=service.database)
error_messages = [violation.message for response in responses for violation in response.violations]
if error_messages:
raise ValidationError(",\n".join(error_messages))

# Execute the migrations
migrations = await merger.calculate_migrations(target_schema=updated_schema)
schema_in_main_before = merger.destination_schema.duplicate()

errors = await schema_apply_migrations(
message=SchemaApplyMigrationData(
branch=merger.source_branch,
new_schema=candidate_schema,
previous_schema=schema_in_main_before,
migrations=migrations,
async with lock.registry.global_graph_lock():
async with db.start_transaction() as dbt:
await obj.rebase(db=dbt)
log.info("Branch successfully rebased")

if obj.has_schema_changes:
# NOTE there is a bit additional work in order to calculate a proper diff that will
# allow us to pull only the part of the schema that has changed, for now the safest option is to pull
# Everything
# schema_diff = await merger.has_schema_changes()
# TODO Would be good to convert this part to a Prefect Task in order to track it properly
updated_schema = await registry.schema.load_schema_from_db(
db=db,
branch=obj,
# schema=merger.source_schema.duplicate(),
# schema_diff=schema_diff,
)
)
for error in errors:
log.error(error)
registry.schema.set_schema_branch(name=obj.name, schema=updated_schema)
obj.update_schema_hash()
await obj.save(db=db)

# Execute the migrations
migrations = await merger.calculate_migrations(target_schema=updated_schema)

errors = await schema_apply_migrations(
message=SchemaApplyMigrationData(
branch=merger.source_branch,
new_schema=candidate_schema,
previous_schema=schema_in_main_before,
migrations=migrations,
)
)
for error in errors:
log.error(error)

# -------------------------------------------------------------
# Trigger the reconciliation of IPAM data after the rebase
# -------------------------------------------------------------
diff_parser = await component_registry.get_component(IpamDiffParser, db=service.database, branch=obj)
ipam_node_details = await diff_parser.get_changed_ipam_node_details(
source_branch_name=obj.name,
target_branch_name=registry.default_branch,
)
if ipam_node_details:
await service.workflow.submit_workflow(
workflow=IPAM_RECONCILIATION, parameters={"branch": obj.name, "ipam_node_details": ipam_node_details}
# -------------------------------------------------------------
# Trigger the reconciliation of IPAM data after the rebase
# -------------------------------------------------------------
diff_parser = await component_registry.get_component(IpamDiffParser, db=db, branch=obj)
ipam_node_details = await diff_parser.get_changed_ipam_node_details(
source_branch_name=obj.name,
target_branch_name=registry.default_branch,
)
if ipam_node_details:
await service.workflow.submit_workflow(
workflow=IPAM_RECONCILIATION, parameters={"branch": obj.name, "ipam_node_details": ipam_node_details}
)

await service.workflow.submit_workflow(workflow=DIFF_REFRESH_ALL, parameters={"branch_name": obj.name})

Expand All @@ -148,12 +148,12 @@ async def rebase_branch(branch: str) -> None:
@flow(name="branch-merge", flow_run_name="Merge branch {branch} into main")
async def merge_branch(branch: str) -> None:
service = services.service
log = get_run_logger()
async with service.database.start_session() as db:
log = get_run_logger()

await add_branch_tag(branch_name=branch)
await add_branch_tag(branch_name=registry.default_branch)
await add_branch_tag(branch_name=branch)
await add_branch_tag(branch_name=registry.default_branch)

async with service.database.start_session() as db:
obj = await Branch.get_by_name(db=db, name=branch)
component_registry = get_component_registry()

Expand Down Expand Up @@ -194,7 +194,7 @@ async def merge_branch(branch: str) -> None:
# -------------------------------------------------------------
# Trigger the reconciliation of IPAM data after the merge
# -------------------------------------------------------------
diff_parser = await component_registry.get_component(IpamDiffParser, db=service.database, branch=obj)
diff_parser = await component_registry.get_component(IpamDiffParser, db=db, branch=obj)
ipam_node_details = await diff_parser.get_changed_ipam_node_details(
source_branch_name=obj.name,
target_branch_name=registry.default_branch,
Expand All @@ -207,7 +207,7 @@ async def merge_branch(branch: str) -> None:
# -------------------------------------------------------------
# remove tracking ID from the diff because there is no diff after the merge
# -------------------------------------------------------------
diff_repository = await component_registry.get_component(DiffRepository, db=service.database, branch=obj)
diff_repository = await component_registry.get_component(DiffRepository, db=db, branch=obj)
await diff_repository.drop_tracking_ids(tracking_ids=[BranchTrackingId(name=obj.name)])

# -------------------------------------------------------------
Expand Down
35 changes: 18 additions & 17 deletions backend/infrahub/core/ipam/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,20 +22,21 @@
)
async def ipam_reconciliation(branch: str, ipam_node_details: list[IpamNodeDetails]) -> None:
service = services.service
branch_obj = await registry.get_branch(db=service.database, branch=branch)

await add_branch_tag(branch_name=branch_obj.name)

ipam_reconciler = IpamReconciler(db=service.database, branch=branch_obj)

for ipam_node_detail_item in ipam_node_details:
if ipam_node_detail_item.is_address:
ip_value: AllIPTypes = ipaddress.ip_interface(ipam_node_detail_item.ip_value)
else:
ip_value = ipaddress.ip_network(ipam_node_detail_item.ip_value)
await ipam_reconciler.reconcile(
ip_value=ip_value,
namespace=ipam_node_detail_item.namespace_id,
node_uuid=ipam_node_detail_item.node_uuid,
is_delete=ipam_node_detail_item.is_delete,
)
async with service.database.start_session() as db:
branch_obj = await registry.get_branch(db=db, branch=branch)

await add_branch_tag(branch_name=branch_obj.name)

ipam_reconciler = IpamReconciler(db=db, branch=branch_obj)

for ipam_node_detail_item in ipam_node_details:
if ipam_node_detail_item.is_address:
ip_value: AllIPTypes = ipaddress.ip_interface(ipam_node_detail_item.ip_value)
else:
ip_value = ipaddress.ip_network(ipam_node_detail_item.ip_value)
await ipam_reconciler.reconcile(
ip_value=ip_value,
namespace=ipam_node_detail_item.namespace_id,
node_uuid=ipam_node_detail_item.node_uuid,
is_delete=ipam_node_detail_item.is_delete,
)
Loading

0 comments on commit c1e5ca3

Please sign in to comment.