diff --git a/pyproject.toml b/pyproject.toml index 8e7e872f80..d311cd1e41 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,7 +55,7 @@ check-hidden = true ignore-words-list = "crate,arithmetics,ser" # Feel free to un-skip examples, and experimental, you will just need to # work through many typos (--write-changes and --interactive will help) -skip = "tests/series/*,target,.git,.venv,venv,data,*.csv,*.csv.*,*.html,*.json,*.jsonl,*.pdf,*.txt,*.ipynb,*.tiktoken,*.sql,tests/table/utf8/*,tests/table/binary/*" +skip = "tests/series/*,tests/table/utf8/*,tests/table/binary/*,target,.git,.venv,venv,data,*.csv,*.csv.*,*.html,*.json,*.jsonl,*.pdf,*.txt,*.ipynb,*.tiktoken,*.sql" [tool.maturin] # "python" tells pyo3 we want to build an extension module (skips linking against libpython.so) diff --git a/tests/table/utf8/test_concat.py b/tests/table/utf8/test_concat.py new file mode 100644 index 0000000000..d3bda39d8e --- /dev/null +++ b/tests/table/utf8/test_concat.py @@ -0,0 +1,211 @@ +from __future__ import annotations + +import pytest + +from daft.expressions import col, lit +from daft.table import MicroPartition + + +@pytest.mark.parametrize( + "input_a,input_b,expected_result", + [ + # Basic ASCII concatenation + ( + ["Hello", "Test", "", "End"], + [" World", "ing", "Empty", "!"], + ["Hello World", "Testing", "Empty", "End!"], + ), + # Special UTF-8 sequences + ( + [ + "☃", # Snowman + "😉", # Winking face + "🌈", # Rainbow + "Hello☃", # String with UTF-8 + "Hello\u0000", # String with null + ], + [ + "😉", # Winking face + "☃", # Snowman + "☃", # Snowman + "World", # ASCII + "\u0000World", # Null and string + ], + [ + "☃😉", # Snowman + Winking face + "😉☃", # Winking face + Snowman + "🌈☃", # Rainbow + Snowman + "Hello☃World", # String with UTF-8 + ASCII + "Hello\u0000\u0000World", # String with multiple nulls + ], + ), + # Nulls and empty strings + ( + ["Hello", None, "", "Test", None, "End", ""], + [" World", "!", None, None, "ing", "", "Empty"], + ["Hello World", None, None, None, None, "End", "Empty"], + ), + # Mixed length concatenation + ( + ["a", "ab", "abc", "abcd"], + ["1", "12", "123", "1234"], + ["a1", "ab12", "abc123", "abcd1234"], + ), + # Empty string combinations + ( + ["", "", "Hello", "World", ""], + ["", "Test", "", "", "!"], + ["", "Test", "Hello", "World", "!"], + ), + # Complex UTF-8 sequences + ( + [ + "☃", # Snowman + "😉", # Winking face + "🌈", # Rainbow + "☃😉", # Snowman + Winking face + ], + [ + "😉", # Winking face + "☃", # Snowman + "☃", # Snowman + "🌈", # Rainbow + ], + [ + "☃😉", # Snowman + Winking face + "😉☃", # Winking face + Snowman + "🌈☃", # Rainbow + Snowman + "☃😉🌈", # Snowman + Winking face + Rainbow + ], + ), + # Null characters in different positions + ( + [ + "\u0000abc", # Leading null + "abc\u0000", # Trailing null + "ab\u0000c", # Middle null + "\u0000ab\u0000c\u0000", # Multiple nulls + ], + [ + "def\u0000", # Trailing null + "\u0000def", # Leading null + "d\u0000ef", # Middle null + "\u0000de\u0000f\u0000", # Multiple nulls + ], + [ + "\u0000abcdef\u0000", # Nulls at ends + "abc\u0000\u0000def", # Adjacent nulls + "ab\u0000cd\u0000ef", # Separated nulls + "\u0000ab\u0000c\u0000\u0000de\u0000f\u0000", # Many nulls + ], + ), + ], +) +def test_utf8_concat(input_a: list[str | None], input_b: list[str | None], expected_result: list[str | None]) -> None: + table = MicroPartition.from_pydict({"a": input_a, "b": input_b}) + result = table.eval_expression_list([col("a").str.concat(col("b"))]) + assert result.to_pydict() == {"a": expected_result} + + +@pytest.mark.parametrize( + "input_data,literal,expected_result", + [ + # Basic broadcasting + ( + ["Hello", "Goodbye", "Test"], + " World!", + ["Hello World!", "Goodbye World!", "Test World!"], + ), + # Broadcasting with nulls + ( + ["Hello", None, "Test"], + " World!", + ["Hello World!", None, "Test World!"], + ), + # Broadcasting with UTF-8 sequences + ( + ["Hello", "Test", "Goodbye"], + "☃", # Snowman + ["Hello☃", "Test☃", "Goodbye☃"], + ), + # Broadcasting with null characters + ( + ["Hello", "Test\u0000", "\u0000World"], + "\u0000", + ["Hello\u0000", "Test\u0000\u0000", "\u0000World\u0000"], + ), + # Broadcasting with empty strings + ( + ["", "Test", ""], + "☃", + ["☃", "Test☃", "☃"], + ), + # Broadcasting with complex UTF-8 + ( + ["Hello", "Test", "Goodbye"], + "☃😉🌈", # Snowman + Winking face + Rainbow + ["Hello☃😉🌈", "Test☃😉🌈", "Goodbye☃😉🌈"], + ), + # Broadcasting with literal None + ( + ["Hello", None, "Test", ""], + None, + [None, None, None, None], # Any concat with None should result in None + ), + ], +) +def test_utf8_concat_broadcast( + input_data: list[str | None], literal: str | None, expected_result: list[str | None] +) -> None: + # Test right-side broadcasting + table = MicroPartition.from_pydict({"a": input_data}) + result = table.eval_expression_list([col("a").str.concat(literal)]) + assert result.to_pydict() == {"a": expected_result} + + # Test left-side broadcasting + table = MicroPartition.from_pydict({"b": input_data}) + result = table.eval_expression_list([lit(literal).str.concat(col("b"))]) + if literal is None: + # When literal is None, all results should be None + assert result.to_pydict() == {"literal": [None] * len(input_data)} + else: + assert result.to_pydict() == { + "literal": [ + lit + data if data is not None else None for lit, data in zip([literal] * len(input_data), input_data) + ] + } + + +def test_utf8_concat_edge_cases() -> None: + # Test various edge cases + table = MicroPartition.from_pydict( + { + "a": [ + "", # Empty string + "\u0000", # Single null character + "Hello", # Normal string + None, # Null value + "☃", # UTF-8 sequence + "😉", # Another UTF-8 sequence + ], + "b": [ + "", # Empty + Empty + "\u0000", # Null + Null + "", # Normal + Empty + None, # Null + Null + "😉", # UTF-8 + UTF-8 + "☃", # UTF-8 + UTF-8 + ], + } + ) + result = table.eval_expression_list([col("a").str.concat(col("b"))]) + assert result.to_pydict() == { + "a": [ + "", # Empty + Empty = Empty + "\u0000\u0000", # Null + Null = Two nulls + "Hello", # Normal + Empty = Normal + None, # Null + Null = Null + "☃😉", # Snowman + Winking face + "😉☃", # Winking face + Snowman + ] + } diff --git a/tests/table/utf8/test_length.py b/tests/table/utf8/test_length.py index c96815ed64..8925d493fc 100644 --- a/tests/table/utf8/test_length.py +++ b/tests/table/utf8/test_length.py @@ -1,10 +1,101 @@ from __future__ import annotations +import pytest + from daft.expressions import col from daft.table import MicroPartition -def test_utf8_length(): - table = MicroPartition.from_pydict({"col": ["foo", None, "barbaz", "quux", "😉test", ""]}) +@pytest.mark.parametrize( + "input_data,expected_lengths", + [ + # Basic ASCII strings + ( + ["Hello", "World!", "", "Test"], + [5, 6, 0, 4], + ), + # Special UTF-8 sequences + ( + [ + "☃", # UTF-8 encoded snowman + "😉", # UTF-8 encoded winking face + "🌈", # UTF-8 encoded rainbow + "Hello☃World", # Mixed ASCII and UTF-8 + "☃😉🌈", # Multiple UTF-8 characters + "Hello\u0000World", # String with null character + ], + [1, 1, 1, 11, 3, 11], + ), + # Nulls and empty strings + ( + ["Hello", None, "", "\u0000", None, "Test", ""], + [5, None, 0, 1, None, 4, 0], + ), + # Large strings + ( + ["x" * 1000, "y" * 10000, "z" * 100000], + [1000, 10000, 100000], + ), + # Mixed strings with different sizes + ( + [ + "a", # Single character + "ab", # Two characters + "abc", # Three characters + "☃", # Single UTF-8 character + "☃☃", # Two UTF-8 characters + "☃☃☃", # Three UTF-8 characters + ], + [1, 2, 3, 1, 2, 3], + ), + # Strings with repeated patterns + ( + [ + "\u0000" * 5, # Repeated null characters + "ab" * 5, # Repeated ASCII pattern + "☃" * 5, # Repeated UTF-8 snowman + "😉" * 5, # Repeated UTF-8 winking face + ], + [5, 10, 5, 5], + ), + # Edge cases with single characters + ( + [ + "\u0000", # Null character + "\u0001", # Start of heading + "\u001f", # Unit separator + " ", # Space + "\u007f", # Delete + "☃", # Snowman + "😉", # Winking face + ], + [1, 1, 1, 1, 1, 1, 1], + ), + # Complex UTF-8 sequences + ( + [ + "☃", # Snowman + "😉", # Winking face + "☃😉", # Snowman + Winking face + "🌈", # Rainbow + "🌈☃", # Rainbow + Snowman + "☃🌈😉", # Snowman + Rainbow + Winking face + ], + [1, 1, 2, 1, 2, 3], + ), + # Mixed content lengths + ( + [ + "Hello☃World", # ASCII + UTF-8 + ASCII + "\u0000Hello\u0000World\u0000", # Null-separated + "☃Hello☃World☃", # UTF-8-separated + "Hello😉World", # ASCII + UTF-8 + ASCII + ], + [11, 13, 13, 11], + ), + ], +) +def test_utf8_length(input_data: list[str | None], expected_lengths: list[int | None]) -> None: + table = MicroPartition.from_pydict({"col": input_data}) result = table.eval_expression_list([col("col").str.length()]) - assert result.to_pydict() == {"col": [3, None, 6, 4, 5, 0]} + assert result.to_pydict() == {"col": expected_lengths} diff --git a/tests/table/utf8/test_substr.py b/tests/table/utf8/test_substr.py index 9989e9bd86..311c9aa075 100644 --- a/tests/table/utf8/test_substr.py +++ b/tests/table/utf8/test_substr.py @@ -1,10 +1,412 @@ from __future__ import annotations +import pytest + +from daft import DataType from daft.expressions import col from daft.table import MicroPartition -def test_utf8_substr(): - table = MicroPartition.from_pydict({"col": ["foo", None, "barbarbar", "quux", "1", ""]}) +def test_utf8_substr() -> None: + table = MicroPartition.from_pydict( + { + "col": [ + "foo", + None, + "barbarbar", + "quux", + "1", + "", + "Hello☃World", # UTF-8 character in middle + "😉test", # UTF-8 character at start + "test🌈", # UTF-8 character at end + "☃😉🌈", # Multiple UTF-8 characters + "Hello\u0000World", # Null character + ] + } + ) result = table.eval_expression_list([col("col").str.substr(0, 5)]) - assert result.to_pydict() == {"col": ["foo", None, "barba", "quux", "1", None]} + assert result.to_pydict() == { + "col": [ + "foo", + None, + "barba", + "quux", + "1", + None, + "Hello", # Should handle UTF-8 correctly + "😉test", # Should include full emoji + "test🌈", # Should include full emoji + "☃😉🌈", # Should include all characters + "Hello", # Should handle null character + ] + } + + +@pytest.mark.parametrize( + "input_data,start_data,length_data,expected_result", + [ + # Test with column for start position + ( + ["hello", "world", "test", "Hello☃World", "😉test", "test🌈"], + [1, 0, 2, 5, 1, 4], + 3, + ["ell", "wor", "st", "☃Wo", "tes", "🌈"], + ), + # Test with column for length + ( + ["hello", "world", "test", "Hello☃World", "😉best", "test🌈"], + 1, + [2, 3, 4, 5, 2, 1], + ["el", "orl", "est", "ello☃", "be", "e"], + ), + # Test with both start and length as columns + ( + ["hello", "world", "test", "Hello☃World", "😉test", "test🌈"], + [1, 0, 2, 5, 1, 4], + [2, 3, 1, 2, 3, 1], + ["el", "wor", "s", "☃W", "tes", "🌈"], + ), + # Test with nulls in start column + ( + ["hello", "world", "test", "Hello☃World", "😉test", "test🌈"], + [1, None, 2, None, 1, None], + 3, + ["ell", None, "st", None, "tes", None], + ), + # Test with nulls in length column + ( + ["hello", "world", "test", "Hello☃World", "😉best", "test🌈"], + 1, + [2, None, 4, None, 2, None], + ["el", "orld", "est", "ello☃World", "be", "est🌈"], + ), + # Test with nulls in both columns + ( + ["hello", "world", "test", "Hello☃World", "😉test", "test🌈"], + [1, None, 2, 5, None, 4], + [2, 3, None, None, 2, None], + ["el", None, "st", "☃World", None, "🌈"], + ), + ], +) +def test_utf8_substr_with_columns( + input_data: list[str | None], + start_data: list[int | None] | int, + length_data: list[int | None] | int, + expected_result: list[str | None], +) -> None: + table_data = {"col": input_data} + if isinstance(start_data, list): + table_data["start"] = start_data + start = col("start") + else: + start = start_data + + if isinstance(length_data, list): + table_data["length"] = length_data + length = col("length") + else: + length = length_data + + table = MicroPartition.from_pydict(table_data) + result = table.eval_expression_list([col("col").str.substr(start, length)]) + assert result.to_pydict() == {"col": expected_result} + + +@pytest.mark.parametrize( + "input_data,start,length,expected_result", + [ + # Test start beyond string length + ( + ["hello", "world", "Hello☃World", "😉test", "test🌈"], + [10, 20, 15, 10, 10], + 2, + [None, None, None, None, None], + ), + # Test start way beyond string length + ( + [ + "hello", # len 5 + "world", # len 5 + "test", # len 4 + "☃😉🌈", # len 3 + ], + [100, 1000, 50, 25], + 5, + [None, None, None, None], + ), + # Test start beyond length with None length + ( + [ + "hello", + "world", + "test", + "☃😉🌈", + ], + [10, 20, 15, 8], + None, + [None, None, None, None], + ), + # Test zero length + ( + ["hello", "world", "Hello☃World", "😉test", "test🌈"], + [1, 0, 5, 0, 4], + 0, + [None, None, None, None, None], + ), + # Test very large length + ( + ["hello", "world", "Hello☃World", "😉test", "test🌈"], + [0, 1, 5, 0, 4], + 100, + ["hello", "orld", "☃World", "😉test", "🌈"], + ), + # Test empty strings + ( + ["", "", ""], + [0, 1, 2], + 3, + [None, None, None], + ), + # Test start + length overflow + ( + ["hello", "world", "Hello☃World", "😉test", "test🌈"], + [2, 3, 5, 0, 4], + 9999999999, + ["llo", "ld", "☃World", "😉test", "🌈"], + ), + # Test UTF-8 character boundaries + ( + ["Hello☃World", "😉test", "test🌈", "☃😉🌈"], + [4, 0, 3, 1], + 2, + ["o☃", "😉t", "t🌈", "😉🌈"], + ), + ], +) +def test_utf8_substr_edge_cases( + input_data: list[str], + start: list[int], + length: int, + expected_result: list[str | None], +) -> None: + table = MicroPartition.from_pydict({"col": input_data, "start": start}) + result = table.eval_expression_list([col("col").str.substr(col("start"), length)]) + assert result.to_pydict() == {"col": expected_result} + + +def test_utf8_substr_errors() -> None: + # Test negative start + table = MicroPartition.from_pydict({"col": ["hello", "world", "Hello☃World"], "start": [-1, -2, -3]}) + with pytest.raises(Exception, match="Error in repeat: failed to cast length as usize"): + table.eval_expression_list([col("col").str.substr(col("start"), 2)]) + + # Test negative length + table = MicroPartition.from_pydict({"col": ["hello", "world", "Hello☃World"]}) + with pytest.raises(Exception, match="Error in substr: failed to cast length as usize -3"): + table.eval_expression_list([col("col").str.substr(0, -3)]) + + # Test both negative + table = MicroPartition.from_pydict({"col": ["hello", "world", "Hello☃World"], "start": [-2, -1, -3]}) + with pytest.raises(Exception, match="Error in substr: failed to cast length as usize -2"): + table.eval_expression_list([col("col").str.substr(col("start"), -2)]) + + # Test negative length in column + table = MicroPartition.from_pydict({"col": ["hello", "world", "Hello☃World"], "length": [-2, -3, -4]}) + with pytest.raises(Exception, match="Error in repeat: failed to cast length as usize"): + table.eval_expression_list([col("col").str.substr(0, col("length"))]) + + +def test_utf8_substr_computed() -> None: + # Test with computed start index (length - 5) + table = MicroPartition.from_pydict( + { + "col": [ + "hello world", # len=11, start=6, expect "wor" + "python programming", # len=17, start=12, expect "mmi" + "data science", # len=12, start=7, expect "ien" + "artificial", # len=10, start=5, expect "ici" + "intelligence", # len=12, start=7, expect "gen" + "Hello☃World", # len=11, start=6, expect "Wor" + "test😉best", # len=9, start=4, expect "😉be" + "test🌈best", # len=9, start=4, expect "🌈be" + ] + } + ) + result = table.eval_expression_list( + [ + col("col").str.substr( + (col("col").str.length() - 5).cast(DataType.int32()), # start 5 chars from end + 3, # take 3 chars + ) + ] + ) + assert result.to_pydict() == {"col": ["wor", "mmi", "ien", "ici", "gen", "Wor", "😉be", "🌈be"]} + + # Test with computed length (half of string length) + table = MicroPartition.from_pydict( + { + "col": [ + "hello world", # len=11, len/2=5, expect "hello" + "python programming", # len=17, len/2=8, expect "python pr" + "data science", # len=12, len/2=6, expect "data s" + "artificial", # len=10, len/2=5, expect "artif" + "intelligence", # len=12, len/2=6, expect "intell" + "Hello☃World", # len=11, len/2=5, expect "Hello" + "test😉test", # len=9, len/2=4, expect "test" + "test🌈test", # len=9, len/2=4, expect "test" + ] + } + ) + result = table.eval_expression_list( + [ + col("col").str.substr( + 0, # start from beginning + (col("col").str.length() / 2).cast(DataType.int32()), # take half of string + ) + ] + ) + assert result.to_pydict() == {"col": ["hello", "python pr", "data s", "artif", "intell", "Hello", "test", "test"]} + + # Test with both computed start and length + table = MicroPartition.from_pydict( + { + "col": [ + "hello world", # len=11, start=2, len=3, expect "llo" + "python programming", # len=17, start=3, len=5, expect "hon pr" + "data science", # len=12, start=2, len=4, expect "ta s" + "artificial", # len=10, start=2, len=3, expect "tif" + "intelligence", # len=12, start=2, len=4, expect "tell" + "Hello☃World", # len=11, start=2, len=3, expect "llo" + "test😉test", # len=9, start=1, len=3, expect "est😉" + "test🌈test", # len=9, start=1, len=3, expect "est🌈" + ] + } + ) + result = table.eval_expression_list( + [ + col("col").str.substr( + (col("col").str.length() / 5).cast(DataType.int32()), # start at 1/5 of string + (col("col").str.length() / 3).cast(DataType.int32()), # take 1/3 of string + ) + ] + ) + assert result.to_pydict() == {"col": ["llo", "hon pr", "ta s", "tif", "tell", "llo", "est", "est"]} + + +def test_utf8_substr_multiple_slices() -> None: + # Test taking multiple different substrings from the same string data + table = MicroPartition.from_pydict( + { + "col": [ + "hello", # Simple ASCII + "Hello☃World", # With UTF-8 character + "😉test🌈", # Multiple UTF-8 characters + "data science", # With space + ] + } + ) + + # Get multiple slices + result = table.eval_expression_list( + [ + col("col").str.substr(1, 3).alias("slice1"), # Middle slice + col("col").str.substr(0, 1).alias("slice2"), # First character + col("col").str.substr(2, 2).alias("slice3"), # Another middle slice + col("col") + .str.substr((col("col").str.length().cast(DataType.int64()) - 1), 1) + .alias("slice4"), # Last character + ] + ) + + assert result.to_pydict() == { + "slice1": ["ell", "ell", "tes", "ata"], + "slice2": ["h", "H", "😉", "d"], + "slice3": ["ll", "ll", "es", "ta"], + "slice4": ["o", "d", "🌈", "e"], + } + + # Test with computed indices + result = table.eval_expression_list( + [ + # First half + col("col").str.substr(0, (col("col").str.length() / 2).cast(DataType.int32())).alias("first_half"), + # Second half + col("col") + .str.substr( + (col("col").str.length() / 2).cast(DataType.int32()), + (col("col").str.length() / 2).cast(DataType.int32()), + ) + .alias("second_half"), + # Middle third + col("col") + .str.substr( + (col("col").str.length() / 3).cast(DataType.int32()), + (col("col").str.length() / 3).cast(DataType.int32()), + ) + .alias("middle_third"), + ] + ) + + assert result.to_pydict() == { + "first_half": ["he", "Hello", "😉te", "data s"], + "second_half": ["ll", "☃Worl", "st🌈", "cience"], + "middle_third": ["e", "lo☃", "es", " sci"], + } + + +def test_utf8_substr_zero_length() -> None: + # Test various combinations of zero length substrings + table = MicroPartition.from_pydict( + { + "col": ["hello", "Hello☃World", "😉test", "test🌈", "☃😉🌈"], + "start": [0, 2, 4, 1, 3], + "length": [0, 0, 0, 0, 0], # All zero lengths + "length_mixed": [0, 1, 0, 2, 0], # Mix of zero and non-zero lengths + } + ) + + # Test with scalar zero length + result = table.eval_expression_list([col("col").str.substr(0, 0)]) + assert result.to_pydict() == {"col": [None, None, None, None, None]} + + # Test with column start and scalar zero length + result = table.eval_expression_list([col("col").str.substr(col("start"), 0)]) + assert result.to_pydict() == {"col": [None, None, None, None, None]} + + # Test with scalar start and column zero length + result = table.eval_expression_list([col("col").str.substr(1, col("length"))]) + assert result.to_pydict() == {"col": [None, None, None, None, None]} + + # Test with both column start and zero length + result = table.eval_expression_list([col("col").str.substr(col("start"), col("length"))]) + assert result.to_pydict() == {"col": [None, None, None, None, None]} + + # Test with mixed zero and non-zero lengths + result = table.eval_expression_list([col("col").str.substr(col("start"), col("length_mixed"))]) + assert result.to_pydict() == { + "col": [ + None, # length 0 + "l", # length 1, start 2 + None, # length 0 + "es", # length 2, start 1 + None, # length 0 + ] + } + + # Test zero length at string boundaries + result = table.eval_expression_list( + [ + col("col").str.substr(0, 0).alias("start"), # At start + col("col").str.substr(col("col").str.length(), 0).alias("end"), # At end + col("col") + .str.substr(col("col").str.length().cast(DataType.int32()) - 1, 0) + .alias("before_end"), # Before end + ] + ) + assert result.to_pydict() == { + "start": [None, None, None, None, None], + "end": [None, None, None, None, None], + "before_end": [None, None, None, None, None], + } diff --git a/tests/table/utf8/test_substr_baseline.py b/tests/table/utf8/test_substr_baseline.py new file mode 100644 index 0000000000..88ee03a0cd --- /dev/null +++ b/tests/table/utf8/test_substr_baseline.py @@ -0,0 +1,410 @@ +from __future__ import annotations + +import pandas as pd +import polars as pl +from pyspark.sql import SparkSession +from pyspark.sql.functions import col as spark_col +from pyspark.sql.functions import substring + +# Test strings that cover various UTF-8 scenarios +TEST_STRINGS = [ + "Hello☃World", # UTF-8 in middle + "😉test", # UTF-8 at start + "test🌈", # UTF-8 at end + "☃😉🌈", # Multiple UTF-8 characters + "", # Empty string +] + + +def test_negative_start_index() -> None: + """Test behavior when start index is negative.""" + print("\n" + "=" * 80) + print("TEST: Negative Start Index") + print("=" * 80) + test_data = TEST_STRINGS + start = -1 # Take last character + length = 1 # Take one character + + # Pandas - allows negative indices (counts from end) + pd_result = pd.Series(test_data).str.slice(start, length) + + # Polars - allows negative indices (counts from end) + pl_result = pl.Series("col", test_data).str.slice(start, length) + + # Spark - doesn't support negative indices + spark = SparkSession.builder.appName("SubstrTest").getOrCreate() + spark_df = spark.createDataFrame([(x,) for x in test_data], ["col"]) + try: + spark_result = spark_df.select( + substring(spark_col("col"), start + 1, length) # Spark uses 1-based indexing + ).collect() + spark_list = [row[0] for row in spark_result] + except Exception as e: + spark_list = [f"Error: {e!s}"] * len(test_data) + finally: + spark.stop() + + print("\nTest negative start index (start=-1, length=1):") + print("Input:", test_data) + print("Pandas:", pd_result.tolist()) + print("Polars:", pl_result.to_list()) + print("Spark:", spark_list) + # Expected behavior: + # - Pandas: Returns last character for each string + # - Polars: Returns last character for each string + # - Spark: Errors on negative indices + + +def test_start_at_string_boundaries() -> None: + """Test behavior when start index is at string boundaries (0, len, len+1).""" + print("\n" + "=" * 80) + print("TEST: String Boundaries (Start at 0, len, len+1)") + print("=" * 80) + test_data = TEST_STRINGS + substr_length = 3 # Fixed length to see what happens at boundaries + + # Test start = 0 (beginning of string) + pd_start_0 = pd.Series(test_data).str.slice(0, substr_length) + pl_start_0 = pl.Series("col", test_data).str.slice(0, substr_length) + + # Test start = len(string) (end of string) + pd_df = pd.DataFrame({"col": test_data}) + pd_start_len = pd_df.apply( + lambda row: row["col"][len(str(row["col"])) : len(str(row["col"])) + substr_length] + if pd.notna(row["col"]) + else None, + axis=1, + ) + + pl_df = pl.DataFrame({"col": test_data}) + pl_start_len = pl_df.select(pl.col("col").str.slice(pl.col("col").str.len_chars(), substr_length)).to_series() + + # Test start > len(string) (beyond string) + pd_start_beyond = pd_df.apply( + lambda row: row["col"][len(str(row["col"])) + 1 : len(str(row["col"])) + 1 + substr_length] + if pd.notna(row["col"]) + else None, + axis=1, + ) + + pl_start_beyond = pl_df.select( + pl.col("col").str.slice(pl.col("col").str.len_chars() + 1, substr_length) + ).to_series() + + # Spark implementation + spark = SparkSession.builder.appName("SubstrTest").getOrCreate() + spark_df = spark.createDataFrame([(x,) for x in test_data], ["col"]) + + spark_start_0 = spark_df.select( + substring(spark_col("col"), 1, substr_length) # Spark uses 1-based indexing + ).collect() + + # Using SQL expression for column arithmetic + spark_start_len = spark_df.selectExpr(f"substring(col, length(col) + 1, {substr_length}) as result").collect() + + spark_start_beyond = spark_df.selectExpr(f"substring(col, length(col) + 2, {substr_length}) as result").collect() + + spark.stop() + + print("\nTest start index at boundaries (length=3):") + print("Input:", test_data) + print("\nStart at beginning (0):") + print("Pandas:", pd_start_0.tolist()) + print("Polars:", pl_start_0.to_list()) + print("Spark:", [row[0] for row in spark_start_0]) + + print("\nStart at end (len):") + print("Pandas:", pd_start_len.tolist()) + print("Polars:", pl_start_len.to_list()) + print("Spark:", [row[0] for row in spark_start_len]) + + print("\nStart beyond end (len+1):") + print("Pandas:", pd_start_beyond.tolist()) + print("Polars:", pl_start_beyond.to_list()) + print("Spark:", [row[0] for row in spark_start_beyond]) + # Expected behavior: + # - Start at 0: Returns up to length characters from start + # - Start at len: Returns empty string or null + # - Start beyond len: Returns empty string or null + + +def test_length_not_specified() -> None: + """Test behavior when length parameter is not specified.""" + print("\n" + "=" * 80) + print("TEST: Unspecified Length Parameter") + print("=" * 80) + test_data = TEST_STRINGS + start = 2 # Start from third character + + # Pandas - should go to end of string + pd_result = pd.Series(test_data).str.slice(start) + + # Polars - should go to end of string + pl_result = pl.Series("col", test_data).str.slice(start) + + # Spark - requires length parameter, use max possible length + spark = SparkSession.builder.appName("SubstrTest").getOrCreate() + spark_df = spark.createDataFrame([(x,) for x in test_data], ["col"]) + max_length = 1000 # Use a large number to effectively get rest of string + spark_result = spark_df.select(substring(spark_col("col"), start + 1, max_length)).collect() + spark_list = [row[0] for row in spark_result] + spark.stop() + + print("\nTest length not specified (start=2):") + print("Input:", test_data) + print("Pandas:", pd_result.tolist()) + print("Polars:", pl_result.to_list()) + print("Spark (using max_length=1000):", spark_list) + + +def test_length_none() -> None: + """Test behavior when length parameter is explicitly set to None.""" + print("\n" + "=" * 80) + print("TEST: None Length Parameter") + print("=" * 80) + test_data = TEST_STRINGS + start = 2 + length = None + + # Pandas - None length + pd_result = pd.Series(test_data).str.slice(start, length) + + # Polars - None length + pl_result = pl.Series("col", test_data).str.slice(start, length) + + # Spark - None length (not typically supported) + spark = SparkSession.builder.appName("SubstrTest").getOrCreate() + spark_df = spark.createDataFrame([(x,) for x in test_data], ["col"]) + try: + spark_result = spark_df.select(substring(spark_col("col"), start + 1, None)).collect() + spark_list = [row[0] for row in spark_result] + except Exception as e: + spark_list = [f"Error: {e!s}"] * len(test_data) + finally: + spark.stop() + + print("\nTest explicit None length (start=2, length=None):") + print("Input:", test_data) + print("Pandas:", pd_result.tolist()) + print("Polars:", pl_result.to_list()) + print("Spark:", spark_list) + # Expected behavior: May differ between libraries, but should handle None gracefully + + +def test_zero_and_negative_length() -> None: + """Test behavior with zero and negative length values.""" + print("\n" + "=" * 80) + print("TEST: Zero and Negative Length") + print("=" * 80) + test_data = TEST_STRINGS + start = 2 + + # Zero length + pd_zero = pd.Series(test_data).str.slice(start, 0) + pl_zero = pl.Series("col", test_data).str.slice(start, 0) + + # Negative length - Polars doesn't support this, only test with Pandas + pd_neg = pd.Series(test_data).str.slice(start, -2) + try: + pl_neg = pl.Series("col", test_data).str.slice(start, -2) + except Exception as e: + pl_neg = [f"Error: {e!s}"] * len(test_data) + + # Spark implementation + spark = SparkSession.builder.appName("SubstrTest").getOrCreate() + spark_df = spark.createDataFrame([(x,) for x in test_data], ["col"]) + + # Zero length + try: + spark_zero = spark_df.select(substring(spark_col("col"), start + 1, 0)).collect() + spark_zero_list = [row[0] for row in spark_zero] + except Exception as e: + spark_zero_list = [f"Error: {e!s}"] * len(test_data) + + # Negative length + try: + spark_neg = spark_df.select(substring(spark_col("col"), start + 1, -2)).collect() + spark_neg_list = [row[0] for row in spark_neg] + except Exception as e: + spark_neg_list = [f"Error: {e!s}"] * len(test_data) + + spark.stop() + + print("\nTest zero length (start=2, length=0):") + print("Input:", test_data) + print("Pandas:", pd_zero.tolist()) + print("Polars:", pl_zero.to_list()) + print("Spark:", spark_zero_list) + + print("\nTest negative length (start=2, length=-2):") + print("Pandas:", pd_neg.tolist()) + print("Polars:", pl_neg) + print("Spark:", spark_neg_list) + + +def test_length_overflow() -> None: + """Test behavior when length extends beyond string end or is very large.""" + print("\n" + "=" * 80) + print("TEST: Length Overflow") + print("=" * 80) + test_data = TEST_STRINGS + start = 2 + + # Length > remaining string + remaining_length = 100 # Much larger than any test string + pd_overflow = pd.Series(test_data).str.slice(start, remaining_length) + pl_overflow = pl.Series("col", test_data).str.slice(start, remaining_length) + + # Very large length (near system limits) + large_length = 999999999 + pd_large = pd.Series(test_data).str.slice(start, large_length) + pl_large = pl.Series("col", test_data).str.slice(start, large_length) + + # Spark implementation + spark = SparkSession.builder.appName("SubstrTest").getOrCreate() + spark_df = spark.createDataFrame([(x,) for x in test_data], ["col"]) + + # Length > remaining string + spark_overflow = spark_df.select(substring(spark_col("col"), start + 1, remaining_length)).collect() + + # Very large length + try: + spark_large = spark_df.select(substring(spark_col("col"), start + 1, large_length)).collect() + spark_large_list = [row[0] for row in spark_large] + except Exception as e: + spark_large_list = [f"Error: {e!s}"] * len(test_data) + + spark.stop() + + print("\nTest length > remaining string (start=2, length=100):") + print("Input:", test_data) + print("Pandas:", pd_overflow.tolist()) + print("Polars:", pl_overflow.to_list()) + print("Spark:", [row[0] for row in spark_overflow]) + + print("\nTest very large length (start=2, length=999999999):") + print("Pandas:", pd_large.tolist()) + print("Polars:", pl_large.to_list()) + print("Spark:", spark_large_list) + # Expected behavior: + # Overflow: Should return remaining string from start + # Large length: Should handle gracefully, same as overflow + + +def test_utf8_character_boundaries() -> None: + """Test behavior when substring boundaries intersect UTF-8 characters.""" + print("\n" + "=" * 80) + print("TEST: UTF-8 Character Boundaries") + print("=" * 80) + # Test strings where start/length would split UTF-8 characters + test_data = [ + "Hello☃World", # 3-byte UTF-8 char + "Hi😉Smile", # 4-byte UTF-8 char + "🌈Rainbow", # Start with 4-byte + "Star⭐End", # 3-byte in middle + ] + + # Test with start in middle of UTF-8 char + pd_mid_char = pd.Series(test_data).str.slice(5, 3) + pl_mid_char = pl.Series("col", test_data).str.slice(5, 3) + + # Test with length that would split UTF-8 char + pd_split_char = pd.Series(test_data).str.slice(4, 2) + pl_split_char = pl.Series("col", test_data).str.slice(4, 2) + + # Spark implementation + spark = SparkSession.builder.appName("SubstrTest").getOrCreate() + spark_df = spark.createDataFrame([(x,) for x in test_data], ["col"]) + + spark_mid_char = spark_df.select( + substring(spark_col("col"), 6, 3) # Spark uses 1-based indexing + ).collect() + + spark_split_char = spark_df.select(substring(spark_col("col"), 5, 2)).collect() + + spark.stop() + + print("\nTest UTF-8 character boundary handling:") + print("Input:", test_data) + print("\nStart in middle of UTF-8 char (start=5, length=3):") + print("Pandas:", pd_mid_char.tolist()) + print("Polars:", pl_mid_char.to_list()) + print("Spark:", [row[0] for row in spark_mid_char]) + + print("\nLength splitting UTF-8 char (start=4, length=2):") + print("Pandas:", pd_split_char.tolist()) + print("Polars:", pl_split_char.to_list()) + print("Spark:", [row[0] for row in spark_split_char]) + + +def test_mixed_ascii_utf8() -> None: + """Test behavior with strings containing both ASCII and UTF-8 characters.""" + print("\n" + "=" * 80) + print("TEST: Mixed ASCII and UTF-8") + print("=" * 80) + test_data = [ + "ABC☃DEF", # ASCII-UTF8-ASCII + "12😉34", # Numbers-UTF8-Numbers + "Hi🌈Bye⭐", # Multiple UTF-8 mixed + "Test⭐", # ASCII ending with UTF-8 + "☃Start", # UTF-8 starting with ASCII + ] + + # Test normal substring + pd_result = pd.Series(test_data).str.slice(2, 3) + pl_result = pl.Series("col", test_data).str.slice(2, 3) + + # Spark implementation + spark = SparkSession.builder.appName("SubstrTest").getOrCreate() + spark_df = spark.createDataFrame([(x,) for x in test_data], ["col"]) + + spark_result = spark_df.select( + substring(spark_col("col"), 3, 3) # Spark uses 1-based indexing + ).collect() + + spark.stop() + + print("\nTest mixed ASCII and UTF-8 handling (start=2, length=3):") + print("Input:", test_data) + print("Pandas:", pd_result.tolist()) + print("Polars:", pl_result.to_list()) + print("Spark:", [row[0] for row in spark_result]) + + +def test_null_values() -> None: + """Test behavior with null/None values in the input.""" + print("\n" + "=" * 80) + print("TEST: Null Values") + print("=" * 80) + test_data = [None, "Hello", None, "World", None] + + # Test with different start/length combinations + pd_result1 = pd.Series(test_data).str.slice(0, 3) + pd_result2 = pd.Series(test_data).str.slice(2, None) + + pl_result1 = pl.Series("col", test_data).str.slice(0, 3) + pl_result2 = pl.Series("col", test_data).str.slice(2, None) + + # Spark implementation + spark = SparkSession.builder.appName("SubstrTest").getOrCreate() + spark_df = spark.createDataFrame([(x,) for x in test_data], ["col"]) + + spark_result1 = spark_df.select(substring(spark_col("col"), 1, 3)).collect() + + # Fixed: providing length parameter for Spark + max_length = 1000 # Use a large number to effectively get rest of string + spark_result2 = spark_df.select(substring(spark_col("col"), 3, max_length)).collect() + + spark.stop() + + print("\nTest null value handling:") + print("Input:", test_data) + print("\nFirst test (start=0, length=3):") + print("Pandas:", pd_result1.tolist()) + print("Polars:", pl_result1.to_list()) + print("Spark:", [row[0] for row in spark_result1]) + + print("\nSecond test (start=2, length=None):") + print("Pandas:", pd_result2.tolist()) + print("Polars:", pl_result2.to_list()) + print("Spark (using max_length=1000):", [row[0] for row in spark_result2])