Skip to content

Commit

Permalink
Merge pull request #5477 from opsmill/lgu-community-cli-context
Browse files Browse the repository at this point in the history
Cleanup around CliContext
  • Loading branch information
LucasG0 authored Jan 15, 2025
2 parents 971fce8 + 3a2624b commit 9ec036d
Show file tree
Hide file tree
Showing 4 changed files with 13 additions and 16 deletions.
2 changes: 1 addition & 1 deletion backend/infrahub/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
@app.callback()
def common(ctx: typer.Context) -> None:
"""Infrahub CLI"""
ctx.obj = CliContext(database_class=InfrahubDatabase)
ctx.obj = CliContext()


app.add_typer(server_app, name="server")
Expand Down
13 changes: 5 additions & 8 deletions backend/infrahub/cli/context.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,15 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING

from infrahub.database import get_db

if TYPE_CHECKING:
from infrahub.database import InfrahubDatabase
from infrahub.database import InfrahubDatabase, get_db


@dataclass
class CliContext:
database_class: type[InfrahubDatabase]
application: str = "infrahub.server:app"

async def get_db(self, retry: int = 0) -> InfrahubDatabase:
return self.database_class(driver=await get_db(retry=retry))
# This method is inherited for Infrahub Enterprise.
@staticmethod
async def init_db(retry: int) -> InfrahubDatabase:
return InfrahubDatabase(driver=await get_db(retry=retry))
12 changes: 6 additions & 6 deletions backend/infrahub/cli/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ async def init(
config.load_and_exit(config_file_name=config_file)

context: CliContext = ctx.obj
dbdriver = await context.get_db(retry=1)
dbdriver = await context.init_db(retry=1)
async with dbdriver.start_transaction() as db:
log.info("Delete All Nodes")
await delete_all_nodes(db=db)
Expand All @@ -120,7 +120,7 @@ async def load_test_data(
config.load_and_exit(config_file_name=config_file)

context: CliContext = ctx.obj
dbdriver = await context.get_db(retry=1)
dbdriver = await context.init_db(retry=1)
async with dbdriver.start_session() as db:
await initialization(db=db)

Expand Down Expand Up @@ -152,7 +152,7 @@ async def migrate(
config.load_and_exit(config_file_name=config_file)

context: CliContext = ctx.obj
dbdriver = await context.get_db(retry=1)
dbdriver = await context.init_db(retry=1)
async with dbdriver.start_session() as db:
rprint("Checking current state of the Database")

Expand Down Expand Up @@ -207,7 +207,7 @@ async def update_core_schema( # pylint: disable=too-many-statements
config.load_and_exit(config_file_name=config_file)

context: CliContext = ctx.obj
dbdriver = await context.get_db(retry=1)
dbdriver = await context.init_db(retry=1)

error_badge = "[bold red]ERROR[/bold red]"

Expand Down Expand Up @@ -332,7 +332,7 @@ async def constraint(
config.load_and_exit(config_file_name=config_file)

context: CliContext = ctx.obj
dbdriver = await context.get_db(retry=1)
dbdriver = await context.init_db(retry=1)

manager: Optional[ConstraintManagerBase] = None
if dbdriver.db_type == DatabaseType.NEO4J:
Expand Down Expand Up @@ -376,7 +376,7 @@ async def index(
config.load_and_exit(config_file_name=config_file)

context: CliContext = ctx.obj
dbdriver = await context.get_db(retry=1)
dbdriver = await context.init_db(retry=1)
dbdriver.manager.index.init(nodes=node_indexes, rels=rel_indexes)

if action == IndexAction.ADD:
Expand Down
2 changes: 1 addition & 1 deletion backend/infrahub/cli/git_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ async def start(
exporter_protocol=config.SETTINGS.trace.exporter_protocol,
)

database = await context.get_db(retry=1)
database = await context.init_db(retry=1)

workflow = config.OVERRIDE.workflow or (
WorkflowWorkerExecution()
Expand Down

0 comments on commit 9ec036d

Please sign in to comment.