Skip to content

Commit

Permalink
Merge pull request #5 from mrjsj/fix/dispose-notebookutils
Browse files Browse the repository at this point in the history
fix: dispose notebookutils
  • Loading branch information
mrjsj authored Dec 1, 2024
2 parents 4a16788 + 2747181 commit 4f26a16
Show file tree
Hide file tree
Showing 4 changed files with 65 additions and 47 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "msfabricutils"
version = "0.1.2"
version = "0.1.3"
description = "A Python library exposes additional functionality to work with Python Notebooks in Microsoft Fabric."
authors = [
{ name = "Jimmy Jensen" },
Expand Down
91 changes: 56 additions & 35 deletions src/msfabricutils/fabric_duckdb_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@
from sqlglot import exp

# Avoid import errors outside Fabric environments
try:
from sempy import fabric
import notebookutils
try:
from sempy import fabric #noqa: F401
except ModuleNotFoundError:
pass


class FabricDuckDBConnection:
"""A DuckDB connection wrapper for Microsoft Fabric Lakehouses.
Expand All @@ -28,7 +28,7 @@ class FabricDuckDBConnection:
- Delta Lake table support
Args:
access_token (str): The Microsoft Fabric access token for authentication.
access_token (str): The Microsoft Fabric access token for authentication.
In a notebook, use `notebookutils.credentials.getToken('storage')`.
config (dict, optional): DuckDB configuration options. Defaults to {}.
Expand All @@ -37,7 +37,7 @@ class FabricDuckDBConnection:
>>> # Initialize connection
>>> access_token = notebookutils.credentials.getToken('storage')
>>> conn = FabricDuckDBConnection(access_token=access_token)
>>>
>>>
>>> # Register lakehouses from different workspaces
>>> conn.register_workspace_lakehouses(
... workspace_id='12345678-1234-5678-1234-567812345678',
Expand All @@ -46,23 +46,24 @@ class FabricDuckDBConnection:
>>> conn.register_workspace_lakehouses(
... workspace_id='87654321-8765-4321-8765-432187654321',
... lakehouses=['marketing']
... )
>>>
... )
>>>
>>> # Query across workspaces using fully qualified names
>>> df = conn.sql('''
... SELECT
... SELECT
... c.customer_id,
... c.name,
... c.region,
... s.segment,
... s.lifetime_value
... FROM sales_workspace.sales.main.customers c
... JOIN marketing_workspace.marketing.main.customer_segments s
... JOIN marketing_workspace.marketing.main.customer_segments s
... ON c.customer_id = s.customer_id
... WHERE c.region = 'EMEA'
... ''').df()
```
"""
"""

def __init__(self, access_token: str, config: dict = {}):
self._registered_tables = []
self._connection = duckdb.connect(config=config)
Expand Down Expand Up @@ -95,11 +96,11 @@ def refresh_access_token(self, access_token: str):
```python
>>> # Initialize connection
>>> conn = FabricDuckDBConnection(access_token='old_token')
>>>
>>>
>>> # When token expires, refresh it
>>> new_token = notebookutils.credentials.getToken('storage')
>>> conn.refresh_access_token(new_token)
```
>>> conn.refresh_access_token(new_token)
```
"""
self._access_token = access_token

Expand All @@ -123,7 +124,7 @@ def _preprocess_sql_query(self, query: str):
Raises:
Exception: If table references are ambiguous or tables don't exist
"""
"""
parsed = sqlglot.parse_one(sql=query, read="duckdb")
tables_in_query = parsed.find_all(exp.Table)
replace_by = {}
Expand Down Expand Up @@ -276,7 +277,7 @@ def _modify_input_query(self, args, kwargs):
Returns:
tuple: Modified (args, kwargs)
"""
"""
modified_args = list(args)

if "query" in kwargs:
Expand All @@ -295,7 +296,7 @@ def _preprocess_input(self, query: str) -> str:
Returns:
str: The preprocessed SQL statements
"""
"""
query_separator_indices = _separator_indices(query, ";")
query_separator_indices = [0] + [idx + 1 for idx in query_separator_indices]

Expand Down Expand Up @@ -328,7 +329,14 @@ def _create_or_replace_fabric_lakehouse_secret(self, catalog_name: str) -> None:
def _register_lakehouse_tables(
self, workspace_name: str, workspace_id: str, lakehouse_id: str, lakehouse_name: str
) -> None:
tables = notebookutils.lakehouse.listTables(lakehouse_name, workspace_id) # noqa: F821
from sempy import fabric

client = fabric.FabricRestClient()

response = client.get(f"v1/workspaces/{workspace_id}/lakehouses/{lakehouse_id}/tables")
response.raise_for_status()

tables = response.json()["data"]

if not tables:
table_information = {
Expand Down Expand Up @@ -388,46 +396,59 @@ def register_workspace_lakehouses(self, workspace_id: str, lakehouses: str | lis
>>> # Initialize connection with access token
>>> access_token = notebookutils.credentials.getToken('storage')
>>> conn = FabricDuckDBConnection(access_token=access_token)
>>>
>>>
>>> # Register a single lakehouse
>>> conn.register_workspace_lakehouses(
... workspace_id='12345678-1234-5678-1234-567812345678',
... lakehouses='sales_lakehouse'
... )
>>>
>>>
>>> # Register multiple lakehouses
>>> conn.register_workspace_lakehouses(
... workspace_id='12345678-1234-5678-1234-567812345678',
... lakehouses=['sales_lakehouse', 'marketing_lakehouse']
... )
```
"""
"""

if isinstance(lakehouses, str):
lakehouses = [lakehouses]

from sempy import fabric

workspaces = fabric.list_workspaces()

workspace_name = workspaces[workspaces.Id == workspace_id]["Name"].iat[0]

if isinstance(lakehouses, str):
lakehouses = [lakehouses]
client = fabric.FabricRestClient()

for lakehouse in lakehouses:
lakehouse_properties = notebookutils.lakehouse.getWithProperties( # noqa: F821
lakehouse, workspace_id
)
response = client.get(f"v1/workspaces/{workspace_id}/lakehouses")
response.raise_for_status()
lakehouse_properties = response["value"]

is_schema_enabled = (
lakehouse_properties.get("properties").get("defaultSchema") is not None
)
lakehouse_id = lakehouse_properties.get("id")
selected_lakehouses = [
lakehouse
for lakehouse in lakehouse_properties
if lakehouse["displayName"] in lakehouses
]

for lakehouse in selected_lakehouses:
is_schema_enabled = lakehouse.get("properties").get("defaultSchema") is not None

lakehouse_name = lakehouse.get("displayName")
lakehouse_id = lakehouse.get("id")

if is_schema_enabled:
raise Exception(f"""
The lakehouse `{lakehouse}` is using the schema-enabled preview feature.\n
The lakehouse `{lakehouse_name}` is using the schema-enabled preview feature.\n
This utility class does support schema-enabled lakehouses (yet).
""")

self._attach_lakehouse(workspace_name, lakehouse)
self._create_or_replace_fabric_lakehouse_secret(f"{workspace_name}{lakehouse}")
self._register_lakehouse_tables(workspace_name, workspace_id, lakehouse_id, lakehouse)
self._attach_lakehouse(workspace_name, lakehouse_name)
self._create_or_replace_fabric_lakehouse_secret(f"{workspace_name}{lakehouse_name}")
self._register_lakehouse_tables(
workspace_name, workspace_id, lakehouse_id, lakehouse_name
)

def print_lakehouse_catalog(self):
"""Print a hierarchical view of all registered lakehouses, schemas, and tables.
Expand Down Expand Up @@ -480,7 +501,7 @@ def write(
workspace_id: str = None,
workspace_name: str = None,
*args: Any,
**kwargs: Any
**kwargs: Any,
):
"""Write a DataFrame to a Fabric Lakehouse table.
Expand Down
6 changes: 2 additions & 4 deletions src/msfabricutils/helpers/separator_indices.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@


def _separator_indices(string: str, separator: str):
"""Find indices of a separator character in a string, ignoring separators inside quotes.
Expand All @@ -23,9 +21,9 @@ def _separator_indices(string: str, separator: str):
inside_double_quotes = not inside_double_quotes
elif char == "'" and not inside_double_quotes:
inside_single_quotes = not inside_single_quotes
elif (inside_double_quotes or inside_single_quotes):
elif inside_double_quotes or inside_single_quotes:
continue
elif char == separator:
indices.append(idx)

return indices
return indices
13 changes: 6 additions & 7 deletions tests/helpers/test_separator_indices.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,20 @@


def test_seperator_indices_simple_string():

chars = "text.to.be.tested"
assert _separator_indices(chars, ".") == [4,7,10]
assert _separator_indices(chars, ".") == [4, 7, 10]

def test_seperator_indices_string_with_quotes():

def test_seperator_indices_string_with_quotes():
chars = "'te.xt'.to.be.tested"
assert _separator_indices(chars, ".") == [7,10,13]
assert _separator_indices(chars, ".") == [7, 10, 13]

def test_seperator_indices_string_with_both_quotes():

def test_seperator_indices_string_with_both_quotes():
chars = """'te"."xt'.to.be.tested"""
assert _separator_indices(chars, ".") == [9, 12, 15]

def test_seperator_indices_only_dots():

def test_seperator_indices_only_dots():
chars = "..."
assert _separator_indices(chars, ".") == [0, 1, 2]
assert _separator_indices(chars, ".") == [0, 1, 2]

0 comments on commit 4f26a16

Please sign in to comment.