Skip to content

Commit

Permalink
feat(python): Support passing instantiated adbc/alchemy connection ob…
Browse files Browse the repository at this point in the history
…jects to `write_database` (#16099)
  • Loading branch information
alexander-beedie authored May 9, 2024
1 parent 8d35193 commit 9aecfd3
Show file tree
Hide file tree
Showing 4 changed files with 157 additions and 41 deletions.
97 changes: 80 additions & 17 deletions py-polars/polars/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@
ColumnWidthsDefinition,
ComparisonOperator,
ConditionalFormatDict,
ConnectionOrCursor,
CsvQuoteStyle,
DbWriteEngine,
FillNullStrategy,
Expand Down Expand Up @@ -3161,13 +3162,14 @@ def write_parquet(
def write_database(
self,
table_name: str,
connection: str,
connection: ConnectionOrCursor | str,
*,
if_table_exists: DbWriteMode = "fail",
engine: DbWriteEngine = "sqlalchemy",
engine: DbWriteEngine | None = None,
engine_options: dict[str, Any] | None = None,
) -> int:
"""
Write a polars frame to a database.
Write the data in a Polars DataFrame to a database.
Parameters
----------
Expand All @@ -3176,7 +3178,8 @@ def write_database(
SQL database. If your table name contains special characters, it should
be quoted.
connection
Connection URI string, for example:
An existing SQLAlchemy or ADBC connection against the target database, or
a URI string that will be used to instantiate such a connection, such as:
* "postgresql://user:pass@server:port/database"
* "sqlite:////path/to/database.db"
Expand All @@ -3187,7 +3190,38 @@ def write_database(
* 'append' will append to an existing table.
* 'fail' will fail if table already exists.
engine : {'sqlalchemy', 'adbc'}
Select the engine to use for writing frame data.
Select the engine to use for writing frame data; only necessary when
supplying a URI string (defaults to 'sqlalchemy' if unset)
engine_options
Additional options to pass to the engine's associated insert method:
* "sqlalchemy" - currently inserts using Pandas' `to_sql` method, though
this will eventually be phased out in favour of a native solution.
* "adbc" - inserts using the ADBC cursor's `adbc_ingest` method.
Examples
--------
Insert into a temporary table using a PostgreSQL URI and the ADBC engine:
>>> df.write_database(
... table_name="target_table",
... connection="postgresql://user:pass@server:port/database",
... engine="adbc",
... engine_options={"temporary": True},
... ) # doctest: +SKIP
Insert into a table using a `pyodbc` SQLAlchemy connection to SQL Server
that was instantiated with "fast_executemany=True" to improve performance:
>>> pyodbc_uri = (
... "mssql+pyodbc://user:pass@server:1433/test?"
... "driver=ODBC+Driver+18+for+SQL+Server"
... )
>>> engine = create_engine(pyodbc_uri, fast_executemany=True) # doctest: +SKIP
>>> df.write_database(
... table_name="target_table",
... connection=engine,
... ) # doctest: +SKIP
Returns
-------
Expand All @@ -3200,6 +3234,16 @@ def write_database(
msg = f"write_database `if_table_exists` must be one of {{{allowed}}}, got {if_table_exists!r}"
raise ValueError(msg)

if engine is None:
if (
isinstance(connection, str)
or (module_root := type(connection).__module__.split(".", 1)[0])
== "sqlalchemy"
):
engine = "sqlalchemy"
elif module_root.startswith("adbc"):
engine = "adbc"

def unpack_table_name(name: str) -> tuple[str | None, str | None, str]:
"""Unpack optionally qualified table name to catalog/schema/table tuple."""
from csv import reader as delimited_read
Expand Down Expand Up @@ -3237,7 +3281,12 @@ def unpack_table_name(name: str) -> tuple[str | None, str | None, str]:
)
raise ValueError(msg)

with _open_adbc_connection(connection) as conn, conn.cursor() as cursor:
conn = (
_open_adbc_connection(connection)
if isinstance(connection, str)
else connection
)
with conn, conn.cursor() as cursor:
catalog, db_schema, unpacked_table_name = unpack_table_name(table_name)
n_rows: int
if adbc_version >= (0, 7):
Expand Down Expand Up @@ -3265,26 +3314,35 @@ def unpack_table_name(name: str) -> tuple[str | None, str | None, str]:
)
else:
n_rows = cursor.adbc_ingest(
unpacked_table_name, self.to_arrow(), mode
table_name=unpacked_table_name,
data=self.to_arrow(),
mode=mode,
**(engine_options or {}),
)
conn.commit()
return n_rows

elif engine == "sqlalchemy":
if not _PANDAS_AVAILABLE:
msg = "writing with engine 'sqlalchemy' currently requires pandas.\n\nInstall with: pip install pandas"
msg = "writing with 'sqlalchemy' engine currently requires pandas.\n\nInstall with: pip install pandas"
raise ModuleNotFoundError(msg)
elif parse_version(pd.__version__) < (1, 5):
msg = f"writing with engine 'sqlalchemy' requires pandas 1.5.x or higher, found {pd.__version__!r}"
elif (pd_version := parse_version(pd.__version__)) < (1, 5):
msg = f"writing with 'sqlalchemy' engine requires pandas >= 1.5; found {pd.__version__!r}"
raise ModuleUpgradeRequired(msg)
try:
from sqlalchemy import create_engine
except ModuleNotFoundError as exc:
msg = "'sqlalchemy' not found\n\nInstall with: pip install polars[sqlalchemy]"
raise ModuleNotFoundError(msg) from exc

import_optional(
module_name="sqlalchemy",
min_version=("2.0" if pd_version >= parse_version("2.2") else "1.4"),
min_err_prefix="pandas >= 2.2 requires",
)
# note: the catalog (database) should be a part of the connection string
engine_sa = create_engine(connection)
from sqlalchemy.engine import create_engine

engine_sa = (
create_engine(connection)
if isinstance(connection, str)
else connection.engine # type: ignore[union-attr]
)
catalog, db_schema, unpacked_table_name = unpack_table_name(table_name)
if catalog:
msg = f"Unexpected three-part table name; provide the database/catalog ({catalog!r}) on the connection URI"
Expand All @@ -3300,11 +3358,16 @@ def unpack_table_name(name: str) -> tuple[str | None, str | None, str]:
con=engine_sa,
if_exists=if_table_exists,
index=False,
**(engine_options or {}),
)
return -1 if res is None else res
else:

elif isinstance(engine, str):
msg = f"engine {engine!r} is not supported"
raise ValueError(msg)
else:
msg = f"unrecognised connection type {connection!r}"
raise TypeError(msg)

@overload
def write_delta(
Expand Down
11 changes: 9 additions & 2 deletions py-polars/polars/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ def import_optional(
err_prefix: str = "required package",
err_suffix: str = "not found",
min_version: str | tuple[int, ...] | None = None,
min_err_prefix: str = "requires",
install_message: str | None = None,
) -> Any:
"""
Expand All @@ -244,6 +245,8 @@ def import_optional(
Error suffix to use in the raised exception (follows the module name).
min_version : {str, tuple[int]}, optional
If a minimum module version is required, specify it here.
min_err_prefix : str, optional
Override the standard "requires" prefix for the minimum version error message.
install_message : str, optional
Override the standard "Please install it using..." exception message fragment.
Expand All @@ -268,15 +271,19 @@ def import_optional(
suffix = f" {err_suffix.strip(' ')}" if err_suffix else ""
err_message = f"{prefix}'{module_name}'{suffix}.\n" + (
install_message
or f"Please install it using the command `pip install {module_root}`."
or f"Please install using the command `pip install {module_root}`."
)
raise ModuleNotFoundError(err_message) from None

if min_version:
min_version = parse_version(min_version)
mod_version = parse_version(module.__version__)
if mod_version < min_version:
msg = f"requires {module_root} {min_version} or higher; found {mod_version}"
msg = (
f"{min_err_prefix} {module_root} "
f"{'.'.join(str(v) for v in min_version)} or higher"
f" (found {'.'.join(str(v) for v in mod_version)})"
)
raise ModuleUpgradeRequired(msg)

return module
Expand Down
5 changes: 3 additions & 2 deletions py-polars/polars/type_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
if TYPE_CHECKING:
import sys

from sqlalchemy import Engine
from sqlalchemy.engine import Connection, Engine
from sqlalchemy.orm import Session

from polars import DataFrame, Expr, LazyFrame, Series
Expand Down Expand Up @@ -247,4 +247,5 @@ def fetchmany(self, *args: Any, **kwargs: Any) -> Any:
"""Fetch results in batches."""


ConnectionOrCursor = Union[BasicConnection, BasicCursor, Cursor, "Engine", "Session"]
AlchemyConnection = Union["Connection", "Engine", "Session"]
ConnectionOrCursor = Union[BasicConnection, BasicCursor, Cursor, AlchemyConnection]
85 changes: 65 additions & 20 deletions py-polars/tests/unit/io/database/test_write.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from __future__ import annotations

import sys
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

import pytest
from sqlalchemy import create_engine

import polars as pl
from polars.io.database._utils import _open_adbc_connection
from polars.testing import assert_frame_equal

if TYPE_CHECKING:
Expand All @@ -17,11 +18,21 @@

@pytest.mark.write_disk()
@pytest.mark.parametrize(
"engine",
("engine", "uri_connection"),
[
"sqlalchemy",
("sqlalchemy", True),
("sqlalchemy", False),
pytest.param(
"adbc",
True,
marks=pytest.mark.skipif(
sys.version_info < (3, 9) or sys.platform == "win32",
reason="adbc not available on Windows or <= Python 3.8",
),
),
pytest.param(
"adbc",
False,
marks=pytest.mark.skipif(
sys.version_info < (3, 9) or sys.platform == "win32",
reason="adbc not available on Windows or <= Python 3.8",
Expand All @@ -32,7 +43,18 @@
class TestWriteDatabase:
"""Database write tests that share common pytest/parametrize options."""

def test_write_database_create(self, engine: DbWriteEngine, tmp_path: Path) -> None:
@staticmethod
def _get_connection(uri: str, engine: DbWriteEngine, uri_connection: bool) -> Any:
if uri_connection:
return uri
elif engine == "sqlalchemy":
return create_engine(uri)
else:
return _open_adbc_connection(uri)

def test_write_database_create(
self, engine: DbWriteEngine, uri_connection: bool, tmp_path: Path
) -> None:
"""Test basic database table creation."""
df = pl.DataFrame(
{
Expand All @@ -42,14 +64,15 @@ def test_write_database_create(self, engine: DbWriteEngine, tmp_path: Path) -> N
}
)
tmp_path.mkdir(exist_ok=True)
test_db = str(tmp_path / f"test_{engine}.db")
test_db_uri = f"sqlite:///{test_db}"
test_db_uri = f"sqlite:///{tmp_path}/test_create_{int(uri_connection)}.db"

table_name = "test_create"
conn = self._get_connection(test_db_uri, engine, uri_connection)

assert (
df.write_database(
table_name=table_name,
connection=test_db_uri,
connection=conn,
engine=engine,
)
== 2
Expand All @@ -60,8 +83,11 @@ def test_write_database_create(self, engine: DbWriteEngine, tmp_path: Path) -> N
)
assert_frame_equal(result, df)

if hasattr(conn, "close"):
conn.close()

def test_write_database_append_replace(
self, engine: DbWriteEngine, tmp_path: Path
self, engine: DbWriteEngine, uri_connection: bool, tmp_path: Path
) -> None:
"""Test append/replace ops against existing database table."""
df = pl.DataFrame(
Expand All @@ -71,11 +97,11 @@ def test_write_database_append_replace(
"other": [5.5, 7.0, None],
}
)

tmp_path.mkdir(exist_ok=True)
test_db = str(tmp_path / f"test_{engine}.db")
test_db_uri = f"sqlite:///{test_db}"
table_name = f"test_append_{engine}"
test_db_uri = f"sqlite:///{tmp_path}/test_append_{int(uri_connection)}.db"

table_name = "test_append"
conn = self._get_connection(test_db_uri, engine, uri_connection)

assert (
df.write_database(
Expand Down Expand Up @@ -123,18 +149,26 @@ def test_write_database_append_replace(
)
assert_frame_equal(result, pl.concat([df, df[:2]]))

if hasattr(conn, "close"):
conn.close()

def test_write_database_create_quoted_tablename(
self, engine: DbWriteEngine, tmp_path: Path
self, engine: DbWriteEngine, uri_connection: bool, tmp_path: Path
) -> None:
"""Test parsing/handling of quoted database table names."""
df = pl.DataFrame({"col x": [100, 200, 300], "col y": ["a", "b", "c"]})

df = pl.DataFrame(
{
"col x": [100, 200, 300],
"col y": ["a", "b", "c"],
}
)
tmp_path.mkdir(exist_ok=True)
test_db = str(tmp_path / f"test_{engine}.db")
test_db_uri = f"sqlite:///{test_db}"
test_db_uri = f"sqlite:///{tmp_path}/test_create_quoted.db"

# table name requires quoting, and is qualified with the implicit 'main' schema
qualified_table_name = f'main."test-append-{engine}"'
# table name has some special chars, so requires quoting, and
# is explicitly qualified with the sqlite 'main' schema
qualified_table_name = f'main."test-append-{engine}-{int(uri_connection)}"'
conn = self._get_connection(test_db_uri, engine, uri_connection)

assert (
df.write_database(
Expand All @@ -159,7 +193,12 @@ def test_write_database_create_quoted_tablename(
)
assert_frame_equal(result, df)

def test_write_database_errors(self, engine: DbWriteEngine, tmp_path: Path) -> None:
if hasattr(conn, "close"):
conn.close()

def test_write_database_errors(
self, engine: DbWriteEngine, uri_connection: bool, tmp_path: Path
) -> None:
"""Confirm that expected errors are raised."""
df = pl.DataFrame({"colx": [1, 2, 3]})

Expand All @@ -182,3 +221,9 @@ def test_write_database_errors(self, engine: DbWriteEngine, tmp_path: Path) -> N
if_table_exists="do_something", # type: ignore[arg-type]
engine=engine,
)

with pytest.raises(
TypeError,
match="unrecognised connection type",
):
df.write_database(connection=True, table_name="misc") # type: ignore[arg-type]

0 comments on commit 9aecfd3

Please sign in to comment.