Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[qob] Add ability to attach to existing batch #14829

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions hail/python/hail/backend/service_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,7 @@ async def create(
driver_memory: Optional[str] = None,
worker_cores: Optional[Union[int, str]] = None,
worker_memory: Optional[str] = None,
batch_id: Optional[int] = None,
name_prefix: Optional[str] = None,
credentials_token: Optional[str] = None,
regions: Optional[List[str]] = None,
Expand All @@ -197,7 +198,7 @@ async def create(
if batch_client is None:
batch_client = await BatchClient.create(billing_project, _token=credentials_token)
async_exit_stack.push_async_callback(batch_client.close)
batch_attributes: Dict[str, str] = dict()
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This code was unused.


remote_tmpdir = get_remote_tmpdir('ServiceBackend', remote_tmpdir=remote_tmpdir)

name_prefix = configuration_of(ConfigVariable.QUERY_NAME_PREFIX, name_prefix, '')
Expand Down Expand Up @@ -244,6 +245,11 @@ async def create(
flags['gcs_requester_pays_project'] = gcs_requester_pays_configuration[0]
flags['gcs_requester_pays_buckets'] = ','.join(gcs_requester_pays_configuration[1])

if batch_id is not None:
batch = await batch_client.get_batch(batch_id)
else:
batch = None

sb = ServiceBackend(
billing_project=billing_project,
sync_fs=sync_fs,
Expand All @@ -256,6 +262,7 @@ async def create(
driver_memory=driver_memory,
worker_cores=worker_cores,
worker_memory=worker_memory,
batch=batch,
regions=regions,
async_exit_stack=async_exit_stack,
)
Expand All @@ -276,6 +283,7 @@ def __init__(
driver_memory: Optional[str],
worker_cores: Optional[Union[int, str]],
worker_memory: Optional[str],
batch: Optional[Batch],
regions: List[str],
async_exit_stack: AsyncExitStack,
):
Expand All @@ -297,7 +305,7 @@ def __init__(
self.worker_memory = worker_memory
self.regions = regions

self._batch: Batch = self._create_batch()
self._batch: Batch = self._create_batch() if batch is None else batch
self._async_exit_stack = async_exit_stack

def _create_batch(self) -> Batch:
Expand Down
8 changes: 8 additions & 0 deletions hail/python/hail/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,7 @@ def stop(self):
driver_memory=nullable(str),
worker_cores=nullable(oneof(str, int)),
worker_memory=nullable(str),
batch_id=nullable(int),
gcs_requester_pays_configuration=nullable(oneof(str, sized_tupleof(str, sequenceof(str)))),
regions=nullable(sequenceof(str)),
gcs_bucket_allow_list=nullable(dictof(str, sequenceof(str))),
Expand Down Expand Up @@ -209,6 +210,7 @@ def init(
driver_memory=None,
worker_cores=None,
worker_memory=None,
batch_id=None,
gcs_requester_pays_configuration: Optional[GCSRequesterPaysConfiguration] = None,
regions: Optional[List[str]] = None,
gcs_bucket_allow_list: Optional[Dict[str, List[str]]] = None,
Expand Down Expand Up @@ -322,6 +324,8 @@ def init(
worker_memory : :class:`str`, optional
Batch backend only. Memory tier to use for the worker processes. May be standard or
highmem. Default is standard.
batch_id: :class:`int`, optional
Batch backend only. An existing batch id to add jobs to.
gcs_requester_pays_configuration : either :class:`str` or :class:`tuple` of :class:`str` and :class:`list` of :class:`str`, optional
If a string is provided, configure the Google Cloud Storage file system to bill usage to the
project identified by that string. If a tuple is provided, configure the Google Cloud
Expand Down Expand Up @@ -379,6 +383,7 @@ def init(
driver_memory=driver_memory,
worker_cores=worker_cores,
worker_memory=worker_memory,
batch_id=batch_id,
name_prefix=app_name,
gcs_requester_pays_configuration=gcs_requester_pays_configuration,
regions=regions,
Expand Down Expand Up @@ -523,6 +528,7 @@ def init_spark(
driver_memory=nullable(str),
worker_cores=nullable(oneof(str, int)),
worker_memory=nullable(str),
batch_id=nullable(int),
name_prefix=nullable(str),
token=nullable(str),
gcs_requester_pays_configuration=nullable(oneof(str, sized_tupleof(str, sequenceof(str)))),
Expand All @@ -545,6 +551,7 @@ async def init_batch(
driver_memory: Optional[str] = None,
worker_cores: Optional[Union[str, int]] = None,
worker_memory: Optional[str] = None,
batch_id: Optional[int] = None,
name_prefix: Optional[str] = None,
token: Optional[str] = None,
gcs_requester_pays_configuration: Optional[GCSRequesterPaysConfiguration] = None,
Expand All @@ -562,6 +569,7 @@ async def init_batch(
driver_memory=driver_memory,
worker_cores=worker_cores,
worker_memory=worker_memory,
batch_id=batch_id,
name_prefix=name_prefix,
credentials_token=token,
regions=regions,
Expand Down
22 changes: 21 additions & 1 deletion hail/python/test/hail/backend/test_service_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import hail as hl
from hail.backend.service_backend import ServiceBackend
from hailtop.batch_client.client import Batch, Job, JobGroup
from hailtop.batch_client.client import Batch, BatchClient, Job, JobGroup


@dataclass
Expand Down Expand Up @@ -129,3 +129,23 @@ def test_driver_and_worker_job_groups():
assert len(worker_jobs) == n_partitions
for i, partition in enumerate(worker_jobs):
assert partition['name'] == f'execute(...)_stage0_table_force_count_job{i}'


@pytest.mark.backend('batch')
def test_attach_to_existing_batch():
backend = hl.current_backend()
assert isinstance(backend, ServiceBackend)

batch_client = BatchClient.from_async(backend._batch_client)
b = batch_client.create_batch()
b.submit()
batch_id = b.id

hl.stop()
hl.init(backend='batch', batch_id=batch_id)

hl.utils.range_table(2)._force_count()

b = batch_client.get_batch(batch_id)
status = b.status()
assert status['n_jobs'] > 0, str(b.debug_info())