Skip to content

Commit

Permalink
[qob] Add ability to attach to existing batch
Browse files Browse the repository at this point in the history
  • Loading branch information
jigold committed Mar 4, 2025
1 parent 51ca2b1 commit 2937ed8
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 3 deletions.
12 changes: 10 additions & 2 deletions hail/python/hail/backend/service_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ async def create(
driver_memory: Optional[str] = None,
worker_cores: Optional[Union[int, str]] = None,
worker_memory: Optional[str] = None,
batch_id: Optional[int] = None,
name_prefix: Optional[str] = None,
credentials_token: Optional[str] = None,
regions: Optional[List[str]] = None,
Expand All @@ -197,7 +198,7 @@ async def create(
if batch_client is None:
batch_client = await BatchClient.create(billing_project, _token=credentials_token)
async_exit_stack.push_async_callback(batch_client.close)
batch_attributes: Dict[str, str] = dict()

remote_tmpdir = get_remote_tmpdir('ServiceBackend', remote_tmpdir=remote_tmpdir)

name_prefix = configuration_of(ConfigVariable.QUERY_NAME_PREFIX, name_prefix, '')
Expand Down Expand Up @@ -244,6 +245,11 @@ async def create(
flags['gcs_requester_pays_project'] = gcs_requester_pays_configuration[0]
flags['gcs_requester_pays_buckets'] = ','.join(gcs_requester_pays_configuration[1])

if batch_id is not None:
batch = await batch_client.get_batch(batch_id)
else:
batch = None

sb = ServiceBackend(
billing_project=billing_project,
sync_fs=sync_fs,
Expand All @@ -256,6 +262,7 @@ async def create(
driver_memory=driver_memory,
worker_cores=worker_cores,
worker_memory=worker_memory,
batch=batch,
regions=regions,
async_exit_stack=async_exit_stack,
)
Expand All @@ -276,6 +283,7 @@ def __init__(
driver_memory: Optional[str],
worker_cores: Optional[Union[int, str]],
worker_memory: Optional[str],
batch: Optional[Batch],
regions: List[str],
async_exit_stack: AsyncExitStack,
):
Expand All @@ -297,7 +305,7 @@ def __init__(
self.worker_memory = worker_memory
self.regions = regions

self._batch: Batch = self._create_batch()
self._batch: Batch = self._create_batch() if batch is None else batch
self._async_exit_stack = async_exit_stack

def _create_batch(self) -> Batch:
Expand Down
8 changes: 8 additions & 0 deletions hail/python/hail/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ def stop(self):
driver_memory=nullable(str),
worker_cores=nullable(oneof(str, int)),
worker_memory=nullable(str),
batch_id=nullable(int),
gcs_requester_pays_configuration=nullable(oneof(str, sized_tupleof(str, sequenceof(str)))),
regions=nullable(sequenceof(str)),
gcs_bucket_allow_list=nullable(dictof(str, sequenceof(str))),
Expand Down Expand Up @@ -209,6 +210,7 @@ def init(
driver_memory=None,
worker_cores=None,
worker_memory=None,
batch_id=None,
gcs_requester_pays_configuration: Optional[GCSRequesterPaysConfiguration] = None,
regions: Optional[List[str]] = None,
gcs_bucket_allow_list: Optional[Dict[str, List[str]]] = None,
Expand Down Expand Up @@ -322,6 +324,8 @@ def init(
worker_memory : :class:`str`, optional
Batch backend only. Memory tier to use for the worker processes. May be standard or
highmem. Default is standard.
batch_id: :class:`int`, optional
Batch backend only. An existing batch id to add jobs to.
gcs_requester_pays_configuration : either :class:`str` or :class:`tuple` of :class:`str` and :class:`list` of :class:`str`, optional
If a string is provided, configure the Google Cloud Storage file system to bill usage to the
project identified by that string. If a tuple is provided, configure the Google Cloud
Expand Down Expand Up @@ -379,6 +383,7 @@ def init(
driver_memory=driver_memory,
worker_cores=worker_cores,
worker_memory=worker_memory,
batch_id=batch_id,
name_prefix=app_name,
gcs_requester_pays_configuration=gcs_requester_pays_configuration,
regions=regions,
Expand Down Expand Up @@ -523,6 +528,7 @@ def init_spark(
driver_memory=nullable(str),
worker_cores=nullable(oneof(str, int)),
worker_memory=nullable(str),
batch_id=nullable(int),
name_prefix=nullable(str),
token=nullable(str),
gcs_requester_pays_configuration=nullable(oneof(str, sized_tupleof(str, sequenceof(str)))),
Expand All @@ -545,6 +551,7 @@ async def init_batch(
driver_memory: Optional[str] = None,
worker_cores: Optional[Union[str, int]] = None,
worker_memory: Optional[str] = None,
batch_id: Optional[int] = None,
name_prefix: Optional[str] = None,
token: Optional[str] = None,
gcs_requester_pays_configuration: Optional[GCSRequesterPaysConfiguration] = None,
Expand All @@ -562,6 +569,7 @@ async def init_batch(
driver_memory=driver_memory,
worker_cores=worker_cores,
worker_memory=worker_memory,
batch_id=batch_id,
name_prefix=name_prefix,
credentials_token=token,
regions=regions,
Expand Down
22 changes: 21 additions & 1 deletion hail/python/test/hail/backend/test_service_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import hail as hl
from hail.backend.service_backend import ServiceBackend
from hailtop.batch_client.client import Batch, Job, JobGroup
from hailtop.batch_client.client import Batch, BatchClient, Job, JobGroup


@dataclass
Expand Down Expand Up @@ -129,3 +129,23 @@ def test_driver_and_worker_job_groups():
assert len(worker_jobs) == n_partitions
for i, partition in enumerate(worker_jobs):
assert partition['name'] == f'execute(...)_stage0_table_force_count_job{i}'


@pytest.mark.backend('batch')
def test_attach_to_existing_batch():
backend = hl.current_backend()
assert isinstance(backend, ServiceBackend)

batch_client = BatchClient.from_async(backend._batch_client)
b = batch_client.create_batch()
b.submit()
batch_id = b.id

hl.stop()
hl.init(backend='batch', batch_id=batch_id)

hl.utils.range_table(2)._force_count()

b = batch_client.get_batch(batch_id)
status = b.status()
assert status['n_jobs'] > 0, str(b.debug_info())

0 comments on commit 2937ed8

Please sign in to comment.