Skip to content

Commit

Permalink
Fix flatten array with sliced input (#73)
Browse files Browse the repository at this point in the history
* Fix flatten array with sliced input

* Add Python tests and ci

* Remove extra job

* Add test for sliced struct field access

* arrow-rs bug?

* add dbg test

* workaround

* remove rust test example
  • Loading branch information
kylebarron authored Jul 29, 2024
1 parent e584d0d commit eac16fe
Show file tree
Hide file tree
Showing 13 changed files with 935 additions and 17 deletions.
82 changes: 82 additions & 0 deletions .github/workflows/test-python.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
name: Python

on:
push:
branches:
- main
pull_request:

jobs:
# lint-python:
# name: Lint Python code
# runs-on: ubuntu-latest
# steps:
# - uses: actions/checkout@v4

# - name: Set up Python 3.8
# uses: actions/setup-python@v2
# with:
# python-version: "3.8"

# - name: run pre-commit
# run: |
# python -m pip install pre-commit
# pre-commit run --all-files

test-python:
name: Build and test Python
runs-on: ubuntu-latest
strategy:
fail-fast: true
matrix:
python-version: ["3.9", "3.12"]
steps:
- uses: actions/checkout@v4

- name: Install Rust
uses: dtolnay/rust-toolchain@stable

- uses: Swatinem/rust-cache@v2

- name: Set up Python
id: setup-python
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}

- name: Install and configure Poetry
uses: snok/install-poetry@v1
with:
version: 1.8.2
virtualenvs-create: true
virtualenvs-in-project: true
installer-parallel: true

- name: Check Poetry lockfile up to date
run: |
poetry check --lock
- name: Load cached venv
id: cached-poetry-dependencies
uses: actions/cache@v4
with:
path: .venv
key: venv-${{ runner.os }}-${{ steps.setup-python.outputs.python-version }}-${{ hashFiles('poetry.lock') }}

- name: Install dependencies
if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true'
run: poetry install --no-interaction --no-root

- name: Install root project
run: poetry install --no-interaction

- name: Build rust submodules
run: |
# Note: core module must be first, because it's depended on by others
poetry run maturin develop -m arro3-core/Cargo.toml
poetry run maturin develop -m arro3-compute/Cargo.toml
poetry run maturin develop -m arro3-io/Cargo.toml
- name: Run python tests
run: |
poetry run pytest tests
23 changes: 19 additions & 4 deletions arro3-compute/src/list_flatten.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,33 @@ pub fn list_flatten(py: Python, input: AnyArray) -> PyArrowResult<PyObject> {
}

fn flatten_array(array: ArrayRef) -> Result<ArrayRef, ArrowError> {
let offset = array.offset();
let length = array.len();
match array.data_type() {
DataType::List(_) => {
let arr = array.as_list::<i32>();
Ok(arr.values().clone())
let start = arr.offsets().get(offset).unwrap();
let end = arr.offsets().get(offset + length).unwrap();
Ok(arr
.values()
.slice(*start as usize, (*end - *start) as usize)
.clone())
}
DataType::LargeList(_) => {
let arr = array.as_list::<i64>();
Ok(arr.values().clone())
let start = arr.offsets().get(offset).unwrap();
let end = arr.offsets().get(offset + length).unwrap();
Ok(arr
.values()
.slice(*start as usize, (*end - *start) as usize)
.clone())
}
DataType::FixedSizeList(_, _) => {
DataType::FixedSizeList(_, list_size) => {
let arr = array.as_fixed_size_list();
Ok(arr.values().clone())
Ok(arr.values().clone().slice(
offset * (*list_size as usize),
(offset + length) * (*list_size as usize),
))
}
_ => Err(ArrowError::SchemaError(
"Expected list-typed Array".to_string(),
Expand Down
12 changes: 8 additions & 4 deletions arro3-compute/src/struct_field.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,20 @@ pub(crate) fn struct_field(
values: PyArray,
indices: StructIndex,
) -> PyArrowResult<PyObject> {
let (array, field) = values.into_inner();
let (orig_array, field) = values.into_inner();
let indices = indices.into_list();

let mut array_ref = &array;
let mut array_ref = &orig_array;
let mut field_ref = &field;
for i in indices {
(array_ref, field_ref) = get_child(&array, i)?;
(array_ref, field_ref) = get_child(array_ref, i)?;
}

Ok(PyArray::new(array_ref.clone(), field_ref.clone()).to_arro3(py)?)
Ok(PyArray::new(
array_ref.slice(orig_array.offset(), orig_array.len()),
field_ref.clone(),
)
.to_arro3(py)?)
}

fn get_child(array: &ArrayRef, i: usize) -> Result<(&ArrayRef, &FieldRef), ArrowError> {
Expand Down
Loading

0 comments on commit eac16fe

Please sign in to comment.