Skip to content

Commit

Permalink
Merge pull request #240 from ecmwf-projects/COPDS-2419-filter-process…
Browse files Browse the repository at this point in the history
…es-by-portal

Filter processes by portals in all endpoints
  • Loading branch information
mcucchi9 authored Feb 7, 2025
2 parents 9e8d73b + 59c8d52 commit 8829abe
Show file tree
Hide file tree
Showing 7 changed files with 57 additions and 26 deletions.
5 changes: 3 additions & 2 deletions cads_processing_api_service/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import fastapi
import requests

from . import config, costing, exceptions, models
from . import config, costing, exceptions, models, utils

SETTINGS = config.settings

Expand Down Expand Up @@ -156,13 +156,14 @@ def get_auth_info(
auth_header = get_auth_header(pat, jwt)
user_uid, user_role, email = authenticate_user(auth_header, portal_header)
request_origin = REQUEST_ORIGIN[auth_header[0]]
portals = utils.get_portals(portal_header)
auth_info = models.AuthInfo(
user_uid=user_uid,
user_role=user_role,
email=email,
request_origin=request_origin,
auth_header=auth_header,
portal_header=portal_header,
portals=portals,
)
return auth_info

Expand Down
30 changes: 12 additions & 18 deletions cads_processing_api_service/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ def get_processes(
),
cursor: str | None = fastapi.Query(None, include_in_schema=False),
back: bool | None = fastapi.Query(None, include_in_schema=False),
portals: tuple[str] | None = fastapi.Depends(utils.get_portals),
) -> ogc_api_processes_fastapi.models.ProcessList:
"""Implement OGC API - Processes `GET /processes` endpoint.
Expand All @@ -108,9 +109,13 @@ def get_processes(
Hash string representing the reference to a particular process, used for pagination.
back : bool | None, optional
Specifies in which sense the list of processes should be traversed, used for pagination.
portals: tuple[str] | None
Portals
"""
statement = sqlalchemy.select(self.process_table)
sort_key, sort_dir = utils.parse_sortby(sortby.name)
if portals:
statement = statement.filter(self.process_table.portal.in_(portals))
if cursor:
statement = utils.apply_bookmark(
statement, self.process_table, cursor, back, sort_key, sort_dir
Expand Down Expand Up @@ -144,6 +149,7 @@ def get_process(
self,
response: fastapi.Response,
process_id: str = fastapi.Path(...),
portals: tuple[str] | None = fastapi.Depends(utils.get_portals),
) -> ogc_api_processes_fastapi.models.ProcessDescription:
"""Implement OGC API - Processes `GET /processes/{process_id}` endpoint.
Expand All @@ -155,6 +161,8 @@ def get_process(
fastapi.Response object.
process_id : str
Process identifier.
portals: tuple[str] | None
Portals
Returns
-------
Expand All @@ -169,6 +177,7 @@ def get_process(
resource_id=process_id,
table=self.process_table,
session=catalogue_session,
portals=portals,
)
process_description = serializers.serialize_process_description(resource)
process_description.outputs = {
Expand Down Expand Up @@ -221,11 +230,6 @@ def post_process_execution(
auth_info,
)
request_body = execution_content.model_dump()
portals = (
[p.strip() for p in auth_info.portal_header.split(",")]
if auth_info.portal_header
else None
)
catalogue_sessionmaker = db_utils.get_catalogue_sessionmaker(
db_utils.ConnectionMode.read
)
Expand All @@ -235,7 +239,7 @@ def post_process_execution(
table=self.process_table,
session=catalogue_session,
load_messages=True,
portals=tuple(portals),
portals=auth_info.portals,
)
auth.verify_if_disabled(dataset.disabled_reason, auth_info.user_role)
adaptor_properties = adaptors.get_adaptor_properties(dataset)
Expand Down Expand Up @@ -375,16 +379,11 @@ def get_jobs(
SETTINGS.rate_limits.jobs.get,
auth_info,
)
portals = (
[p.strip() for p in auth_info.portal_header.split(",")]
if auth_info.portal_header
else None
)
job_filters = {
"process_id": processID,
"status": status,
"user_uid": [auth_info.user_uid],
"portal": portals,
"portal": auth_info.portals,
}
sort_key, sort_dir = utils.parse_sortby(sortby.name)
statement = sqlalchemy.select(self.job_table)
Expand Down Expand Up @@ -495,11 +494,6 @@ def get_job(
SETTINGS.rate_limits.job.get,
auth_info,
)
portals = (
[p.strip() for p in auth_info.portal_header.split(",")]
if auth_info.portal_header
else None
)
compute_connection_mode = (
db_utils.ConnectionMode.write
if auth_info.request_origin == "ui"
Expand Down Expand Up @@ -554,7 +548,7 @@ def get_job(
"user_visible_log",
log_start_time,
)
if job.portal not in portals:
if job.portal not in auth_info.portals:
raise ogc_api_processes_fastapi.exceptions.NoSuchJob(
detail=f"job {job_id} not found"
)
Expand Down
6 changes: 5 additions & 1 deletion cads_processing_api_service/constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
def apply_constraints(
process_id: str = fastapi.Path(...),
execution_content: models.Execute = fastapi.Body(...),
portals: tuple[str] | None = fastapi.Depends(utils.get_portals),
) -> dict[str, Any]:
request = execution_content.model_dump()
table = cads_catalogue.database.Resource
Expand All @@ -21,7 +22,10 @@ def apply_constraints(
)
with catalogue_sessionmaker() as catalogue_session:
dataset = utils.lookup_resource_by_id(
resource_id=process_id, table=table, session=catalogue_session
resource_id=process_id,
table=table,
session=catalogue_session,
portals=portals,
)
adaptor: cads_adaptors.AbstractAdaptor = adaptors.instantiate_adaptor(dataset)
try:
Expand Down
12 changes: 8 additions & 4 deletions cads_processing_api_service/costing.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import cads_catalogue
import fastapi

from . import adaptors, costing, db_utils, exceptions, models, utils
from . import adaptors, db_utils, exceptions, models, utils

COST_THRESHOLDS = {"api": "max_costs", "ui": "max_costs_portal"}

Expand All @@ -38,6 +38,7 @@ def estimate_cost(
request_origin: RequestOrigin = fastapi.Query("api"),
mandatory_inputs: bool = fastapi.Query(False),
execution_content: models.Execute = fastapi.Body(...),
portals: tuple[str] | None = fastapi.Depends(utils.get_portals),
) -> models.RequestCost:
"""
Estimate the cost with the highest cost/limit ratio of the request.
Expand All @@ -61,18 +62,21 @@ def estimate_cost(
)
with catalogue_sessionmaker() as catalogue_session:
dataset = utils.lookup_resource_by_id(
resource_id=process_id, table=table, session=catalogue_session
resource_id=process_id,
table=table,
session=catalogue_session,
portals=portals,
)
adaptor_properties = adaptors.get_adaptor_properties(dataset)
request_is_valid = check_request_validity(
request=request,
mandatory_inputs=mandatory_inputs,
adaptor_properties=adaptor_properties,
)
costing_info = costing.compute_costing(
costing_info = compute_costing(
request.get("inputs", {}), adaptor_properties, request_origin
)
cost = costing.compute_highest_cost_limit_ratio(costing_info)
cost = compute_highest_cost_limit_ratio(costing_info)
if costing_info.cost_bar_steps:
cost.cost_bar_steps = costing_info.cost_bar_steps
costing_info.request_is_valid = request_is_valid
Expand Down
2 changes: 1 addition & 1 deletion cads_processing_api_service/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class AuthInfo(pydantic.BaseModel):
email: str | None = None
request_origin: str
auth_header: tuple[str, str]
portal_header: str | None = None
portals: tuple[str, ...] | None = None


class StatusCode(str, enum.Enum):
Expand Down
21 changes: 21 additions & 0 deletions cads_processing_api_service/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -659,3 +659,24 @@ def make_status_info(
log=log,
)
return status_info


def get_portals(
portal_header: str | None = fastapi.Header(None, alias=SETTINGS.portal_header_name),
) -> tuple[str, ...] | None:
"""Get the list of portals from the incoming HTTP request's header.
Parameters
----------
portal_header : str | None, optional
Portal header
Returns
-------
tuple[str] | None
List of portals.
"""
portals = (
tuple([p.strip() for p in portal_header.split(",")]) if portal_header else None
)
return portals
7 changes: 7 additions & 0 deletions tests/test_30_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,3 +333,10 @@ def test_make_status_info() -> None:
metadata={"origin": "api"},
)
assert status_info == exp_status_info


def test_get_portals() -> None:
portal_header = "portal1,portal2"
result = utils.get_portals(portal_header)
expected = ("portal1", "portal2")
assert result == expected

0 comments on commit 8829abe

Please sign in to comment.