Skip to content

Commit

Permalink
Enable unit testing versioned models (#9421)
Browse files Browse the repository at this point in the history
  • Loading branch information
gshank authored Jan 29, 2024
1 parent 1cbc6d3 commit 5ae8f6a
Show file tree
Hide file tree
Showing 12 changed files with 662 additions and 41 deletions.
6 changes: 6 additions & 0 deletions .changes/unreleased/Features-20240122-145854.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Features
body: Enable unit testing versioned models
time: 2024-01-22T14:58:54.251484-05:00
custom:
Author: gshank
Issue: "9344"
1 change: 1 addition & 0 deletions core/dbt/contracts/graph/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -980,6 +980,7 @@ def build_parent_and_child_maps(self):
self.metrics.values(),
self.semantic_models.values(),
self.saved_queries.values(),
self.unit_tests.values(),
)
)
forward_edges, backward_edges = build_node_edges(edge_members)
Expand Down
6 changes: 5 additions & 1 deletion core/dbt/contracts/graph/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
UnitTestOverrides,
UnitTestInputFixture,
UnitTestOutputFixture,
UnitTestNodeVersions,
)
from dbt.contracts.graph.node_args import ModelNodeArgs
from dbt.contracts.graph.semantic_layer_common import WhereFilterIntersection
Expand Down Expand Up @@ -1067,6 +1068,9 @@ class UnitTestDefinition(NodeInfoMixin, GraphNode, UnitTestDefinitionMandatory):
config: UnitTestConfig = field(default_factory=UnitTestConfig)
checksum: Optional[str] = None
schema: Optional[str] = None
created_at: float = field(default_factory=lambda: time.time())
versions: Optional[UnitTestNodeVersions] = None
version: Optional[NodeVersion] = None

@property
def build_path(self):
Expand All @@ -1089,7 +1093,7 @@ def tags(self) -> List[str]:

def build_unit_test_checksum(self):
# everything except 'description'
data = f"{self.model}-{self.given}-{self.expect}-{self.overrides}"
data = f"{self.model}-{self.versions}-{self.given}-{self.expect}-{self.overrides}"

# include underlying fixture data
for input in self.given:
Expand Down
14 changes: 14 additions & 0 deletions core/dbt/contracts/graph/unparsed.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,6 +803,12 @@ class UnitTestOverrides(dbtClassMixin):
env_vars: Dict[str, Any] = field(default_factory=dict)


@dataclass
class UnitTestNodeVersions(dbtClassMixin):
include: Optional[List[NodeVersion]] = None
exclude: Optional[List[NodeVersion]] = None


@dataclass
class UnparsedUnitTest(dbtClassMixin):
name: str
Expand All @@ -812,3 +818,11 @@ class UnparsedUnitTest(dbtClassMixin):
description: str = ""
overrides: Optional[UnitTestOverrides] = None
config: Dict[str, Any] = field(default_factory=dict)
versions: Optional[UnitTestNodeVersions] = None

@classmethod
def validate(cls, data):
super(UnparsedUnitTest, cls).validate(data)
if data.get("versions", None):
if data["versions"].get("include") and data["versions"].get("exclude"):
raise ValidationError("Unit tests can not both include and exclude versions.")
40 changes: 39 additions & 1 deletion core/dbt/parser/manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,7 @@
from dbt.parser.seeds import SeedParser
from dbt.parser.snapshots import SnapshotParser
from dbt.parser.sources import SourcePatcher
from dbt.parser.unit_tests import process_models_for_unit_test
from dbt.version import __version__

from dbt_common.dataclass_schema import StrEnum, dbtClassMixin
Expand Down Expand Up @@ -534,6 +535,7 @@ def load(self) -> Manifest:
start_process = time.perf_counter()
self.process_sources(self.root_project.project_name)
self.process_refs(self.root_project.project_name, self.root_project.dependencies)
self.process_unit_tests(self.root_project.project_name)
self.process_docs(self.root_project)
self.process_metrics(self.root_project)
self.process_saved_queries(self.root_project)
Expand Down Expand Up @@ -1227,6 +1229,27 @@ def process_sources(self, current_project: str):
continue
_process_sources_for_exposure(self.manifest, current_project, exposure)

# Loops through all nodes, for each element in
# 'unit_test' array finds the node and updates the
# 'depends_on.nodes' array with the unique id
def process_unit_tests(self, current_project: str):
models_to_versions = None
unit_test_unique_ids = list(self.manifest.unit_tests.keys())
for unit_test_unique_id in unit_test_unique_ids:
# This is because some unit tests will be removed when processing
# and the list of unit_test_unique_ids won't have changed
if unit_test_unique_id in self.manifest.unit_tests:
unit_test = self.manifest.unit_tests[unit_test_unique_id]
else:
continue
if unit_test.created_at < self.started_at:
continue
if not models_to_versions:
models_to_versions = _build_model_names_to_versions(self.manifest)
process_models_for_unit_test(
self.manifest, current_project, unit_test, models_to_versions
)

def cleanup_disabled(self):
# make sure the nodes are in the manifest.nodes or the disabled dict,
# correctly now that the schema files are also parsed
Expand Down Expand Up @@ -1343,6 +1366,21 @@ def invalid_target_fail_unless_test(
)


def _build_model_names_to_versions(manifest: Manifest) -> Dict[str, Dict]:
model_names_to_versions: Dict[str, Dict] = {}
for node in manifest.nodes.values():
if node.resource_type != NodeType.Model:
continue
if not node.is_versioned:
continue
if node.package_name not in model_names_to_versions:
model_names_to_versions[node.package_name] = {}
if node.name not in model_names_to_versions[node.package_name]:
model_names_to_versions[node.package_name][node.name] = []
model_names_to_versions[node.package_name][node.name].append(node.unique_id)
return model_names_to_versions


def _check_resource_uniqueness(
manifest: Manifest,
config: RuntimeConfig,
Expand Down Expand Up @@ -1754,7 +1792,7 @@ def _process_sources_for_node(manifest: Manifest, current_project: str, node: Ma
)

if target_source is None or isinstance(target_source, Disabled):
# this folows the same pattern as refs
# this follows the same pattern as refs
node.config.enabled = False
invalid_target_fail_unless_test(
node=node,
Expand Down
10 changes: 8 additions & 2 deletions core/dbt/parser/partial.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,6 +456,11 @@ def schedule_nodes_for_parsing(self, unique_ids):
self.delete_macro_file(source_file)
self.saved_files[file_id] = deepcopy(self.new_files[file_id])
self.add_to_pp_files(self.saved_files[file_id])
elif unique_id in self.saved_manifest.unit_tests:
unit_test = self.saved_manifest.unit_tests[unique_id]
self._schedule_for_parsing(
"unit_tests", unit_test, unit_test.name, self.delete_schema_unit_test
)

def _schedule_for_parsing(self, dict_key: str, element, name, delete: Callable) -> None:
file_id = element.file_id
Expand Down Expand Up @@ -839,8 +844,9 @@ def delete_schema_mssa_links(self, schema_file, dict_key, elem):
# if the node's group has changed - need to reparse all referencing nodes to ensure valid ref access
if node.group != elem.get("group"):
self.schedule_referencing_nodes_for_parsing(node.unique_id)
# if the node's latest version has changed - need to reparse all referencing nodes to ensure correct ref resolution
if node.is_versioned and node.latest_version != elem.get("latest_version"):
# If the latest version has changed or a version has been removed we need to
# reparse referencing nodes.
if node.is_versioned:
self.schedule_referencing_nodes_for_parsing(node.unique_id)
# remove from patches
schema_file.node_patches.remove(elem_unique_id)
Expand Down
Loading

0 comments on commit 5ae8f6a

Please sign in to comment.