Skip to content

Commit

Permalink
support IP Namespaces on branches (#5369)
Browse files Browse the repository at this point in the history
* support branch during ip_namespace validation

* add changelog
  • Loading branch information
ajtmccarty authored Jan 3, 2025
1 parent e536909 commit 145e69f
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 15 deletions.
21 changes: 13 additions & 8 deletions backend/infrahub/graphql/mutations/ipam.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,20 @@


async def validate_namespace(
db: InfrahubDatabase, data: InputObjectType, existing_namespace_id: Optional[str] = None
db: InfrahubDatabase,
branch: Branch | str | None,
data: InputObjectType,
existing_namespace_id: Optional[str] = None,
) -> str:
"""Validate or set (if not present) the namespace to pass to the mutation and return its ID."""
namespace_id: Optional[str] = None
if "ip_namespace" not in data or not data["ip_namespace"]:
namespace_id = existing_namespace_id or registry.default_ipnamespace
data["ip_namespace"] = {"id": namespace_id}
elif "id" in data["ip_namespace"]:
namespace = await registry.manager.get_one(db=db, kind=InfrahubKind.IPNAMESPACE, id=data["ip_namespace"]["id"])
namespace = await registry.manager.get_one(
db=db, branch=branch, kind=InfrahubKind.IPNAMESPACE, id=data["ip_namespace"]["id"]
)
namespace_id = namespace.id
else:
raise ValidationError(
Expand Down Expand Up @@ -130,7 +135,7 @@ async def mutate_create(
context: GraphqlContext = info.context
db = database or context.db
ip_address = ipaddress.ip_interface(data["address"]["value"])
namespace_id = await validate_namespace(db=db, data=data)
namespace_id = await validate_namespace(db=db, branch=branch, data=data)

async with db.start_transaction() as dbt:
if lock_name := cls._get_lock_name(namespace_id, branch):
Expand Down Expand Up @@ -186,7 +191,7 @@ async def mutate_update(
include_source=True,
)
namespace = await address.ip_namespace.get_peer(db)
namespace_id = await validate_namespace(db=db, data=data, existing_namespace_id=namespace.id)
namespace_id = await validate_namespace(db=db, branch=branch, data=data, existing_namespace_id=namespace.id)
try:
async with db.start_transaction() as dbt:
if lock_name := cls._get_lock_name(namespace_id, branch):
Expand Down Expand Up @@ -217,7 +222,7 @@ async def mutate_upsert(
context: GraphqlContext = info.context
db = database or context.db

await validate_namespace(db=db, data=data)
await validate_namespace(db=db, branch=branch, data=data)
prefix, result, created = await super().mutate_upsert(
info=info, data=data, branch=branch, node_getters=node_getters, database=db
)
Expand Down Expand Up @@ -283,7 +288,7 @@ async def mutate_create(
) -> tuple[Node, Self]:
context: GraphqlContext = info.context
db = database or context.db
namespace_id = await validate_namespace(db=db, data=data)
namespace_id = await validate_namespace(db=db, branch=branch, data=data)

async with db.start_transaction() as dbt:
if lock_name := cls._get_lock_name(namespace_id, branch):
Expand Down Expand Up @@ -337,7 +342,7 @@ async def mutate_update(
include_source=True,
)
namespace = await prefix.ip_namespace.get_peer(db)
namespace_id = await validate_namespace(db=db, data=data, existing_namespace_id=namespace.id)
namespace_id = await validate_namespace(db=db, branch=branch, data=data, existing_namespace_id=namespace.id)
try:
async with db.start_transaction() as dbt:
if lock_name := cls._get_lock_name(namespace_id, branch):
Expand Down Expand Up @@ -367,7 +372,7 @@ async def mutate_upsert(
context: GraphqlContext = info.context
db = database or context.db

await validate_namespace(db=db, data=data)
await validate_namespace(db=db, branch=branch, data=data)
prefix, result, created = await super().mutate_upsert(
info=info, data=data, branch=branch, node_getters=node_getters, database=db
)
Expand Down
16 changes: 9 additions & 7 deletions backend/tests/unit/graphql/mutations/test_ipam.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,12 +320,12 @@ async def test_ipprefix_create(

async def test_ipprefix_create_with_ipnamespace(
db: InfrahubDatabase,
default_branch: Branch,
default_ipnamespace: Node,
register_core_models_schema: SchemaBranch,
register_ipam_schema: SchemaBranch,
branch: Branch,
):
ns = await Node.init(db=db, schema=InfrahubKind.NAMESPACE, branch=default_branch)
ns = await Node.init(db=db, schema=InfrahubKind.NAMESPACE, branch=branch)
await ns.new(db=db, name="ns1")
await ns.save(db=db)

Expand All @@ -349,7 +349,7 @@ async def test_ipprefix_create_with_ipnamespace(
}
"""

gql_params = prepare_graphql_params(db=db, include_subscription=False, branch=default_branch)
gql_params = prepare_graphql_params(db=db, include_subscription=False, branch=branch)

supernet = ipaddress.ip_network("2001:db8::/32")
result = await graphql(
Expand All @@ -363,7 +363,9 @@ async def test_ipprefix_create_with_ipnamespace(
assert result.data["IpamIPPrefixCreate"]["ok"]
assert result.data["IpamIPPrefixCreate"]["object"]["id"]

ip_prefix = await registry.manager.get_one(id=result.data["IpamIPPrefixCreate"]["object"]["id"], db=db)
ip_prefix = await registry.manager.get_one(
id=result.data["IpamIPPrefixCreate"]["object"]["id"], db=db, branch=branch
)
ip_namespace = await ip_prefix.ip_namespace.get_peer(db=db)
assert ip_namespace.id == ns.id

Expand Down Expand Up @@ -448,17 +450,17 @@ async def test_ipprefix_update(

async def test_ipprefix_update_within_namespace(
db: InfrahubDatabase,
default_branch: Branch,
default_ipnamespace: Node,
register_core_models_schema: SchemaBranch,
register_ipam_schema: SchemaBranch,
branch: Branch,
):
"""Make sure a prefix can be updated within a namespace."""
test_ns = await Node.init(db=db, schema=InfrahubKind.NAMESPACE)
test_ns = await Node.init(db=db, branch=branch, schema=InfrahubKind.NAMESPACE)
await test_ns.new(db=db, name="test")
await test_ns.save(db=db)

gql_params = prepare_graphql_params(db=db, include_subscription=False, branch=default_branch)
gql_params = prepare_graphql_params(db=db, include_subscription=False, branch=branch)

subnet = ipaddress.ip_network("2001:db8::/48")
result = await graphql(
Expand Down
1 change: 1 addition & 0 deletions changelog/+ipnamespace-branch.fixed.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed an issue that prevented using an IP Namespace on a branch

0 comments on commit 145e69f

Please sign in to comment.