From 78c1f1d9871801265841887c603756dfe1c43b74 Mon Sep 17 00:00:00 2001 From: Quigley Malcolm Date: Mon, 3 Mar 2025 15:21:24 -0600 Subject: [PATCH] Rewrite execution of microbatch models to avoid blocking the main thread (#11332) * Push orchestration of batches previously in the `RunTask` into `MicrobatchModelRunner` * Split `MicrobatchModelRunner` into two separate runners `MicrobatchModelRunner` is now an orchestrator of `MicrobatchBatchRunner`s, the latter being what handle actual batch execution * Introduce new `DbtThreadPool` that knows if it's been closed * Enable `MicrobatchModelRunner` to shutdown gracefully when it detects the thread pool has been closed --- .../unreleased/Fixes-20250303-131440.yaml | 6 + core/dbt/graph/thread_pool.py | 18 + .../incremental/microbatch.py | 9 +- core/dbt/task/build.py | 12 +- core/dbt/task/run.py | 691 ++++++++++-------- core/dbt/task/runnable.py | 6 +- .../functional/microbatch/test_microbatch.py | 38 +- .../incremental/test_microbatch.py | 12 +- tests/unit/task/test_run.py | 32 +- 9 files changed, 445 insertions(+), 379 deletions(-) create mode 100644 .changes/unreleased/Fixes-20250303-131440.yaml create mode 100644 core/dbt/graph/thread_pool.py diff --git a/.changes/unreleased/Fixes-20250303-131440.yaml b/.changes/unreleased/Fixes-20250303-131440.yaml new file mode 100644 index 00000000000..68b466a4936 --- /dev/null +++ b/.changes/unreleased/Fixes-20250303-131440.yaml @@ -0,0 +1,6 @@ +kind: Fixes +body: Fix microbatch execution to not block main thread nor hang +time: 2025-03-03T13:14:40.432874-06:00 +custom: + Author: QMalcolm + Issue: 11243 11306 diff --git a/core/dbt/graph/thread_pool.py b/core/dbt/graph/thread_pool.py new file mode 100644 index 00000000000..8a8fd755c7c --- /dev/null +++ b/core/dbt/graph/thread_pool.py @@ -0,0 +1,18 @@ +from __future__ import annotations + +from multiprocessing.pool import ThreadPool + + +class DbtThreadPool(ThreadPool): + """A ThreadPool that tracks whether or not it's been closed""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.closed = False + + def close(self): + self.closed = True + super().close() + + def is_closed(self): + return self.closed diff --git a/core/dbt/materializations/incremental/microbatch.py b/core/dbt/materializations/incremental/microbatch.py index 6de6945704c..24d5b9cec7a 100644 --- a/core/dbt/materializations/incremental/microbatch.py +++ b/core/dbt/materializations/incremental/microbatch.py @@ -100,7 +100,8 @@ def build_batches(self, start: datetime, end: datetime) -> List[BatchType]: return batches - def build_jinja_context_for_batch(self, incremental_batch: bool) -> Dict[str, Any]: + @staticmethod + def build_jinja_context_for_batch(model: ModelNode, incremental_batch: bool) -> Dict[str, Any]: """ Create context with entries that reflect microbatch model + incremental execution state @@ -109,9 +110,9 @@ def build_jinja_context_for_batch(self, incremental_batch: bool) -> Dict[str, An jinja_context: Dict[str, Any] = {} # Microbatch model properties - jinja_context["model"] = self.model.to_dict() - jinja_context["sql"] = self.model.compiled_code - jinja_context["compiled_code"] = self.model.compiled_code + jinja_context["model"] = model.to_dict() + jinja_context["sql"] = model.compiled_code + jinja_context["compiled_code"] = model.compiled_code # Add incremental context variables for batches running incrementally if incremental_batch: diff --git a/core/dbt/task/build.py b/core/dbt/task/build.py index ff68d976744..9940c1c5468 100644 --- a/core/dbt/task/build.py +++ b/core/dbt/task/build.py @@ -169,7 +169,8 @@ def call_model_and_unit_tests_runner(self, node, pool) -> RunResult: runner.do_skip(cause=cause) if isinstance(runner, MicrobatchModelRunner): - return self.handle_microbatch_model(runner, pool) + runner.set_parent_task(self) + runner.set_pool(pool) return self.call_runner(runner) @@ -184,10 +185,11 @@ def handle_job_queue_node(self, node, pool, callback): runner.do_skip(cause=cause) if isinstance(runner, MicrobatchModelRunner): - callback(self.handle_microbatch_model(runner, pool)) - else: - args = [runner] - self._submit(pool, args, callback) + runner.set_parent_task(self) + runner.set_pool(pool) + + args = [runner] + self._submit(pool, args, callback) # Make a map of model unique_ids to selected unit test unique_ids, # for processing before the model. diff --git a/core/dbt/task/run.py b/core/dbt/task/run.py index 44d45272cfb..67718318144 100644 --- a/core/dbt/task/run.py +++ b/core/dbt/task/run.py @@ -1,10 +1,11 @@ +from __future__ import annotations + import functools import threading import time from copy import deepcopy from dataclasses import asdict from datetime import datetime -from multiprocessing.pool import ThreadPool from typing import AbstractSet, Any, Dict, Iterable, List, Optional, Set, Tuple, Type from dbt import tracking, utils @@ -40,6 +41,7 @@ ) from dbt.exceptions import CompilationError, DbtInternalError, DbtRuntimeError from dbt.graph import ResourceTypeSelector +from dbt.graph.thread_pool import DbtThreadPool from dbt.hooks import get_hook_dict from dbt.materializations.incremental.microbatch import MicrobatchBuilder from dbt.node_types import NodeType, RunHookType @@ -332,80 +334,36 @@ def execute(self, model, manifest): return self._execute_model(hook_ctx, context_config, model, context, materialization_macro) -class MicrobatchModelRunner(ModelRunner): - def __init__(self, config, adapter, node, node_index: int, num_nodes: int): - super().__init__(config, adapter, node, node_index, num_nodes) - - self.batch_idx: Optional[int] = None - self.batches: Dict[int, BatchType] = {} - self.relation_exists: bool = False +class MicrobatchBatchRunner(ModelRunner): + """Handles the running of individual batches""" - def compile(self, manifest: Manifest): - if self.batch_idx is not None: - batch = self.batches[self.batch_idx] - - # LEGACY: Set start/end in context prior to re-compiling (Will be removed for 1.10+) - # TODO: REMOVE before 1.10 GA - self.node.config["__dbt_internal_microbatch_event_time_start"] = batch[0] - self.node.config["__dbt_internal_microbatch_event_time_end"] = batch[1] - # Create batch context on model node prior to re-compiling - self.node.batch = BatchContext( - id=MicrobatchBuilder.batch_id(batch[0], self.node.config.batch_size), - event_time_start=batch[0], - event_time_end=batch[1], - ) - # Recompile node to re-resolve refs with event time filters rendered, update context - self.compiler.compile_node( - self.node, - manifest, - {}, - split_suffix=MicrobatchBuilder.format_batch_start( - batch[0], self.node.config.batch_size - ), - ) - - # Skips compilation for non-batch runs - return self.node + def __init__( + self, + config, + adapter, + node, + node_index: int, + num_nodes: int, + batch_idx: int, + batches: Dict[int, BatchType], + relation_exists: bool, + incremental_batch: bool, + ): + super().__init__(config, adapter, node, node_index, num_nodes) - def set_batch_idx(self, batch_idx: int) -> None: self.batch_idx = batch_idx - - def set_relation_exists(self, relation_exists: bool) -> None: - self.relation_exists = relation_exists - - def set_batches(self, batches: Dict[int, BatchType]) -> None: self.batches = batches - - @property - def batch_start(self) -> Optional[datetime]: - if self.batch_idx is None: - return None - else: - return self.batches[self.batch_idx][0] - - def describe_node(self) -> str: - return f"{self.node.language} microbatch model {self.get_node_representation()}" + self.relation_exists = relation_exists + self.incremental_batch = incremental_batch def describe_batch(self) -> str: - batch_start = self.batch_start - if batch_start is None: - return "" - - # Only visualize date if batch_start year/month/day + batch_start = self.batches[self.batch_idx][0] formatted_batch_start = MicrobatchBuilder.format_batch_start( batch_start, self.node.config.batch_size ) return f"batch {formatted_batch_start} of {self.get_node_representation()}" - def print_batch_result_line( - self, - result: RunResult, - ): - if self.batch_idx is None: - return - - description = self.describe_batch() - group = group_lookup.get(self.node.unique_id) + def print_result_line(self, result: RunResult): if result.status == NodeStatus.Error: status = result.status level = EventLevel.ERROR @@ -415,96 +373,85 @@ def print_batch_result_line( else: status = result.message level = EventLevel.INFO + fire_event( LogBatchResult( - description=description, + description=self.describe_batch(), status=status, batch_index=self.batch_idx + 1, total_batches=len(self.batches), execution_time=result.execution_time, node_info=self.node.node_info, - group=group, + group=group_lookup.get(self.node.unique_id), ), level=level, ) - def print_batch_start_line(self) -> None: - if self.batch_idx is None: - return - - batch_start = self.batches[self.batch_idx][0] - if batch_start is None: - return - - batch_description = self.describe_batch() + def print_start_line(self) -> None: fire_event( LogStartBatch( - description=batch_description, + description=self.describe_batch(), batch_index=self.batch_idx + 1, total_batches=len(self.batches), node_info=self.node.node_info, ) ) - def before_execute(self) -> None: - if self.batch_idx is None: - self.print_start_line() + def should_run_in_parallel(self) -> bool: + if not self.adapter.supports(Capability.MicrobatchConcurrency): + run_in_parallel = False + elif not self.relation_exists: + # If the relation doesn't exist, we can't run in parallel + run_in_parallel = False + elif self.node.config.concurrent_batches is not None: + # If the relation exists and the `concurrent_batches` config isn't None, use the config value + run_in_parallel = self.node.config.concurrent_batches else: - self.print_batch_start_line() - - def after_execute(self, result) -> None: - if self.batch_idx is not None: - self.print_batch_result_line(result) - - def merge_batch_results(self, result: RunResult, batch_results: List[RunResult]): - """merge batch_results into result""" - if result.batch_results is None: - result.batch_results = BatchResults() - - for batch_result in batch_results: - if batch_result.batch_results is not None: - result.batch_results += batch_result.batch_results - result.execution_time += batch_result.execution_time + # If the relation exists, the `concurrent_batches` config is None, check if the model self references `this`. + # If the model self references `this` then we assume the model batches _can't_ be run in parallel + run_in_parallel = not self.node.has_this - num_successes = len(result.batch_results.successful) - num_failures = len(result.batch_results.failed) - if num_failures == 0: - status = RunStatus.Success - msg = "SUCCESS" - elif num_successes == 0: - status = RunStatus.Error - msg = "ERROR" - else: - status = RunStatus.PartialSuccess - msg = f"PARTIAL SUCCESS ({num_successes}/{num_successes + num_failures})" - result.status = status - result.message = msg + return run_in_parallel - result.batch_results.successful = sorted(result.batch_results.successful) - result.batch_results.failed = sorted(result.batch_results.failed) + def on_skip(self): + result = RunResult( + node=self.node, + status=RunStatus.Skipped, + timing=[], + thread_id=threading.current_thread().name, + execution_time=0.0, + message="SKIPPED", + adapter_response={}, + failures=1, + batch_results=BatchResults(failed=[self.batches[self.batch_idx]]), + ) + self.print_result_line(result=result) + return result - # # If retrying, propagate previously successful batches into final result, even thoguh they were not run in this invocation - if self.node.previous_batch_results is not None: - result.batch_results.successful += self.node.previous_batch_results.successful + def compile(self, manifest: Manifest): + batch = self.batches[self.batch_idx] + + # LEGACY: Set start/end in context prior to re-compiling (Will be removed for 1.10+) + # TODO: REMOVE before 1.10 GA + self.node.config["__dbt_internal_microbatch_event_time_start"] = batch[0] + self.node.config["__dbt_internal_microbatch_event_time_end"] = batch[1] + # Create batch context on model node prior to re-compiling + self.node.batch = BatchContext( + id=MicrobatchBuilder.batch_id(batch[0], self.node.config.batch_size), + event_time_start=batch[0], + event_time_end=batch[1], + ) + # Recompile node to re-resolve refs with event time filters rendered, update context + self.compiler.compile_node( + self.node, + manifest, + {}, + split_suffix=MicrobatchBuilder.format_batch_start( + batch[0], self.node.config.batch_size + ), + ) - def on_skip(self): - # If node.batch is None, then we're dealing with skipping of the entire node - if self.batch_idx is None: - return super().on_skip() - else: - result = RunResult( - node=self.node, - status=RunStatus.Skipped, - timing=[], - thread_id=threading.current_thread().name, - execution_time=0.0, - message="SKIPPED", - adapter_response={}, - failures=1, - batch_results=BatchResults(failed=[self.batches[self.batch_idx]]), - ) - self.print_batch_result_line(result=result) - return result + return self.node def _build_succesful_run_batch_result( self, @@ -535,125 +482,124 @@ def _build_failed_run_batch_result( batch_results=BatchResults(failed=[batch]), ) - def _build_run_microbatch_model_result(self, model: ModelNode) -> RunResult: - return RunResult( - node=model, - status=RunStatus.Success, - timing=[], - thread_id=threading.current_thread().name, - # The execution_time here doesn't get propagated to logs because - # `safe_run_hooks` handles the elapsed time at the node level - execution_time=0, - message="", - adapter_response={}, - failures=0, - batch_results=BatchResults(), - ) - def _execute_microbatch_materialization( self, model: ModelNode, context: Dict[str, Any], materialization_macro: MacroProtocol, ) -> RunResult: - microbatch_builder = MicrobatchBuilder( - model=model, - is_incremental=self._is_incremental(model), - event_time_start=getattr(self.config.args, "EVENT_TIME_START", None), - event_time_end=getattr(self.config.args, "EVENT_TIME_END", None), - default_end_time=self.config.invoked_at, - ) - if self.batch_idx is None: - # Note currently (9/30/2024) model.previous_batch_results is only ever _not_ `None` - # IFF `dbt retry` is being run and the microbatch model had batches which - # failed on the run of the model (which is being retried) - if model.previous_batch_results is None: - end = microbatch_builder.build_end_time() - start = microbatch_builder.build_start_time(end) - batches = microbatch_builder.build_batches(start, end) - else: - batches = model.previous_batch_results.failed - # If there is batch info, then don't run as full_refresh and do force is_incremental - # not doing this risks blowing away the work that has already been done - if self._has_relation(model=model): - self.relation_exists = True + batch = self.batches[self.batch_idx] + # call materialization_macro to get a batch-level run result + start_time = time.perf_counter() + try: + # Update jinja context with batch context members + jinja_context = MicrobatchBuilder.build_jinja_context_for_batch( + model=model, + incremental_batch=self.incremental_batch, + ) + context.update(jinja_context) - batch_result = self._build_run_microbatch_model_result(model) - self.batches = {batch_idx: batches[batch_idx] for batch_idx in range(len(batches))} + # Materialize batch and cache any materialized relations + result = MacroGenerator( + materialization_macro, context, stack=context["context_macro_stack"] + )() + for relation in self._materialization_relations(result, model): + self.adapter.cache_added(relation.incorporate(dbt_created=True)) - else: - batch = self.batches[self.batch_idx] - # call materialization_macro to get a batch-level run result - start_time = time.perf_counter() - try: - # Update jinja context with batch context members - jinja_context = microbatch_builder.build_jinja_context_for_batch( - incremental_batch=self.relation_exists - ) - context.update(jinja_context) - - # Materialize batch and cache any materialized relations - result = MacroGenerator( - materialization_macro, context, stack=context["context_macro_stack"] - )() - for relation in self._materialization_relations(result, model): - self.adapter.cache_added(relation.incorporate(dbt_created=True)) - - # Build result of executed batch - batch_run_result = self._build_succesful_run_batch_result( - model, context, batch, time.perf_counter() - start_time - ) - batch_result = batch_run_result + # Build result of executed batch + batch_run_result = self._build_succesful_run_batch_result( + model, context, batch, time.perf_counter() - start_time + ) + batch_result = batch_run_result - # At least one batch has been inserted successfully! - # Can proceed incrementally + in parallel - self.relation_exists = True + # At least one batch has been inserted successfully! + # Can proceed incrementally + in parallel + self.relation_exists = True - except (KeyboardInterrupt, SystemExit): - # reraise it for GraphRunnableTask.execute_nodes to handle - raise - except Exception as e: - fire_event( - GenericExceptionOnRun( - unique_id=self.node.unique_id, - exc=f"Exception on worker thread. {str(e)}", - node_info=self.node.node_info, - ) - ) - batch_run_result = self._build_failed_run_batch_result( - model, batch, time.perf_counter() - start_time + except (KeyboardInterrupt, SystemExit): + # reraise it for GraphRunnableTask.execute_nodes to handle + raise + except Exception as e: + fire_event( + GenericExceptionOnRun( + unique_id=self.node.unique_id, + exc=f"Exception on worker thread. {str(e)}", + node_info=self.node.node_info, ) + ) + batch_run_result = self._build_failed_run_batch_result( + model, batch, time.perf_counter() - start_time + ) - batch_result = batch_run_result + batch_result = batch_run_result return batch_result - def _has_relation(self, model) -> bool: + def _execute_model( + self, + hook_ctx: Any, + context_config: Any, + model: ModelNode, + context: Dict[str, Any], + materialization_macro: MacroProtocol, + ) -> RunResult: + try: + batch_result = self._execute_microbatch_materialization( + model, context, materialization_macro + ) + finally: + self.adapter.post_model_hook(context_config, hook_ctx) + + return batch_result + + +class MicrobatchModelRunner(ModelRunner): + """Handles the orchestration of batches to run for a given microbatch model""" + + def __init__(self, config, adapter, node, node_index: int, num_nodes: int): + super().__init__(config, adapter, node, node_index, num_nodes) + + # The parent task is necessary because we need access to the `_submit_batch` and `submit` methods + self._parent_task: Optional[RunTask] = None + # The pool is necessary because we need to batches to be executed within the same thread pool + self._pool: Optional[DbtThreadPool] = None + + def set_parent_task(self, parent_task: RunTask) -> None: + self._parent_task = parent_task + + def set_pool(self, pool: DbtThreadPool) -> None: + self._pool = pool + + @property + def parent_task(self) -> RunTask: + if self._parent_task is None: + raise DbtInternalError( + msg="Tried to access `parent_task` of `MicrobatchModelRunner` before it was set" + ) + + return self._parent_task + + @property + def pool(self) -> DbtThreadPool: + if self._pool is None: + raise DbtInternalError( + msg="Tried to access `pool` of `MicrobatchModelRunner` before it was set" + ) + + return self._pool + + def _has_relation(self, model: ModelNode) -> bool: + """Check whether the relation for the model exists in the data warehouse""" relation_info = self.adapter.Relation.create_from(self.config, model) relation = self.adapter.get_relation( relation_info.database, relation_info.schema, relation_info.name ) return relation is not None - def should_run_in_parallel(self) -> bool: - if not self.adapter.supports(Capability.MicrobatchConcurrency): - run_in_parallel = False - elif not self.relation_exists: - # If the relation doesn't exist, we can't run in parallel - run_in_parallel = False - elif self.node.config.concurrent_batches is not None: - # If the relation exists and the `concurrent_batches` config isn't None, use the config value - run_in_parallel = self.node.config.concurrent_batches - else: - # If the relation exists, the `concurrent_batches` config is None, check if the model self references `this`. - # If the model self references `this` then we assume the model batches _can't_ be run in parallel - run_in_parallel = not self.node.has_this - - return run_in_parallel - def _is_incremental(self, model) -> bool: - # TODO: Remove. This is a temporary method. We're working with adapters on + """Check whether the model should be run `incrementally` or as `full refresh`""" + # TODO: Remove this whole function. This should be a temporary method. We're working with adapters on # a strategy to ensure we can access the `is_incremental` logic without drift relation_info = self.adapter.Relation.create_from(self.config, model) relation = self.adapter.get_relation( @@ -671,136 +617,230 @@ def _is_incremental(self, model) -> bool: else: return False - def _execute_model( - self, - hook_ctx: Any, - context_config: Any, - model: ModelNode, - context: Dict[str, Any], - materialization_macro: MacroProtocol, - ) -> RunResult: - try: - batch_result = self._execute_microbatch_materialization( - model, context, materialization_macro - ) - finally: - self.adapter.post_model_hook(context_config, hook_ctx) + def _initial_run_microbatch_model_result(self, model: ModelNode) -> RunResult: + return RunResult( + node=model, + status=RunStatus.Success, + timing=[], + thread_id=threading.current_thread().name, + # The execution_time here doesn't get propagated to logs because + # `safe_run_hooks` handles the elapsed time at the node level + execution_time=0, + message="", + adapter_response={}, + failures=0, + batch_results=BatchResults(), + ) - return batch_result + def describe_node(self) -> str: + return f"{self.node.language} microbatch model {self.get_node_representation()}" + def merge_batch_results(self, result: RunResult, batch_results: List[RunResult]): + """merge batch_results into result""" + if result.batch_results is None: + result.batch_results = BatchResults() -class RunTask(CompileTask): - def __init__( - self, - args: Flags, - config: RuntimeConfig, - manifest: Manifest, - batch_map: Optional[Dict[str, BatchResults]] = None, + for batch_result in batch_results: + if batch_result.batch_results is not None: + result.batch_results += batch_result.batch_results + result.execution_time += batch_result.execution_time + + num_successes = len(result.batch_results.successful) + num_failures = len(result.batch_results.failed) + if num_failures == 0: + status = RunStatus.Success + msg = "SUCCESS" + elif num_successes == 0: + status = RunStatus.Error + msg = "ERROR" + else: + status = RunStatus.PartialSuccess + msg = f"PARTIAL SUCCESS ({num_successes}/{num_successes + num_failures})" + result.status = status + result.message = msg + + result.batch_results.successful = sorted(result.batch_results.successful) + result.batch_results.failed = sorted(result.batch_results.failed) + + # # If retrying, propagate previously successful batches into final result, even thoguh they were not run in this invocation + if self.node.previous_batch_results is not None: + result.batch_results.successful += self.node.previous_batch_results.successful + + def _update_result_with_unfinished_batches( + self, result: RunResult, batches: Dict[int, BatchType] ) -> None: - super().__init__(args, config, manifest) - self.batch_map = batch_map + """This method is really only to be used when the execution of a microbatch model is halted before all batches have had a chance to run""" + batches_finished: Set[BatchType] = set() - def raise_on_first_error(self) -> bool: - return False + if result.batch_results: + # build list of finished batches + batches_finished = batches_finished.union(set(result.batch_results.successful)) + batches_finished = batches_finished.union(set(result.batch_results.failed)) + else: + # instantiate `batch_results` if it was `None` + result.batch_results = BatchResults() - def get_hook_sql(self, adapter, hook, idx, num_hooks, extra_context) -> str: - if self.manifest is None: - raise DbtInternalError("compile_node called before manifest was loaded") + # skipped batches are any batch that was expected but didn't finish + batches_expected = {batch for _, batch in batches.items()} + skipped_batches = batches_expected.difference(batches_finished) - compiled = self.compiler.compile_node(hook, self.manifest, extra_context) - statement = compiled.compiled_code - hook_index = hook.index or num_hooks - hook_obj = get_hook(statement, index=hook_index) - return hook_obj.sql or "" + result.batch_results.failed.extend(list(skipped_batches)) - def handle_job_queue(self, pool, callback): - node = self.job_queue.get() - self._raise_set_error() - runner = self.get_runner(node) - # we finally know what we're running! Make sure we haven't decided - # to skip it due to upstream failures - if runner.node.unique_id in self._skipped_children: - cause = self._skipped_children.pop(runner.node.unique_id) - runner.do_skip(cause=cause) + # We call this method, even though we are merging no new results, as it updates + # the result witht he appropriate status (Success/Partial/Failed) + self.merge_batch_results(result, []) - if isinstance(runner, MicrobatchModelRunner): - callback(self.handle_microbatch_model(runner, pool)) + def get_microbatch_builder(self, model: ModelNode) -> MicrobatchBuilder: + return MicrobatchBuilder( + model=model, + is_incremental=self._is_incremental(model), + event_time_start=getattr(self.config.args, "EVENT_TIME_START", None), + event_time_end=getattr(self.config.args, "EVENT_TIME_END", None), + default_end_time=self.config.invoked_at, + ) + + def get_batches(self, model: ModelNode) -> Dict[int, BatchType]: + """Get the batches that should be run for the model""" + + # Note currently (02/23/2025) model.previous_batch_results is only ever _not_ `None` + # IFF `dbt retry` is being run and the microbatch model had batches which + # failed on the run of the model (which is being retried) + if model.previous_batch_results is None: + microbatch_builder = self.get_microbatch_builder(model) + end = microbatch_builder.build_end_time() + start = microbatch_builder.build_start_time(end) + batches = microbatch_builder.build_batches(start, end) else: - args = [runner] - self._submit(pool, args, callback) + batches = model.previous_batch_results.failed - def handle_microbatch_model( - self, - runner: MicrobatchModelRunner, - pool: ThreadPool, - ) -> RunResult: - # Initial run computes batch metadata - result = self.call_runner(runner) - batches, node, relation_exists = runner.batches, runner.node, runner.relation_exists + return {batch_idx: batches[batch_idx] for batch_idx in range(len(batches))} - # Return early if model should be skipped, or there are no batches to execute - if result.status == RunStatus.Skipped: - return result - elif len(runner.batches) == 0: + def compile(self, manifest: Manifest): + """Don't do anything here because this runner doesn't need to compile anything""" + return self.node + + def execute(self, model: ModelNode, manifest: Manifest) -> RunResult: + # Execution really means orchestration in this case + + batches = self.get_batches(model=model) + relation_exists = self._has_relation(model=model) + result = self._initial_run_microbatch_model_result(model=model) + + # No batches to run, so return initial result + if len(batches) == 0: return result batch_results: List[RunResult] = [] batch_idx = 0 # Run first batch not in parallel - relation_exists = self._submit_batch( - node=node, - adapter=runner.adapter, + relation_exists = self.parent_task._submit_batch( + node=model, + adapter=self.adapter, relation_exists=relation_exists, batches=batches, batch_idx=batch_idx, batch_results=batch_results, - pool=pool, + pool=self.pool, force_sequential_run=True, + incremental_batch=self._is_incremental(model=model), ) batch_idx += 1 skip_batches = batch_results[0].status != RunStatus.Success # Run all batches except first and last batch, in parallel if possible - while batch_idx < len(runner.batches) - 1: - relation_exists = self._submit_batch( - node=node, - adapter=runner.adapter, + while batch_idx < len(batches) - 1: + relation_exists = self.parent_task._submit_batch( + node=model, + adapter=self.adapter, relation_exists=relation_exists, batches=batches, batch_idx=batch_idx, batch_results=batch_results, - pool=pool, + pool=self.pool, skip=skip_batches, ) batch_idx += 1 # Wait until all submitted batches have completed while len(batch_results) != batch_idx: - pass + # Check if the pool was closed, because if it was, then the main thread is trying to exit. + # If the main thread is trying to exit, we need to shutdown. If we _don't_ shutdown, then + # batches will continue to execute and we'll delay the run from stopping + if self.pool.is_closed(): + # It's technically possible for more results to come in while we clean up + # instead we're going to say the didn't finish, regardless of if they finished + # or not. Thus, lets get a copy of the results as they exist right "now". + frozen_batch_results = deepcopy(batch_results) + self.merge_batch_results(result, frozen_batch_results) + self._update_result_with_unfinished_batches(result, batches) + return result + + # breifly sleep so that this thread doesn't go brrrrr while waiting + time.sleep(0.1) # Only run "last" batch if there is more than one batch if len(batches) != 1: # Final batch runs once all others complete to ensure post_hook runs at the end - self._submit_batch( - node=node, - adapter=runner.adapter, + self.parent_task._submit_batch( + node=model, + adapter=self.adapter, relation_exists=relation_exists, batches=batches, batch_idx=batch_idx, batch_results=batch_results, - pool=pool, + pool=self.pool, force_sequential_run=True, skip=skip_batches, ) # Finalize run: merge results, track model run, and print final result line - runner.merge_batch_results(result, batch_results) - track_model_run(runner.node_index, runner.num_nodes, result, adapter=runner.adapter) - runner.print_result_line(result) + self.merge_batch_results(result, batch_results) return result + +class RunTask(CompileTask): + def __init__( + self, + args: Flags, + config: RuntimeConfig, + manifest: Manifest, + batch_map: Optional[Dict[str, BatchResults]] = None, + ) -> None: + super().__init__(args, config, manifest) + self.batch_map = batch_map + + def raise_on_first_error(self) -> bool: + return False + + def get_hook_sql(self, adapter, hook, idx, num_hooks, extra_context) -> str: + if self.manifest is None: + raise DbtInternalError("compile_node called before manifest was loaded") + + compiled = self.compiler.compile_node(hook, self.manifest, extra_context) + statement = compiled.compiled_code + hook_index = hook.index or num_hooks + hook_obj = get_hook(statement, index=hook_index) + return hook_obj.sql or "" + + def handle_job_queue(self, pool, callback): + node = self.job_queue.get() + self._raise_set_error() + runner = self.get_runner(node) + # we finally know what we're running! Make sure we haven't decided + # to skip it due to upstream failures + if runner.node.unique_id in self._skipped_children: + cause = self._skipped_children.pop(runner.node.unique_id) + runner.do_skip(cause=cause) + + if isinstance(runner, MicrobatchModelRunner): + runner.set_parent_task(self) + runner.set_pool(pool) + + args = [runner] + self._submit(pool, args, callback) + def _submit_batch( self, node: ModelNode, @@ -809,9 +849,10 @@ def _submit_batch( batches: Dict[int, BatchType], batch_idx: int, batch_results: List[RunResult], - pool: ThreadPool, + pool: DbtThreadPool, force_sequential_run: bool = False, skip: bool = False, + incremental_batch: bool = True, ): node_copy = deepcopy(node) # Only run pre_hook(s) for first batch @@ -825,31 +866,41 @@ def _submit_batch( # TODO: We should be doing self.get_runner, however doing so # currently causes the tracking of how many nodes there are to # increment when we don't want it to - batch_runner = MicrobatchModelRunner( - self.config, adapter, node_copy, self.run_count, self.num_nodes + batch_runner = MicrobatchBatchRunner( + self.config, + adapter, + node_copy, + self.run_count, + self.num_nodes, + batch_idx, + batches, + relation_exists, + incremental_batch, ) - batch_runner.set_batch_idx(batch_idx) - batch_runner.set_relation_exists(relation_exists) - batch_runner.set_batches(batches) if skip: batch_runner.do_skip() - if not force_sequential_run and batch_runner.should_run_in_parallel(): - fire_event( - MicrobatchExecutionDebug( - msg=f"{batch_runner.describe_batch()} is being run concurrently" + if not pool.is_closed(): + if not force_sequential_run and batch_runner.should_run_in_parallel(): + fire_event( + MicrobatchExecutionDebug( + msg=f"{batch_runner.describe_batch()} is being run concurrently" + ) ) - ) - self._submit(pool, [batch_runner], batch_results.append) - else: - fire_event( - MicrobatchExecutionDebug( - msg=f"{batch_runner.describe_batch()} is being run sequentially" + self._submit(pool, [batch_runner], batch_results.append) + else: + fire_event( + MicrobatchExecutionDebug( + msg=f"{batch_runner.describe_batch()} is being run sequentially" + ) ) + batch_results.append(self.call_runner(batch_runner)) + relation_exists = batch_runner.relation_exists + else: + batch_results.append( + batch_runner._build_failed_run_batch_result(node_copy, batches[batch_idx]) ) - batch_results.append(self.call_runner(batch_runner)) - relation_exists = batch_runner.relation_exists return relation_exists diff --git a/core/dbt/task/runnable.py b/core/dbt/task/runnable.py index 55342cafbbc..57e7e49d862 100644 --- a/core/dbt/task/runnable.py +++ b/core/dbt/task/runnable.py @@ -3,7 +3,6 @@ from abc import abstractmethod from concurrent.futures import as_completed from datetime import datetime -from multiprocessing.dummy import Pool as ThreadPool from pathlib import Path from typing import AbstractSet, Dict, Iterable, List, Optional, Set, Tuple, Type, Union @@ -48,6 +47,7 @@ UniqueId, parse_difference, ) +from dbt.graph.thread_pool import DbtThreadPool from dbt.parser.manifest import write_manifest from dbt.task.base import BaseRunner, ConfiguredTask from dbt_common.context import _INVOCATION_CONTEXT_VAR, get_invocation_context @@ -406,7 +406,9 @@ def _cancel_connections(self, pool): def execute_nodes(self): num_threads = self.config.threads - pool = ThreadPool(num_threads, self._pool_thread_initializer, [get_invocation_context()]) + pool = DbtThreadPool( + num_threads, self._pool_thread_initializer, [get_invocation_context()] + ) try: self.run_queue(pool) except FailFastError as failure: diff --git a/tests/functional/microbatch/test_microbatch.py b/tests/functional/microbatch/test_microbatch.py index fdbd0a219a5..d00f44a250b 100644 --- a/tests/functional/microbatch/test_microbatch.py +++ b/tests/functional/microbatch/test_microbatch.py @@ -5,8 +5,6 @@ from pytest_mock import MockerFixture from dbt.events.types import ( - ArtifactWritten, - EndOfRunSummary, GenericExceptionOnRun, InvalidConcurrentBatchesConfig, JinjaLogDebug, @@ -909,38 +907,6 @@ def test_microbatch(self, project) -> None: assert microbatch_model_last_batch == second_microbatch_model_last_batch -class TestMicrobatchModelStoppedByKeyboardInterrupt(BaseMicrobatchTest): - @pytest.fixture - def catch_eors(self) -> EventCatcher: - return EventCatcher(EndOfRunSummary) - - @pytest.fixture - def catch_aw(self) -> EventCatcher: - return EventCatcher( - event_to_catch=ArtifactWritten, - predicate=lambda event: event.data.artifact_type == "RunExecutionResult", - ) - - def test_microbatch( - self, - mocker: MockerFixture, - project, - catch_eors: EventCatcher, - catch_aw: EventCatcher, - ) -> None: - mocked_fbs = mocker.patch( - "dbt.materializations.incremental.microbatch.MicrobatchBuilder.format_batch_start" - ) - mocked_fbs.side_effect = KeyboardInterrupt - try: - run_dbt(["run"], callbacks=[catch_eors.catch, catch_aw.catch]) - assert False, "KeyboardInterrupt failed to stop batch execution" - except KeyboardInterrupt: - assert len(catch_eors.caught_events) == 1 - assert "Exited because of keyboard interrupt" in catch_eors.caught_events[0].info.msg - assert len(catch_aw.caught_events) == 1 - - class TestMicrobatchModelSkipped(BaseMicrobatchTest): @pytest.fixture(scope="class") def models(self): @@ -967,7 +933,7 @@ def batch_exc_catcher(self) -> EventCatcher: def test_microbatch( self, mocker: MockerFixture, project, batch_exc_catcher: EventCatcher ) -> None: - mocked_srip = mocker.patch("dbt.task.run.MicrobatchModelRunner.should_run_in_parallel") + mocked_srip = mocker.patch("dbt.task.run.MicrobatchBatchRunner.should_run_in_parallel") # Should be run in parallel mocked_srip.return_value = True @@ -1007,7 +973,7 @@ def batch_exc_catcher(self) -> EventCatcher: def test_microbatch( self, mocker: MockerFixture, project, batch_exc_catcher: EventCatcher ) -> None: - mocked_srip = mocker.patch("dbt.task.run.MicrobatchModelRunner.should_run_in_parallel") + mocked_srip = mocker.patch("dbt.task.run.MicrobatchBatchRunner.should_run_in_parallel") # Should be run in parallel mocked_srip.return_value = True diff --git a/tests/unit/materializations/incremental/test_microbatch.py b/tests/unit/materializations/incremental/test_microbatch.py index 3d827a79975..7ea2b986b1f 100644 --- a/tests/unit/materializations/incremental/test_microbatch.py +++ b/tests/unit/materializations/incremental/test_microbatch.py @@ -490,10 +490,10 @@ def test_build_batches(self, microbatch_model, start, end, batch_size, expected_ assert actual_batches == expected_batches def test_build_jinja_context_for_incremental_batch(self, microbatch_model): - microbatch_builder = MicrobatchBuilder( - model=microbatch_model, is_incremental=True, event_time_start=None, event_time_end=None + context = MicrobatchBuilder.build_jinja_context_for_batch( + model=microbatch_model, + incremental_batch=True, ) - context = microbatch_builder.build_jinja_context_for_batch(incremental_batch=True) assert context["model"] == microbatch_model.to_dict() assert context["sql"] == microbatch_model.compiled_code @@ -503,10 +503,10 @@ def test_build_jinja_context_for_incremental_batch(self, microbatch_model): assert context["should_full_refresh"]() is False def test_build_jinja_context_for_incremental_batch_false(self, microbatch_model): - microbatch_builder = MicrobatchBuilder( - model=microbatch_model, is_incremental=True, event_time_start=None, event_time_end=None + context = MicrobatchBuilder.build_jinja_context_for_batch( + model=microbatch_model, + incremental_batch=False, ) - context = microbatch_builder.build_jinja_context_for_batch(incremental_batch=False) assert context["model"] == microbatch_model.to_dict() assert context["sql"] == microbatch_model.compiled_code diff --git a/tests/unit/task/test_run.py b/tests/unit/task/test_run.py index b28ac505a7f..3378009fade 100644 --- a/tests/unit/task/test_run.py +++ b/tests/unit/task/test_run.py @@ -9,6 +9,7 @@ from psycopg2 import DatabaseError from pytest_mock import MockerFixture +from core.dbt.task.run import MicrobatchBatchRunner from dbt.adapters.contracts.connection import AdapterResponse from dbt.adapters.postgres import PostgresAdapter from dbt.artifacts.resources.base import FileHash @@ -174,6 +175,25 @@ def model_runner( num_nodes=1, ) + @pytest.fixture + def batch_runner( + self, + postgres_adapter: PostgresAdapter, + table_model: ModelNode, + runtime_config: RuntimeConfig, + ) -> MicrobatchBatchRunner: + return MicrobatchBatchRunner( + config=runtime_config, + adapter=postgres_adapter, + node=table_model, + node_index=1, + num_nodes=1, + batch_idx=0, + batches=[], + relation_exists=False, + incremental_batch=False, + ) + @pytest.mark.parametrize( "has_relation,relation_type,materialized,full_refresh_config,full_refresh_flag,expectation", [ @@ -267,22 +287,22 @@ class Relation: def test_should_run_in_parallel( self, mocker: MockerFixture, - model_runner: MicrobatchModelRunner, + batch_runner: MicrobatchBatchRunner, adapter_microbatch_concurrency: bool, has_relation: bool, concurrent_batches: Optional[bool], has_this: bool, expectation: bool, ) -> None: - model_runner.node._has_this = has_this - model_runner.node.config = ModelConfig(concurrent_batches=concurrent_batches) - model_runner.set_relation_exists(has_relation) + batch_runner.node._has_this = has_this + batch_runner.node.config = ModelConfig(concurrent_batches=concurrent_batches) + batch_runner.relation_exists = has_relation - mocked_supports = mocker.patch.object(model_runner.adapter, "supports") + mocked_supports = mocker.patch.object(batch_runner.adapter, "supports") mocked_supports.return_value = adapter_microbatch_concurrency # Assert result of should_run_in_parallel - assert model_runner.should_run_in_parallel() == expectation + assert batch_runner.should_run_in_parallel() == expectation class TestRunTask: