Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add method to infer primary key of model #9650

Merged
merged 1 commit into from
Feb 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .changes/unreleased/Under the Hood-20240223-115021.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
kind: Under the Hood
body: Implement primary key inference for model nodes
time: 2024-02-23T11:50:21.257494-08:00
custom:
Author: aliceliu
Issue: "9652"
54 changes: 54 additions & 0 deletions core/dbt/contracts/graph/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,60 @@ def search_name(self):
def materialization_enforces_constraints(self) -> bool:
return self.config.materialized in ["table", "incremental"]

def infer_primary_key(self, data_tests: List["GenericTestNode"]) -> List[str]:
"""
Infers the columns that can be used as primary key of a model in the following order:
1. Columns with primary key constraints
2. Columns with unique and not_null data tests
3. Columns with enabled unique or dbt_utils.unique_combination_of_columns data tests
4. Columns with disabled unique or dbt_utils.unique_combination_of_columns data tests
"""
for constraint in self.constraints:
if constraint.type == ConstraintType.primary_key:
return constraint.columns

for column, column_info in self.columns.items():
for column_constraint in column_info.constraints:
if column_constraint.type == ConstraintType.primary_key:
return [column]

columns_with_enabled_unique_tests = set()
columns_with_disabled_unique_tests = set()
columns_with_not_null_tests = set()
for test in data_tests:
columns = []
if "column_name" in test.test_metadata.kwargs:
columns = [test.test_metadata.kwargs["column_name"]]
elif "combination_of_columns" in test.test_metadata.kwargs:
columns = test.test_metadata.kwargs["combination_of_columns"]

for column in columns:
if test.test_metadata.name in ["unique", "unique_combination_of_columns"]:
if test.config.enabled:
columns_with_enabled_unique_tests.add(column)
else:
columns_with_disabled_unique_tests.add(column)
elif test.test_metadata.name == "not_null":
columns_with_not_null_tests.add(column)

columns_with_unique_and_not_null_tests = []
for column in columns_with_not_null_tests:
if (
column in columns_with_enabled_unique_tests
or column in columns_with_disabled_unique_tests
):
columns_with_unique_and_not_null_tests.append(column)
if columns_with_unique_and_not_null_tests:
return columns_with_unique_and_not_null_tests

if columns_with_enabled_unique_tests:
return list(columns_with_enabled_unique_tests)

if columns_with_disabled_unique_tests:
return list(columns_with_disabled_unique_tests)

return []

def same_contents(self, old, adapter_type) -> bool:
return super().same_contents(old, adapter_type) and self.same_ref_representation(old)

Expand Down
81 changes: 81 additions & 0 deletions tests/unit/fixtures.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
from dbt.contracts.files import FileHash
from dbt.contracts.graph.nodes import (
DependsOn,
InjectedCTE,
ModelNode,
ModelConfig,
GenericTestNode,
)
from dbt.node_types import NodeType

from dbt.artifacts.resources import Contract, TestConfig, TestMetadata


def model_node():
return ModelNode(
package_name="test",
path="/root/models/foo.sql",
original_file_path="models/foo.sql",
language="sql",
raw_code='select * from {{ ref("other") }}',
name="foo",
resource_type=NodeType.Model,
unique_id="model.test.foo",
fqn=["test", "models", "foo"],
refs=[],
sources=[],
metrics=[],
depends_on=DependsOn(),
deferred=True,
description="",
database="test_db",
schema="test_schema",
alias="bar",
tags=[],
config=ModelConfig(),
contract=Contract(),
meta={},
compiled=True,
extra_ctes=[InjectedCTE("whatever", "select * from other")],
extra_ctes_injected=True,
compiled_code="with whatever as (select * from other) select * from whatever",
checksum=FileHash.from_contents(""),
unrendered_config={},
)


def generic_test_node():
return GenericTestNode(
package_name="test",
path="/root/x/path.sql",
original_file_path="/root/path.sql",
language="sql",
raw_code='select * from {{ ref("other") }}',
name="foo",
resource_type=NodeType.Test,
unique_id="model.test.foo",
fqn=["test", "models", "foo"],
refs=[],
sources=[],
metrics=[],
depends_on=DependsOn(),
deferred=False,
description="",
database="test_db",
schema="dbt_test__audit",
alias="bar",
tags=[],
config=TestConfig(severity="warn"),
contract=Contract(),
meta={},
compiled=True,
extra_ctes=[InjectedCTE("whatever", "select * from other")],
extra_ctes_injected=True,
compiled_code="with whatever as (select * from other) select * from whatever",
column_name="id",
test_metadata=TestMetadata(namespace=None, name="foo", kwargs={}),
checksum=FileHash.from_contents(""),
unrendered_config={
"severity": "warn",
},
)
70 changes: 4 additions & 66 deletions tests/unit/test_contracts_graph_compiled.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
from dbt.contracts.graph.nodes import (
DependsOn,
GenericTestNode,
InjectedCTE,
ModelNode,
ModelConfig,
)
from dbt.artifacts.resources import Contract, TestConfig, TestMetadata
from dbt.artifacts.resources import TestConfig, TestMetadata
from tests.unit.fixtures import generic_test_node, model_node
from dbt.node_types import NodeType

from .utils import (
Expand Down Expand Up @@ -57,36 +57,7 @@ def basic_uncompiled_model():

@pytest.fixture
def basic_compiled_model():
return ModelNode(
package_name="test",
path="/root/models/foo.sql",
original_file_path="models/foo.sql",
language="sql",
raw_code='select * from {{ ref("other") }}',
name="foo",
resource_type=NodeType.Model,
unique_id="model.test.foo",
fqn=["test", "models", "foo"],
refs=[],
sources=[],
metrics=[],
depends_on=DependsOn(),
deferred=True,
description="",
database="test_db",
schema="test_schema",
alias="bar",
tags=[],
config=ModelConfig(),
contract=Contract(),
meta={},
compiled=True,
extra_ctes=[InjectedCTE("whatever", "select * from other")],
extra_ctes_injected=True,
compiled_code="with whatever as (select * from other) select * from whatever",
checksum=FileHash.from_contents(""),
unrendered_config={},
)
return model_node()


@pytest.fixture
Expand Down Expand Up @@ -432,40 +403,7 @@ def basic_uncompiled_schema_test_node():

@pytest.fixture
def basic_compiled_schema_test_node():
return GenericTestNode(
package_name="test",
path="/root/x/path.sql",
original_file_path="/root/path.sql",
language="sql",
raw_code='select * from {{ ref("other") }}',
name="foo",
resource_type=NodeType.Test,
unique_id="model.test.foo",
fqn=["test", "models", "foo"],
refs=[],
sources=[],
metrics=[],
depends_on=DependsOn(),
deferred=False,
description="",
database="test_db",
schema="dbt_test__audit",
alias="bar",
tags=[],
config=TestConfig(severity="warn"),
contract=Contract(),
meta={},
compiled=True,
extra_ctes=[InjectedCTE("whatever", "select * from other")],
extra_ctes_injected=True,
compiled_code="with whatever as (select * from other) select * from whatever",
column_name="id",
test_metadata=TestMetadata(namespace=None, name="foo", kwargs={}),
checksum=FileHash.from_contents(""),
unrendered_config={
"severity": "warn",
},
)
return generic_test_node()


@pytest.fixture
Expand Down
Loading
Loading