Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(python): Improve write_database, accounting for latest adbc fixes/updates #12713

Merged
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 68 additions & 26 deletions py-polars/polars/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -3434,71 +3434,113 @@ def write_database(
Parameters
----------
table_name
Name of the table to create or append to in the target SQL database.
If your table name contains special characters, it should be quoted.
Schema-qualified name of the table to create or append to in the target
SQL database. If your table name contains special characters, it should
be quoted.
connection
Connection URI string, for example:

* "postgresql://user:pass@server:port/database"
* "sqlite:////path/to/database.db"
if_exists : {'append', 'replace', 'fail'}
The insert mode.
'replace' will create a new database table, overwriting an existing one.
'append' will append to an existing table.
'fail' will fail if table already exists.
The insert mode:

* 'replace' will create a new database table, overwriting an existing one.
* 'append' will append to an existing table.
* 'fail' will fail if table already exists.
engine : {'sqlalchemy', 'adbc'}
Select the engine used for writing the data.
"""
from polars.io.database import _open_adbc_connection

def unpack_table_name(name: str) -> tuple[str | None, str]:
"""Unpack optionally qualified table name into schema/table pair."""
from csv import reader as delimited_read

table_ident = next(delimited_read([name], delimiter="."))
if len(table_ident) > 2:
raise ValueError(f"`table_name` appears to be invalid: {name!r}")
elif len(table_ident) > 1:
schema = table_ident[0]
tbl = table_ident[1]
else:
schema = None
tbl = table_ident[0]
return schema, tbl

if engine == "adbc":
if if_exists == "fail":
raise NotImplementedError(
"`if_exists = 'fail'` not supported for ADBC engine"
try:
import adbc_driver_manager

adbc_version = parse_version(
getattr(adbc_driver_manager, "__version__", "0.0")
)
elif if_exists == "replace":
except ModuleNotFoundError as exc:
raise ModuleNotFoundError(
"adbc_driver_manager not found"
"\n\nInstall Polars with: pip install adbc_driver_manager"
) from exc

if if_exists == "fail":
# if the table exists, 'create' will raise an error,
# resulting in behaviour equivalent to 'fail'
mode = "create"
elif if_exists == "replace":
if adbc_version < (0, 7):
adbc_str_version = ".".join(str(v) for v in adbc_version)
raise ModuleNotFoundError(
f"`if_exists = 'replace'` requires ADBC version >= 0.7, found {adbc_str_version}"
)
mode = "replace"
elif if_exists == "append":
mode = "append"
else:
raise ValueError(
f"unexpected value for `if_exists`: {if_exists!r}"
f"\n\nChoose one of {{'fail', 'replace', 'append'}}"
)

with _open_adbc_connection(connection) as conn, conn.cursor() as cursor:
cursor.adbc_ingest(table_name, self.to_arrow(), mode)
if adbc_version >= (0, 7):
db_schema, unpacked_table_name = unpack_table_name(table_name)
if "sqlite" in conn.adbc_get_info()["driver_name"].lower():
if if_exists == "replace":
# note: adbc doesn't (yet) support 'replace' for sqlite
cursor.execute(f"DROP TABLE IF EXISTS {table_name}")
mode = "create"
catalog, db_schema = db_schema, None
else:
catalog = None

cursor.adbc_ingest(
unpacked_table_name,
data=self.to_arrow(),
mode=mode,
catalog_name=catalog,
db_schema_name=db_schema,
)
else:
cursor.adbc_ingest(table_name, self.to_arrow(), mode)
conn.commit()

elif engine == "sqlalchemy":
if parse_version(pd.__version__) < parse_version("1.5"):
raise ModuleNotFoundError(
f"writing with engine 'sqlalchemy' requires pandas 1.5.x or higher, found pandas {pd.__version__!r}"
f"writing with engine 'sqlalchemy' requires pandas 1.5.x or higher, found {pd.__version__!r}"
)

try:
from sqlalchemy import create_engine
except ModuleNotFoundError as exc:
raise ModuleNotFoundError(
"sqlalchemy not found"
"\n\nInstall Polars with: pip install polars[sqlalchemy]"
) from exc
from csv import reader as delimited_read

# the table name may also include the db schema; ensure that we identify
# both components and pass them through unquoted (sqlalachemy will quote)
table_ident = next(delimited_read([table_name], delimiter="."))
if len(table_ident) > 2:
raise ValueError(f"`table_name` appears to be invalid: {table_name!r}")
elif len(table_ident) > 1:
db_schema = table_ident[0]
table_name = table_ident[1]
else:
table_name = table_ident[0]
db_schema = None

# ensure conversion to pandas uses the pyarrow extension array option
# so that we can make use of the sql/db export without copying data
engine_sa = create_engine(connection)
db_schema, table_name = unpack_table_name(table_name)

self.to_pandas(use_pyarrow_extension_array=True).to_sql(
name=table_name,
schema=db_schema,
Expand Down
4 changes: 2 additions & 2 deletions py-polars/polars/io/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,10 +190,10 @@ def _normalise_cursor(self, conn: ConnectionOrCursor) -> Cursor:
if conn.driver == "databricks-sql-python": # type: ignore[union-attr]
# take advantage of the raw connection to get arrow integration
self.driver_name = "databricks"
return conn.raw_connection().cursor() # type: ignore[union-attr]
return conn.raw_connection().cursor() # type: ignore[union-attr, return-value]
else:
# sqlalchemy engine; direct use is deprecated, so prefer the connection
return conn.connect() # type: ignore[union-attr]
return conn.connect() # type: ignore[union-attr, return-value]

elif hasattr(conn, "cursor"):
# connection has a dedicated cursor; prefer over direct execute
Expand Down
4 changes: 3 additions & 1 deletion py-polars/polars/type_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
if TYPE_CHECKING:
import sys

from sqlalchemy import Engine

from polars import DataFrame, Expr, LazyFrame, Series
from polars.datatypes import DataType, DataTypeClass, IntegerType, TemporalType
from polars.dependencies import numpy as np
Expand Down Expand Up @@ -233,4 +235,4 @@ def fetchmany(self, *args: Any, **kwargs: Any) -> Any:
"""Fetch results in batches."""


ConnectionOrCursor = Union[BasicConnection, BasicCursor, Cursor]
ConnectionOrCursor = Union[BasicConnection, BasicCursor, Cursor, "Engine"]
2 changes: 1 addition & 1 deletion py-polars/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ disable_error_code = [
[[tool.mypy.overrides]]
module = [
"IPython.*",
"adbc_driver_postgresql.*",
"adbc_driver_manager.*",
"adbc_driver_sqlite.*",
"arrow_odbc",
"backports",
Expand Down
1 change: 1 addition & 0 deletions py-polars/requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ backports.zoneinfo; python_version < '3.9'
tzdata; platform_system == 'Windows'
# Database
SQLAlchemy
adbc_driver_manager; python_version >= '3.9' and platform_system != 'Windows'
adbc_driver_sqlite; python_version >= '3.9' and platform_system != 'Windows'
# TODO: Remove version constraint for connectorx when Python 3.12 is supported:
# https://github.com/sfu-db/connector-x/issues/527
Expand Down
106 changes: 47 additions & 59 deletions py-polars/tests/unit/io/test_database_write.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from __future__ import annotations

import sys
from contextlib import suppress
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING

import pytest
from sqlalchemy import create_engine

import polars as pl
from polars.testing import assert_frame_equal
Expand All @@ -15,18 +15,6 @@
from polars.type_aliases import DbWriteEngine


def adbc_sqlite_driver_version(*args: Any, **kwargs: Any) -> str:
with suppress(ModuleNotFoundError): # not available on 3.8/windows
import adbc_driver_sqlite

return getattr(adbc_driver_sqlite, "__version__", "n/a")
return "n/a"


@pytest.mark.skipif(
sys.version_info > (3, 11),
reason="connectorx cannot be installed on Python 3.12 yet.",
)
@pytest.mark.skipif(
sys.version_info < (3, 9) or sys.platform == "win32",
reason="adbc_driver_sqlite not available below Python 3.9 / on Windows",
Expand All @@ -43,29 +31,28 @@ def test_write_database_create(engine: DbWriteEngine, tmp_path: Path) -> None:
)
tmp_path.mkdir(exist_ok=True)
test_db = str(tmp_path / f"test_{engine}.db")
test_db_uri = f"sqlite:///{test_db}"
table_name = "test_create"

df.write_database(
table_name=table_name,
connection=f"sqlite:///{test_db}",
if_exists="replace",
connection=test_db_uri,
engine=engine,
)
result = pl.read_database_uri(f"SELECT * FROM {table_name}", f"sqlite:///{test_db}")
result = pl.read_database(
query=f"SELECT * FROM {table_name}",
connection=create_engine(test_db_uri),
)
assert_frame_equal(result, df)


@pytest.mark.skipif(
sys.version_info > (3, 11),
reason="connectorx cannot be installed on Python 3.12 yet.",
)
@pytest.mark.skipif(
sys.version_info < (3, 9) or sys.platform == "win32",
reason="adbc_driver_sqlite not available below Python 3.9 / on Windows",
)
@pytest.mark.write_disk()
@pytest.mark.parametrize("engine", ["adbc", "sqlalchemy"])
def test_write_database_append(engine: DbWriteEngine, tmp_path: Path) -> None:
def test_write_database_append_replace(engine: DbWriteEngine, tmp_path: Path) -> None:
df = pl.DataFrame(
{
"key": ["xx", "yy", "zz"],
Expand All @@ -76,31 +63,45 @@ def test_write_database_append(engine: DbWriteEngine, tmp_path: Path) -> None:

tmp_path.mkdir(exist_ok=True)
test_db = str(tmp_path / f"test_{engine}.db")
table_name = "test_append"
test_db_uri = f"sqlite:///{test_db}"
table_name = f"test_append_{engine}"

df.write_database(
table_name=table_name,
connection=f"sqlite:///{test_db}",
if_exists="replace",
connection=test_db_uri,
engine=engine,
)

ExpectedError = NotImplementedError if engine == "adbc" else ValueError
with pytest.raises(ExpectedError):
with pytest.raises(Exception): # noqa: B017
df.write_database(
table_name=table_name,
connection=f"sqlite:///{test_db}",
connection=test_db_uri,
if_exists="fail",
engine=engine,
)

df.write_database(
table_name=table_name,
connection=f"sqlite:///{test_db}",
connection=test_db_uri,
if_exists="replace",
engine=engine,
)
result = pl.read_database(
query=f"SELECT * FROM {table_name}",
connection=create_engine(test_db_uri),
)
assert_frame_equal(result, df)

df.write_database(
table_name=table_name,
connection=test_db_uri,
if_exists="append",
engine=engine,
)
result = pl.read_database_uri(f"SELECT * FROM {table_name}", f"sqlite:///{test_db}")
result = pl.read_database(
query=f"SELECT * FROM {table_name}",
connection=create_engine(test_db_uri),
)
assert_frame_equal(result, pl.concat([df, df]))


Expand All @@ -111,21 +112,7 @@ def test_write_database_append(engine: DbWriteEngine, tmp_path: Path) -> None:
@pytest.mark.write_disk()
@pytest.mark.parametrize(
"engine",
[
pytest.param(
"adbc",
marks=pytest.mark.xfail( # see: https://github.com/apache/arrow-adbc/issues/1000
reason="ADBC SQLite driver has a bug with quoted/qualified table names",
),
),
pytest.param(
"sqlalchemy",
marks=pytest.mark.skipif(
sys.version_info > (3, 11),
reason="connectorx cannot be installed on Python 3.12 yet.",
),
),
],
["adbc", "sqlalchemy"],
)
def test_write_database_create_quoted_tablename(
engine: DbWriteEngine, tmp_path: Path
Expand All @@ -134,17 +121,26 @@ def test_write_database_create_quoted_tablename(

tmp_path.mkdir(exist_ok=True)
test_db = str(tmp_path / f"test_{engine}.db")
test_db_uri = f"sqlite:///{test_db}"

# table name requires quoting, and is qualified with the implicit 'main' schema
table_name = 'main."test-append"'
qualified_table_name = f'main."test-append-{engine}"'

df.write_database(
table_name=table_name,
connection=f"sqlite:///{test_db}",
table_name=qualified_table_name,
connection=test_db_uri,
engine=engine,
)
df.write_database(
table_name=qualified_table_name,
connection=test_db_uri,
if_exists="replace",
engine=engine,
)
result = pl.read_database_uri(f"SELECT * FROM {table_name}", f"sqlite:///{test_db}")
result = pl.read_database(
query=f"SELECT * FROM {qualified_table_name}",
connection=create_engine(test_db_uri),
)
assert_frame_equal(result, df)


Expand All @@ -154,19 +150,11 @@ def test_write_database_errors() -> None:

with pytest.raises(
ValueError, match="`table_name` appears to be invalid: 'w.x.y.z'"
):
df.write_database(
connection="sqlite:///:memory:", table_name="w.x.y.z", engine="sqlalchemy"
)

with pytest.raises(
NotImplementedError, match="`if_exists = 'fail'` not supported for ADBC engine"
):
df.write_database(
connection="sqlite:///:memory:",
table_name="test_errs",
if_exists="fail",
engine="adbc",
table_name="w.x.y.z",
engine="sqlalchemy",
)

with pytest.raises(ValueError, match="'do_something' is not valid for if_exists"):
Expand Down