Skip to content

Commit

Permalink
update tests, lint
Browse files Browse the repository at this point in the history
  • Loading branch information
alexander-beedie committed May 7, 2024
1 parent a75500c commit 3f8f6b2
Showing 1 changed file with 132 additions and 126 deletions.
258 changes: 132 additions & 126 deletions py-polars/tests/unit/io/database/test_write.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

import sys
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

import pytest
from sqlalchemy import create_engine
Expand All @@ -18,11 +18,21 @@

@pytest.mark.write_disk()
@pytest.mark.parametrize(
"engine",
("engine", "uri_connection"),
[
"sqlalchemy",
("sqlalchemy", True),
("sqlalchemy", False),
pytest.param(
"adbc",
True,
marks=pytest.mark.skipif(
sys.version_info < (3, 9) or sys.platform == "win32",
reason="adbc not available on Windows or <= Python 3.8",
),
),
pytest.param(
"adbc",
False,
marks=pytest.mark.skipif(
sys.version_info < (3, 9) or sys.platform == "win32",
reason="adbc not available on Windows or <= Python 3.8",
Expand All @@ -33,7 +43,18 @@
class TestWriteDatabase:
"""Database write tests that share common pytest/parametrize options."""

def test_write_database_create(self, engine: DbWriteEngine, tmp_path: Path) -> None:
@staticmethod
def _get_connection(uri: str, engine: DbWriteEngine, uri_connection: bool) -> Any:
if uri_connection:
return uri
elif engine == "sqlalchemy":
return create_engine(uri)
else:
return _open_adbc_connection(uri)

def test_write_database_create(
self, engine: DbWriteEngine, uri_connection: bool, tmp_path: Path
) -> None:
"""Test basic database table creation."""
df = pl.DataFrame(
{
Expand All @@ -43,37 +64,30 @@ def test_write_database_create(self, engine: DbWriteEngine, tmp_path: Path) -> N
}
)
tmp_path.mkdir(exist_ok=True)
test_db = str(tmp_path / f"test_{engine}.db")
test_db_uri = f"sqlite:///{test_db}"
table_name_stub = "test_create"

for idx, conn in enumerate(
(
test_db_uri,
create_engine(test_db_uri),
_open_adbc_connection(test_db_uri),
)
):
table_name = f"{table_name_stub}{idx}"
assert (
df.write_database(
table_name=table_name,
connection=test_db_uri,
engine=engine,
)
== 2
)
result = pl.read_database(
query=f"SELECT * FROM {table_name}",
connection=create_engine(test_db_uri),
test_db_uri = f"sqlite:///{tmp_path}/test_create_{int(uri_connection)}.db"

table_name = "test_create"
conn = self._get_connection(test_db_uri, engine, uri_connection)

assert (
df.write_database(
table_name=table_name,
connection=conn,
engine=engine,
)
assert_frame_equal(result, df)
== 2
)
result = pl.read_database(
query=f"SELECT * FROM {table_name}",
connection=create_engine(test_db_uri),
)
assert_frame_equal(result, df)

if hasattr(conn, "close"):
conn.close()
if hasattr(conn, "close"):
conn.close()

def test_write_database_append_replace(
self, engine: DbWriteEngine, tmp_path: Path
self, engine: DbWriteEngine, uri_connection: bool, tmp_path: Path
) -> None:
"""Test append/replace ops against existing database table."""
df = pl.DataFrame(
Expand All @@ -83,116 +97,108 @@ def test_write_database_append_replace(
"other": [5.5, 7.0, None],
}
)

tmp_path.mkdir(exist_ok=True)
test_db = str(tmp_path / f"test_{engine}.db")
test_db_uri = f"sqlite:///{test_db}"
table_name_stub = "test_create"

for idx, conn in enumerate(
(
test_db_uri,
create_engine(test_db_uri),
_open_adbc_connection(test_db_uri),
)
):
table_name = f"{table_name_stub}{idx}"
assert (
df.write_database(
table_name=table_name,
connection=test_db_uri,
engine=engine,
)
== 3
)
with pytest.raises(Exception): # noqa: B017
df.write_database(
table_name=table_name,
connection=test_db_uri,
if_table_exists="fail",
engine=engine,
)

assert (
df.write_database(
table_name=table_name,
connection=test_db_uri,
if_table_exists="replace",
engine=engine,
)
== 3
test_db_uri = f"sqlite:///{tmp_path}/test_append_{int(uri_connection)}.db"

table_name = "test_append"
conn = self._get_connection(test_db_uri, engine, uri_connection)

assert (
df.write_database(
table_name=table_name,
connection=test_db_uri,
engine=engine,
)
result = pl.read_database(
query=f"SELECT * FROM {table_name}",
connection=create_engine(test_db_uri),
== 3
)
with pytest.raises(Exception): # noqa: B017
df.write_database(
table_name=table_name,
connection=test_db_uri,
if_table_exists="fail",
engine=engine,
)
assert_frame_equal(result, df)

assert (
df[:2].write_database(
table_name=table_name,
connection=test_db_uri,
if_table_exists="append",
engine=engine,
)
== 2

assert (
df.write_database(
table_name=table_name,
connection=test_db_uri,
if_table_exists="replace",
engine=engine,
)
result = pl.read_database(
query=f"SELECT * FROM {table_name}",
connection=create_engine(test_db_uri),
== 3
)
result = pl.read_database(
query=f"SELECT * FROM {table_name}",
connection=create_engine(test_db_uri),
)
assert_frame_equal(result, df)

assert (
df[:2].write_database(
table_name=table_name,
connection=test_db_uri,
if_table_exists="append",
engine=engine,
)
assert_frame_equal(result, pl.concat([df, df[:2]]))
== 2
)
result = pl.read_database(
query=f"SELECT * FROM {table_name}",
connection=create_engine(test_db_uri),
)
assert_frame_equal(result, pl.concat([df, df[:2]]))

if hasattr(conn, "close"):
conn.close()
if hasattr(conn, "close"):
conn.close()

def test_write_database_create_quoted_tablename(
self, engine: DbWriteEngine, tmp_path: Path
self, engine: DbWriteEngine, uri_connection: bool, tmp_path: Path
) -> None:
"""Test parsing/handling of quoted database table names."""
df = pl.DataFrame({"col x": [100, 200, 300], "col y": ["a", "b", "c"]})

df = pl.DataFrame(
{
"col x": [100, 200, 300],
"col y": ["a", "b", "c"],
}
)
tmp_path.mkdir(exist_ok=True)
test_db = str(tmp_path / f"test_{engine}.db")
test_db_uri = f"sqlite:///{test_db}"

for idx, conn in enumerate(
(
test_db_uri,
create_engine(test_db_uri),
_open_adbc_connection(test_db_uri),
)
):
# table name has some special chars, so requires quoting, and
# is expliocitly qualified with the sqlite 'main' schema
qualified_table_name = f'main."test-append-{engine}{idx}"'
assert (
df.write_database(
table_name=qualified_table_name,
connection=test_db_uri,
engine=engine,
)
== 3
)
assert (
df.write_database(
table_name=qualified_table_name,
connection=test_db_uri,
if_table_exists="replace",
engine=engine,
)
== 3
test_db_uri = f"sqlite:///{tmp_path}/test_create_quoted.db"

# table name has some special chars, so requires quoting, and
# is explicitly qualified with the sqlite 'main' schema
qualified_table_name = f'main."test-append-{engine}-{int(uri_connection)}"'
conn = self._get_connection(test_db_uri, engine, uri_connection)

assert (
df.write_database(
table_name=qualified_table_name,
connection=test_db_uri,
engine=engine,
)
result = pl.read_database(
query=f"SELECT * FROM {qualified_table_name}",
connection=create_engine(test_db_uri),
== 3
)
assert (
df.write_database(
table_name=qualified_table_name,
connection=test_db_uri,
if_table_exists="replace",
engine=engine,
)
assert_frame_equal(result, df)
== 3
)
result = pl.read_database(
query=f"SELECT * FROM {qualified_table_name}",
connection=create_engine(test_db_uri),
)
assert_frame_equal(result, df)

if hasattr(conn, "close"):
conn.close()
if hasattr(conn, "close"):
conn.close()

def test_write_database_errors(self, engine: DbWriteEngine, tmp_path: Path) -> None:
def test_write_database_errors(
self, engine: DbWriteEngine, uri_connection: bool, tmp_path: Path
) -> None:
"""Confirm that expected errors are raised."""
df = pl.DataFrame({"colx": [1, 2, 3]})

Expand Down

0 comments on commit 3f8f6b2

Please sign in to comment.