diff --git a/backend/infrahub/cli/__init__.py b/backend/infrahub/cli/__init__.py index 5320e89803..2e91fcd3ad 100644 --- a/backend/infrahub/cli/__init__.py +++ b/backend/infrahub/cli/__init__.py @@ -20,7 +20,7 @@ @app.callback() def common(ctx: typer.Context) -> None: """Infrahub CLI""" - ctx.obj = CliContext(database_class=InfrahubDatabase) + ctx.obj = CliContext() app.add_typer(server_app, name="server") diff --git a/backend/infrahub/cli/context.py b/backend/infrahub/cli/context.py index 6f4e6a83e5..4785166875 100644 --- a/backend/infrahub/cli/context.py +++ b/backend/infrahub/cli/context.py @@ -1,18 +1,15 @@ from __future__ import annotations from dataclasses import dataclass -from typing import TYPE_CHECKING -from infrahub.database import get_db - -if TYPE_CHECKING: - from infrahub.database import InfrahubDatabase +from infrahub.database import InfrahubDatabase, get_db @dataclass class CliContext: - database_class: type[InfrahubDatabase] application: str = "infrahub.server:app" - async def get_db(self, retry: int = 0) -> InfrahubDatabase: - return self.database_class(driver=await get_db(retry=retry)) + # This method is inherited for Infrahub Enterprise. + @staticmethod + async def init_db(retry: int) -> InfrahubDatabase: + return InfrahubDatabase(driver=await get_db(retry=retry)) diff --git a/backend/infrahub/cli/db.py b/backend/infrahub/cli/db.py index 6712cec8f8..c970eab456 100644 --- a/backend/infrahub/cli/db.py +++ b/backend/infrahub/cli/db.py @@ -97,7 +97,7 @@ async def init( config.load_and_exit(config_file_name=config_file) context: CliContext = ctx.obj - dbdriver = await context.get_db(retry=1) + dbdriver = await context.init_db(retry=1) async with dbdriver.start_transaction() as db: log.info("Delete All Nodes") await delete_all_nodes(db=db) @@ -120,7 +120,7 @@ async def load_test_data( config.load_and_exit(config_file_name=config_file) context: CliContext = ctx.obj - dbdriver = await context.get_db(retry=1) + dbdriver = await context.init_db(retry=1) async with dbdriver.start_session() as db: await initialization(db=db) @@ -152,7 +152,7 @@ async def migrate( config.load_and_exit(config_file_name=config_file) context: CliContext = ctx.obj - dbdriver = await context.get_db(retry=1) + dbdriver = await context.init_db(retry=1) async with dbdriver.start_session() as db: rprint("Checking current state of the Database") @@ -207,7 +207,7 @@ async def update_core_schema( # pylint: disable=too-many-statements config.load_and_exit(config_file_name=config_file) context: CliContext = ctx.obj - dbdriver = await context.get_db(retry=1) + dbdriver = await context.init_db(retry=1) error_badge = "[bold red]ERROR[/bold red]" @@ -332,7 +332,7 @@ async def constraint( config.load_and_exit(config_file_name=config_file) context: CliContext = ctx.obj - dbdriver = await context.get_db(retry=1) + dbdriver = await context.init_db(retry=1) manager: Optional[ConstraintManagerBase] = None if dbdriver.db_type == DatabaseType.NEO4J: @@ -376,7 +376,7 @@ async def index( config.load_and_exit(config_file_name=config_file) context: CliContext = ctx.obj - dbdriver = await context.get_db(retry=1) + dbdriver = await context.init_db(retry=1) dbdriver.manager.index.init(nodes=node_indexes, rels=rel_indexes) if action == IndexAction.ADD: diff --git a/backend/infrahub/cli/git_agent.py b/backend/infrahub/cli/git_agent.py index e10cf8fd79..d68fc303b2 100644 --- a/backend/infrahub/cli/git_agent.py +++ b/backend/infrahub/cli/git_agent.py @@ -111,7 +111,7 @@ async def start( exporter_protocol=config.SETTINGS.trace.exporter_protocol, ) - database = await context.get_db(retry=1) + database = await context.init_db(retry=1) workflow = config.OVERRIDE.workflow or ( WorkflowWorkerExecution()