Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(python): Support passing instantiated adbc/alchemy connection objects to write_database #16099

Merged
merged 3 commits into from
May 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 80 additions & 17 deletions py-polars/polars/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@
ColumnWidthsDefinition,
ComparisonOperator,
ConditionalFormatDict,
ConnectionOrCursor,
CsvQuoteStyle,
DbWriteEngine,
FillNullStrategy,
Expand Down Expand Up @@ -3161,13 +3162,14 @@ def write_parquet(
def write_database(
self,
table_name: str,
connection: str,
connection: ConnectionOrCursor | str,
*,
if_table_exists: DbWriteMode = "fail",
engine: DbWriteEngine = "sqlalchemy",
engine: DbWriteEngine | None = None,
engine_options: dict[str, Any] | None = None,
) -> int:
"""
Write a polars frame to a database.
Write the data in a Polars DataFrame to a database.

Parameters
----------
Expand All @@ -3176,7 +3178,8 @@ def write_database(
SQL database. If your table name contains special characters, it should
be quoted.
connection
Connection URI string, for example:
An existing SQLAlchemy or ADBC connection against the target database, or
a URI string that will be used to instantiate such a connection, such as:

* "postgresql://user:pass@server:port/database"
* "sqlite:////path/to/database.db"
Expand All @@ -3187,7 +3190,38 @@ def write_database(
* 'append' will append to an existing table.
* 'fail' will fail if table already exists.
engine : {'sqlalchemy', 'adbc'}
Select the engine to use for writing frame data.
Select the engine to use for writing frame data; only necessary when
supplying a URI string (defaults to 'sqlalchemy' if unset)
engine_options
Additional options to pass to the engine's associated insert method:

* "sqlalchemy" - currently inserts using Pandas' `to_sql` method, though
this will eventually be phased out in favour of a native solution.
* "adbc" - inserts using the ADBC cursor's `adbc_ingest` method.

Examples
--------
Insert into a temporary table using a PostgreSQL URI and the ADBC engine:

>>> df.write_database(
... table_name="target_table",
... connection="postgresql://user:pass@server:port/database",
... engine="adbc",
... engine_options={"temporary": True},
... ) # doctest: +SKIP

Insert into a table using a `pyodbc` SQLAlchemy connection to SQL Server
that was instantiated with "fast_executemany=True" to improve performance:

>>> pyodbc_uri = (
... "mssql+pyodbc://user:pass@server:1433/test?"
... "driver=ODBC+Driver+18+for+SQL+Server"
... )
>>> engine = create_engine(pyodbc_uri, fast_executemany=True) # doctest: +SKIP
>>> df.write_database(
... table_name="target_table",
... connection=engine,
... ) # doctest: +SKIP

Returns
-------
Expand All @@ -3200,6 +3234,16 @@ def write_database(
msg = f"write_database `if_table_exists` must be one of {{{allowed}}}, got {if_table_exists!r}"
raise ValueError(msg)

if engine is None:
if (
isinstance(connection, str)
or (module_root := type(connection).__module__.split(".", 1)[0])
== "sqlalchemy"
):
engine = "sqlalchemy"
elif module_root.startswith("adbc"):
engine = "adbc"

def unpack_table_name(name: str) -> tuple[str | None, str | None, str]:
"""Unpack optionally qualified table name to catalog/schema/table tuple."""
from csv import reader as delimited_read
Expand Down Expand Up @@ -3237,7 +3281,12 @@ def unpack_table_name(name: str) -> tuple[str | None, str | None, str]:
)
raise ValueError(msg)

with _open_adbc_connection(connection) as conn, conn.cursor() as cursor:
conn = (
_open_adbc_connection(connection)
if isinstance(connection, str)
else connection
)
with conn, conn.cursor() as cursor:
catalog, db_schema, unpacked_table_name = unpack_table_name(table_name)
n_rows: int
if adbc_version >= (0, 7):
Expand Down Expand Up @@ -3265,26 +3314,35 @@ def unpack_table_name(name: str) -> tuple[str | None, str | None, str]:
)
else:
n_rows = cursor.adbc_ingest(
unpacked_table_name, self.to_arrow(), mode
table_name=unpacked_table_name,
data=self.to_arrow(),
mode=mode,
**(engine_options or {}),
)
conn.commit()
return n_rows

elif engine == "sqlalchemy":
if not _PANDAS_AVAILABLE:
msg = "writing with engine 'sqlalchemy' currently requires pandas.\n\nInstall with: pip install pandas"
msg = "writing with 'sqlalchemy' engine currently requires pandas.\n\nInstall with: pip install pandas"
raise ModuleNotFoundError(msg)
elif parse_version(pd.__version__) < (1, 5):
msg = f"writing with engine 'sqlalchemy' requires pandas 1.5.x or higher, found {pd.__version__!r}"
elif (pd_version := parse_version(pd.__version__)) < (1, 5):
msg = f"writing with 'sqlalchemy' engine requires pandas >= 1.5; found {pd.__version__!r}"
raise ModuleUpgradeRequired(msg)
try:
from sqlalchemy import create_engine
except ModuleNotFoundError as exc:
msg = "'sqlalchemy' not found\n\nInstall with: pip install polars[sqlalchemy]"
raise ModuleNotFoundError(msg) from exc

import_optional(
module_name="sqlalchemy",
min_version=("2.0" if pd_version >= parse_version("2.2") else "1.4"),
min_err_prefix="pandas >= 2.2 requires",
)
# note: the catalog (database) should be a part of the connection string
engine_sa = create_engine(connection)
from sqlalchemy.engine import create_engine

engine_sa = (
create_engine(connection)
if isinstance(connection, str)
else connection.engine # type: ignore[union-attr]
)
catalog, db_schema, unpacked_table_name = unpack_table_name(table_name)
if catalog:
msg = f"Unexpected three-part table name; provide the database/catalog ({catalog!r}) on the connection URI"
Expand All @@ -3300,11 +3358,16 @@ def unpack_table_name(name: str) -> tuple[str | None, str | None, str]:
con=engine_sa,
if_exists=if_table_exists,
index=False,
**(engine_options or {}),
)
return -1 if res is None else res
else:

elif isinstance(engine, str):
msg = f"engine {engine!r} is not supported"
raise ValueError(msg)
else:
msg = f"unrecognised connection type {connection!r}"
raise TypeError(msg)

@overload
def write_delta(
Expand Down
11 changes: 9 additions & 2 deletions py-polars/polars/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,7 @@ def import_optional(
err_prefix: str = "required package",
err_suffix: str = "not found",
min_version: str | tuple[int, ...] | None = None,
min_err_prefix: str = "requires",
install_message: str | None = None,
) -> Any:
"""
Expand All @@ -244,6 +245,8 @@ def import_optional(
Error suffix to use in the raised exception (follows the module name).
min_version : {str, tuple[int]}, optional
If a minimum module version is required, specify it here.
min_err_prefix : str, optional
Override the standard "requires" prefix for the minimum version error message.
install_message : str, optional
Override the standard "Please install it using..." exception message fragment.

Expand All @@ -268,15 +271,19 @@ def import_optional(
suffix = f" {err_suffix.strip(' ')}" if err_suffix else ""
err_message = f"{prefix}'{module_name}'{suffix}.\n" + (
install_message
or f"Please install it using the command `pip install {module_root}`."
or f"Please install using the command `pip install {module_root}`."
)
raise ModuleNotFoundError(err_message) from None

if min_version:
min_version = parse_version(min_version)
mod_version = parse_version(module.__version__)
if mod_version < min_version:
msg = f"requires {module_root} {min_version} or higher; found {mod_version}"
msg = (
f"{min_err_prefix} {module_root} "
f"{'.'.join(str(v) for v in min_version)} or higher"
f" (found {'.'.join(str(v) for v in mod_version)})"
)
raise ModuleUpgradeRequired(msg)

return module
Expand Down
5 changes: 3 additions & 2 deletions py-polars/polars/type_aliases.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
if TYPE_CHECKING:
import sys

from sqlalchemy import Engine
from sqlalchemy.engine import Connection, Engine
from sqlalchemy.orm import Session

from polars import DataFrame, Expr, LazyFrame, Series
Expand Down Expand Up @@ -247,4 +247,5 @@ def fetchmany(self, *args: Any, **kwargs: Any) -> Any:
"""Fetch results in batches."""


ConnectionOrCursor = Union[BasicConnection, BasicCursor, Cursor, "Engine", "Session"]
AlchemyConnection = Union["Connection", "Engine", "Session"]
ConnectionOrCursor = Union[BasicConnection, BasicCursor, Cursor, AlchemyConnection]
85 changes: 65 additions & 20 deletions py-polars/tests/unit/io/database/test_write.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from __future__ import annotations

import sys
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

import pytest
from sqlalchemy import create_engine

import polars as pl
from polars.io.database._utils import _open_adbc_connection
from polars.testing import assert_frame_equal

if TYPE_CHECKING:
Expand All @@ -17,11 +18,21 @@

@pytest.mark.write_disk()
@pytest.mark.parametrize(
"engine",
("engine", "uri_connection"),
[
"sqlalchemy",
("sqlalchemy", True),
("sqlalchemy", False),
pytest.param(
"adbc",
True,
marks=pytest.mark.skipif(
sys.version_info < (3, 9) or sys.platform == "win32",
reason="adbc not available on Windows or <= Python 3.8",
),
),
pytest.param(
"adbc",
False,
marks=pytest.mark.skipif(
sys.version_info < (3, 9) or sys.platform == "win32",
reason="adbc not available on Windows or <= Python 3.8",
Expand All @@ -32,7 +43,18 @@
class TestWriteDatabase:
"""Database write tests that share common pytest/parametrize options."""

def test_write_database_create(self, engine: DbWriteEngine, tmp_path: Path) -> None:
@staticmethod
def _get_connection(uri: str, engine: DbWriteEngine, uri_connection: bool) -> Any:
if uri_connection:
return uri
elif engine == "sqlalchemy":
return create_engine(uri)
else:
return _open_adbc_connection(uri)

def test_write_database_create(
self, engine: DbWriteEngine, uri_connection: bool, tmp_path: Path
) -> None:
"""Test basic database table creation."""
df = pl.DataFrame(
{
Expand All @@ -42,14 +64,15 @@ def test_write_database_create(self, engine: DbWriteEngine, tmp_path: Path) -> N
}
)
tmp_path.mkdir(exist_ok=True)
test_db = str(tmp_path / f"test_{engine}.db")
test_db_uri = f"sqlite:///{test_db}"
test_db_uri = f"sqlite:///{tmp_path}/test_create_{int(uri_connection)}.db"

table_name = "test_create"
conn = self._get_connection(test_db_uri, engine, uri_connection)

assert (
df.write_database(
table_name=table_name,
connection=test_db_uri,
connection=conn,
engine=engine,
)
== 2
Expand All @@ -60,8 +83,11 @@ def test_write_database_create(self, engine: DbWriteEngine, tmp_path: Path) -> N
)
assert_frame_equal(result, df)

if hasattr(conn, "close"):
conn.close()

def test_write_database_append_replace(
self, engine: DbWriteEngine, tmp_path: Path
self, engine: DbWriteEngine, uri_connection: bool, tmp_path: Path
) -> None:
"""Test append/replace ops against existing database table."""
df = pl.DataFrame(
Expand All @@ -71,11 +97,11 @@ def test_write_database_append_replace(
"other": [5.5, 7.0, None],
}
)

tmp_path.mkdir(exist_ok=True)
test_db = str(tmp_path / f"test_{engine}.db")
test_db_uri = f"sqlite:///{test_db}"
table_name = f"test_append_{engine}"
test_db_uri = f"sqlite:///{tmp_path}/test_append_{int(uri_connection)}.db"

table_name = "test_append"
conn = self._get_connection(test_db_uri, engine, uri_connection)

assert (
df.write_database(
Expand Down Expand Up @@ -123,18 +149,26 @@ def test_write_database_append_replace(
)
assert_frame_equal(result, pl.concat([df, df[:2]]))

if hasattr(conn, "close"):
conn.close()

def test_write_database_create_quoted_tablename(
self, engine: DbWriteEngine, tmp_path: Path
self, engine: DbWriteEngine, uri_connection: bool, tmp_path: Path
) -> None:
"""Test parsing/handling of quoted database table names."""
df = pl.DataFrame({"col x": [100, 200, 300], "col y": ["a", "b", "c"]})

df = pl.DataFrame(
{
"col x": [100, 200, 300],
"col y": ["a", "b", "c"],
}
)
tmp_path.mkdir(exist_ok=True)
test_db = str(tmp_path / f"test_{engine}.db")
test_db_uri = f"sqlite:///{test_db}"
test_db_uri = f"sqlite:///{tmp_path}/test_create_quoted.db"

# table name requires quoting, and is qualified with the implicit 'main' schema
qualified_table_name = f'main."test-append-{engine}"'
# table name has some special chars, so requires quoting, and
# is explicitly qualified with the sqlite 'main' schema
qualified_table_name = f'main."test-append-{engine}-{int(uri_connection)}"'
conn = self._get_connection(test_db_uri, engine, uri_connection)

assert (
df.write_database(
Expand All @@ -159,7 +193,12 @@ def test_write_database_create_quoted_tablename(
)
assert_frame_equal(result, df)

def test_write_database_errors(self, engine: DbWriteEngine, tmp_path: Path) -> None:
if hasattr(conn, "close"):
conn.close()

def test_write_database_errors(
self, engine: DbWriteEngine, uri_connection: bool, tmp_path: Path
) -> None:
"""Confirm that expected errors are raised."""
df = pl.DataFrame({"colx": [1, 2, 3]})

Expand All @@ -182,3 +221,9 @@ def test_write_database_errors(self, engine: DbWriteEngine, tmp_path: Path) -> N
if_table_exists="do_something", # type: ignore[arg-type]
engine=engine,
)

with pytest.raises(
TypeError,
match="unrecognised connection type",
):
df.write_database(connection=True, table_name="misc") # type: ignore[arg-type]