Skip to content

Commit

Permalink
Refactor model run to new function
Browse files Browse the repository at this point in the history
  • Loading branch information
gmertes committed Feb 15, 2024
1 parent 922b793 commit f23b5b8
Showing 1 changed file with 17 additions and 11 deletions.
28 changes: 17 additions & 11 deletions ai_models/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,24 +253,30 @@ def _main(argv):
if args.class_ is not None:
args.metadata["class"] = args.class_

model = load_model(args.model, **vars(args), model_args=unknownargs)

if args.fields:
model.print_fields()
sys.exit(0)

if args.requests_extra:
if not args.retrieve_requests and not args.archive_requests:
parser.error(
"You need to specify --retrieve-requests or --archive-requests"
)

run(vars(args), unknownargs)


def run(cfg: dict, model_args: list):
model = load_model(cfg["model"], **cfg, model_args=model_args)

if cfg["fields"]:
model.print_fields()
sys.exit(0)

# This logic is a bit convoluted, but it is for backwards compatibility.
if args.retrieve_requests or (args.requests_extra and not args.archive_requests):
if cfg["retrieve_requests"] or (
cfg["requests_extra"] and not cfg["archive_requests"]
):
model.print_requests()
sys.exit(0)

if args.assets_list:
if cfg["assets_list"]:
model.print_assets_list()
sys.exit(0)

Expand All @@ -280,7 +286,7 @@ def _main(argv):
LOG.exception(e)
LOG.error(
"It is possible that some files requited by %s are missing.",
args.model,
cfg["model"],
)
LOG.error("Rerun the command as:")
LOG.error(
Expand All @@ -291,9 +297,9 @@ def _main(argv):

model.finalise()

if args.dump_provenance:
if cfg["dump_provenance"]:
with Timer("Collect provenance information"):
with open(args.dump_provenance, "w") as f:
with open(cfg["dump_provenance"], "w") as f:
prov = model.provenance()
import json # import here so it is not listed in provenance

Expand Down

0 comments on commit f23b5b8

Please sign in to comment.