Skip to content

Commit

Permalink
Add remote client
Browse files Browse the repository at this point in the history
  • Loading branch information
gmertes committed Feb 20, 2024
1 parent f23b5b8 commit 859cc34
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 1 deletion.
26 changes: 25 additions & 1 deletion ai_models/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .inputs import available_inputs
from .model import Timer, available_models, load_model
from .outputs import available_outputs
from .remote import RemoteRunner

LOG = logging.getLogger(__name__)

Expand Down Expand Up @@ -220,6 +221,18 @@ def _main(argv):
help="The model to run",
)

parser.add_argument(
"--remote-url",
default=os.getenv("AI_MODELS_REMOTE_URL"),
help="Remote endpoint URL",
)

parser.add_argument(
"--remote-token",
default=os.getenv("AI_MODELS_REMOTE_TOKEN"),
help="Remote endpoint auth token",
)

args, unknownargs = parser.parse_known_args(argv)

if args.models:
Expand Down Expand Up @@ -259,7 +272,18 @@ def _main(argv):
"You need to specify --retrieve-requests or --archive-requests"
)

run(vars(args), unknownargs)
if args.remote_url is not None:
if args.remote_token is None:
parser.error("You need to specify --remote-token")

RemoteRunner(
url=args.remote_url,
token=args.remote_token,
input_file=args.file,
output_file=args.path,
).run(vars(args), unknownargs)
else:
run(vars(args), unknownargs)


def run(cfg: dict, model_args: list):
Expand Down
74 changes: 74 additions & 0 deletions ai_models/remote.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import logging
import time
from urllib.parse import urljoin

import requests
from multiurl import download, robust

LOG = logging.getLogger(__name__)


class BearerAuth(requests.auth.AuthBase):
def __init__(self, token):
self.token = token

def __call__(self, r):
r.headers["authorization"] = "Bearer " + self.token
return r


class RemoteRunner:
def __init__(self, url: str, token: str, output_file: str, input_file: str = None):
self.url = url
self.auth = BearerAuth(token)
self.output_file = output_file
self.input_file = input_file
self._timeout = 10

def run(self, cfg: dict, model_args: list):
cfg.pop("remote_url", None)
cfg.pop("remote_token", None)
cfg["model_args"] = model_args

r = self._submit(cfg)
status = self._last_status
LOG.debug(r)
LOG.info("Job status: %s", self._last_status)

while not self._ready():
if status != self._last_status:
status = self._last_status
LOG.info("Job status: %s", status)
time.sleep(5)

download(urljoin(self.url, self._href), target=self.output_file)

LOG.info("Result written to %s", self.output_file)

def _submit(self, data):
r = robust(requests.post, retry_after=self._timeout)(
urljoin(self.url, "submit"),
json=data,
auth=self.auth,
timeout=self._timeout,
)
res = r.json()
self._uid = res["id"]
self._href = res["href"]
self._last_status = res["status"]
return res

def _status(self):
r = robust(requests.get, retry_after=self._timeout)(
urljoin(self.url, self._href), auth=self.auth, timeout=self._timeout
)

res = r.json()
LOG.debug(res)
self._href = res["href"]
self._last_status = res["status"]

return self._last_status

def _ready(self):
return self._status().lower() == "ready"

0 comments on commit 859cc34

Please sign in to comment.