Skip to content

Commit

Permalink
Dask integration (#56)
Browse files Browse the repository at this point in the history
* Move data proxy

* Improve data proxy serialization

* Throw away unwanted keys

* Record distributed in setup
  • Loading branch information
DPeterK authored Oct 21, 2020
1 parent 484762d commit 6c77eaa
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 39 deletions.
119 changes: 119 additions & 0 deletions nctotdb/proxy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
from distributed.protocol import dask_serialize, dask_deserialize
import numpy as np
import tiledb


# Inspired by https://github.com/SciTools/iris/blob/master/lib/iris/fileformats/netcdf.py#L418.
class TileDBDataProxy(object):
"""A proxy to the data of a single TileDB array attribute."""

__slots__ = ("shape", "dtype", "path", "var_name", "ctx", "handle_nan")

def __init__(self, shape, dtype, path, var_name, ctx=None, handle_nan=None):
self.shape = shape
self.dtype = dtype
self.path = path
self.var_name = var_name
self.ctx = ctx
self.handle_nan = handle_nan

@property
def ndim(self):
return len(self.shape)

def __getitem__(self, keys):
with tiledb.open(self.path, 'r', ctx=self.ctx) as A:
data = A[keys][self.var_name]
if self.handle_nan is not None:
if self.handle_nan == 'mask':
data = np.ma.masked_invalid(data, np.nan)
elif type(self.handle_nan) in [int, float]:
data = np.nan_to_num(data, nan=self.handle_nan, copy=False)
else:
raise ValueError(f'Not a valid nan-handling approach: {self.handle_nan!r}.')
return data

def serialize_state(self, dummy=None):
"""
Take the current state of `self` and make it serializable.
Note the apparently unused kwarg `dummy`. This allows `serialize_state` to be used
as the 'default' serialization function for msgpack. For example:
```
msgpack.dumps(my_data_proxy, default=my_data_proxy.serialize_state)
```
In such instances, msgpack calls `default` with the object to be dumped, which makes
no sense in this application.
"""
state = {}
for attr in self.__slots__:
value = getattr(self, attr)
if attr == "shape":
# `shape` could either be a simple list (of np.int!) or a tuple of slices...
result = {"type": None, "value": None}
if isinstance(value, tuple):
result["type"] = "tuple"
result["value"] = [[int(s.start), int(s.stop), int(s.step)] for s in value]
else:
result["type"] = "list"
result["value"] = [int(i) for i in value]
state[attr] = result
elif attr == "dtype":
state[attr] = np.dtype(value).str
elif attr == "ctx":
# ctx is based on a C library that doesn't pickle...
state[attr] = value.config().dict() if value is not None else None
else:
state[attr] = value
return state

def __getstate__(self):
"""Simplify a complex object for pickling."""
return self.serialize_state()

def __setstate__(self, state):
"""Restore the complex object from the simple pickled dict."""
deserialized_state = deserialize_state(state)
for key, value in deserialized_state.items():
if key in self.__slots__:
setattr(self, key, value)


def deserialize_state(s_state):
"""
Take a serialized dictionary of state and deserialize it to set state
on a `TileDBDataProxy` instance.
"""
d_state = {}
for key, s_value in s_state.items():
if key == "shape":
if s_value["type"] == "tuple":
result = [slice(*l) for l in s_value["value"]]
d_value = tuple(result)
elif s_value["type"] == "list":
d_value = s_value["value"]
else:
raise RuntimeError(f"Cannot deserialize {key!r} with type {s_value['type']!r}.")
elif key == "dtype":
d_value = np.dtype(s_value)
elif key == "ctx":
d_value = tiledb.Ctx(config=tiledb.Config(s_value)) if s_value is not None else None
else:
d_value = s_value
d_state[key] = d_value
return d_state


@dask_serialize.register(TileDBDataProxy)
def tdb_data_proxy_dumps(data_proxy):
return data_proxy.serialize_state(), []


@dask_deserialize.register(TileDBDataProxy)
def tdb_data_proxy_loads(header, frames):
deserialized_state = deserialize_state(header)
return TileDBDataProxy(**deserialized_state)
41 changes: 2 additions & 39 deletions nctotdb/readers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import zarr

from .grid_mappings import GridMapping
from .proxy import TileDBDataProxy
from . import utils


Expand All @@ -39,44 +40,6 @@
])


# Inspired by https://github.com/SciTools/iris/blob/master/lib/iris/fileformats/netcdf.py#L418.
class TileDBDataProxy:
"""A proxy to the data of a single TileDB array attribute."""

__slots__ = ("shape", "dtype", "path", "var_name", "ctx", "handle_nan")

def __init__(self, shape, dtype, path, var_name, ctx=None, handle_nan=None):
self.shape = shape
self.dtype = dtype
self.path = path
self.var_name = var_name
self.ctx = ctx
self.handle_nan = handle_nan

@property
def ndim(self):
return len(self.shape)

def __getitem__(self, keys):
with tiledb.open(self.path, 'r', ctx=self.ctx) as A:
data = A[keys][self.var_name]
if self.handle_nan is not None:
if self.handle_nan == 'mask':
data = np.ma.masked_invalid(data, np.nan)
elif type(self.handle_nan) in [int, float]:
data = np.nan_to_num(data, nan=self.handle_nan, copy=False)
else:
raise ValueError(f'Not a valid nan-handling approach: {self.handle_nan!r}.')
return data

def __getstate__(self):
return {attr: getattr(self, attr) for attr in self.__slots__}

def __setstate__(self, state):
for key, value in state.items():
setattr(self, key, value)


class Reader(object):
"""
Abstract reader class that defines the API.
Expand Down Expand Up @@ -348,7 +311,7 @@ def _from_tdb_array(self, array_path, naming_key,
def _load_dim(self, dim_path, grid_mapping):
"""
Create an Iris DimCoord from a TileDB array describing a dimension.
# TODO not handled here: circular.
"""
Expand Down
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
],
python_requires=">=3.7",
install_requires=[
"distributed>=2.28.0",
"tiledb>=0.6.6",
"scitools-iris>=2.4.0",
"xarray>=0.15.1",
Expand Down

0 comments on commit 6c77eaa

Please sign in to comment.