diff --git a/docs/api/sql/Stac.md b/docs/api/sql/Stac.md index 98909e025b..26ebd082b7 100644 --- a/docs/api/sql/Stac.md +++ b/docs/api/sql/Stac.md @@ -146,6 +146,156 @@ In this example, the data source will push down the temporal filter to the under In this example, the data source will push down the spatial filter to the underlying data source. +# Python API + +The Python API allows you to interact with a SpatioTemporal Asset Catalog (STAC) API using the Client class. This class provides methods to open a connection to a STAC API, retrieve collections, and search for items with various filters. + +## Client Class + +## Methods + +### `open(url: str) -> Client` + +Opens a connection to the specified STAC API URL. + +**Parameters:** + +- `url` (*str*): The URL of the STAC API to connect to. + **Example:** `"https://planetarycomputer.microsoft.com/api/stac/v1"` + +**Returns:** + +- `Client`: An instance of the `Client` class connected to the specified URL. + +--- + +### `get_collection(collection_id: str) -> CollectionClient` + +Retrieves a collection client for the specified collection ID. + +**Parameters:** + +- `collection_id` (*str*): The ID of the collection to retrieve. + **Example:** `"aster-l1t"` + +**Returns:** + +- `CollectionClient`: An instance of the `CollectionClient` class for the specified collection. + +--- + +### `search(*ids: Union[str, list], collection_id: str, bbox: Optional[list] = None, datetime: Optional[Union[str, datetime.datetime, list]] = None, max_items: Optional[int] = None, return_dataframe: bool = True) -> Union[Iterator[PyStacItem], DataFrame]` + +Searches for items in the specified collection with optional filters. + +**Parameters:** + +- `ids` (*Union[str, list]*): A variable number of item IDs to filter the items. + **Example:** `"item_id1"` or `["item_id1", "item_id2"]` +- `collection_id` (*str*): The ID of the collection to search in. + **Example:** `"aster-l1t"` +- `bbox` (*Optional[list]*): A list of bounding boxes for filtering the items. Each bounding box is represented as a list of four float values: `[min_lon, min_lat, max_lon, max_lat]`. + **Example:** `[[ -180.0, -90.0, 180.0, 90.0 ]]` +- `datetime` (*Optional[Union[str, datetime.datetime, list]]*): A single datetime, RFC 3339-compliant timestamp, or a list of date-time ranges for filtering the items. + **Example:** + - `"2020-01-01T00:00:00Z"` + - `datetime.datetime(2020, 1, 1)` + - `[["2020-01-01T00:00:00Z", "2021-01-01T00:00:00Z"]]` +- `max_items` (*Optional[int]*): The maximum number of items to return from the search, even if there are more matching results. + **Example:** `100` +- `return_dataframe` (*bool*): If `True` (default), return the result as a Spark DataFrame instead of an iterator of `PyStacItem` objects. + **Example:** `True` + +**Returns:** + +- *Union[Iterator[PyStacItem], DataFrame]*: An iterator of `PyStacItem` objects or a Spark DataFrame that matches the specified filters. + +## Sample Code + +### Initialize the Client + +```python +from sedona.stac.client import Client + +# Initialize the client +client = Client.open("https://planetarycomputer.microsoft.com/api/stac/v1") +``` + +### Search Items on a Collection Within a Year + +```python +items = client.search( + collection_id="aster-l1t", + datetime="2020", + return_dataframe=False +) +``` + +### Search Items on a Collection Within a Month and Max Items + +```python +items = client.search( + collection_id="aster-l1t", + datetime="2020-05", + return_dataframe=False, + max_items=5 +) +``` + +### Search Items with Bounding Box and Interval + +```python +items = client.search( + collection_id="aster-l1t", + ids=["AST_L1T_00312272006020322_20150518201805"], + bbox=[-180.0, -90.0, 180.0, 90.0], + datetime=["2006-01-01T00:00:00Z", "2007-01-01T00:00:00Z"], + return_dataframe=False +) +``` + +### Search Multiple Items with Multiple Bounding Boxes + +```python +bbox_list = [ + [-180.0, -90.0, 180.0, 90.0], + [-100.0, -50.0, 100.0, 50.0] +] +items = client.search( + collection_id="aster-l1t", + bbox=bbox_list, + return_dataframe=False +) +``` + +### Search Items and Get DataFrame as Return with Multiple Intervals + +```python +interval_list = [ + ["2020-01-01T00:00:00Z", "2020-06-01T00:00:00Z"], + ["2020-07-01T00:00:00Z", "2021-01-01T00:00:00Z"] +] +df = client.search( + collection_id="aster-l1t", + datetime=interval_list, + return_dataframe=True +) +df.show() +``` + +### Save Items in DataFrame to GeoParquet with Both Bounding Boxes and Intervals + +```python +# Save items in DataFrame to GeoParquet with both bounding boxes and intervals +client.get_collection("aster-l1t").save_to_geoparquet( + output_path="/path/to/output", + bbox=bbox_list, + datetime="2020-05" +) +``` + +These examples demonstrate how to use the Client class to search for items in a STAC collection with various filters and return the results as either an iterator of PyStacItem objects or a Spark DataFrame. + # References - STAC Specification: https://stacspec.org/ diff --git a/python/Pipfile b/python/Pipfile index 8c899b263e..389c56c0ac 100644 --- a/python/Pipfile +++ b/python/Pipfile @@ -29,6 +29,7 @@ attrs="*" pyarrow="*" keplergl = "==0.3.2" pydeck = "===0.8.0" +pystac = "===1.5.0" rasterio = ">=1.2.10" [requires] diff --git a/python/sedona/stac/__init__.py b/python/sedona/stac/__init__.py new file mode 100644 index 0000000000..a67d5ea255 --- /dev/null +++ b/python/sedona/stac/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/python/sedona/stac/client.py b/python/sedona/stac/client.py new file mode 100644 index 0000000000..3e8eeacefa --- /dev/null +++ b/python/sedona/stac/client.py @@ -0,0 +1,112 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from typing import Union, Optional, Iterator + +from sedona.stac.collection_client import CollectionClient + +import datetime as python_datetime +from pystac import Item as PyStacItem + +from pyspark.sql import DataFrame + + +class Client: + def __init__(self, url: str): + self.url = url + + @classmethod + def open(cls, url: str): + """ + Opens a connection to the specified STAC API URL. + + This class method creates an instance of the Client class with the given URL. + + Parameters: + - url (str): The URL of the STAC API to connect to. + Example: "https://planetarycomputer.microsoft.com/api/stac/v1" + + Returns: + - Client: An instance of the Client class connected to the specified URL. + """ + return cls(url) + + def get_collection(self, collection_id: str): + """ + Retrieves a collection client for the specified collection ID. + + This method creates an instance of the CollectionClient class for the given collection ID, + allowing interaction with the specified collection in the STAC API. + + Parameters: + - collection_id (str): The ID of the collection to retrieve. + Example: "aster-l1t" + + Returns: + - CollectionClient: An instance of the CollectionClient class for the specified collection. + """ + return CollectionClient(self.url, collection_id) + + def search( + self, + *ids: Union[str, list], + collection_id: str, + bbox: Optional[list] = None, + datetime: Optional[Union[str, python_datetime.datetime, list]] = None, + max_items: Optional[int] = None, + return_dataframe: bool = True, + ) -> Union[Iterator[PyStacItem], DataFrame]: + """ + Searches for items in the specified collection with optional filters. + + Parameters: + - ids (Union[str, list]): A variable number of item IDs to filter the items. + Example: "item_id1" or ["item_id1", "item_id2"] + + - collection_id (str): The ID of the collection to search in. + Example: "aster-l1t" + + - bbox (Optional[list]): A list of bounding boxes for filtering the items. + Each bounding box is represented as a list of four float values: [min_lon, min_lat, max_lon, max_lat]. + Example: [[-180.0, -90.0, 180.0, 90.0]] # This bounding box covers the entire world. + + - datetime (Optional[Union[str, python_datetime.datetime, list]]): A single datetime, RFC 3339-compliant timestamp, + or a list of date-time ranges for filtering the items. The datetime can be specified in various formats: + - "YYYY" expands to ["YYYY-01-01T00:00:00Z", "YYYY-12-31T23:59:59Z"] + - "YYYY-mm" expands to ["YYYY-mm-01T00:00:00Z", "YYYY-mm-T23:59:59Z"] + - "YYYY-mm-dd" expands to ["YYYY-mm-ddT00:00:00Z", "YYYY-mm-ddT23:59:59Z"] + - "YYYY-mm-ddTHH:MM:SSZ" remains as ["YYYY-mm-ddTHH:MM:SSZ", "YYYY-mm-ddTHH:MM:SSZ"] + - A list of date-time ranges can be provided for multiple intervals. + Example: "2020-01-01T00:00:00Z" or python_datetime.datetime(2020, 1, 1) or [["2020-01-01T00:00:00Z", "2021-01-01T00:00:00Z"]] + + - max_items (Optional[int]): The maximum number of items to return from the search, even if there are more matching results. + Example: 100 + + - return_dataframe (bool): If True, return the result as a Spark DataFrame instead of an iterator of PyStacItem objects. + Example: True + + Returns: + - Union[Iterator[PyStacItem], DataFrame]: An iterator of PyStacItem objects or a Spark DataFrame that match the specified filters. + """ + client = self.get_collection(collection_id) + if return_dataframe: + return client.get_dataframe( + *ids, bbox=bbox, datetime=datetime, max_items=max_items + ) + else: + return client.get_items( + *ids, bbox=bbox, datetime=datetime, max_items=max_items + ) diff --git a/python/sedona/stac/collection_client.py b/python/sedona/stac/collection_client.py new file mode 100644 index 0000000000..b1cae6df39 --- /dev/null +++ b/python/sedona/stac/collection_client.py @@ -0,0 +1,398 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import logging +from typing import Iterator, Union +from typing import Optional + +import datetime as python_datetime +from pyspark.sql import DataFrame, SparkSession +from pyspark.sql.types import dt +from pystac import Item as PyStacItem + + +def get_collection_url(url: str, collection_id: Optional[str] = None) -> str: + """ + Constructs the collection URL based on the provided base URL and optional collection ID. + + If the collection ID is provided and the URL starts with 'http' or 'https', the collection ID + is appended to the URL. Otherwise, an exception is raised. + + Parameters: + - url (str): The base URL of the STAC collection. + - collection_id (Optional[str]): The optional collection ID to append to the URL. + + Returns: + - str: The constructed collection URL. + + Raises: + - ValueError: If the URL does not start with 'http' or 'https' and a collection ID is provided. + """ + if not collection_id: + return url + elif url.startswith("http") or url.startswith("https"): + return f"{url}/collections/{collection_id}" + else: + raise ValueError( + "Collection ID is not used because the URL does not start with http or https" + ) + + +class CollectionClient: + def __init__(self, url: str, collection_id: Optional[str] = None): + self.url = url + self.collection_id = collection_id + self.collection_url = get_collection_url(url, collection_id) + self.spark = SparkSession.getActiveSession() + + @staticmethod + def _move_attributes_to_properties(item_dict: dict) -> dict: + """ + Moves specified attributes from the item dictionary to the 'properties' field. + + This method ensures that certain attributes are nested under the 'properties' key + in the item dictionary. If the 'properties' key does not exist, it is initialized. + + Parameters: + - item_dict (dict): The dictionary representation of a STAC item. + + Returns: + - dict: The updated item dictionary with specified attributes moved to 'properties'. + """ + # List of attributes to move to 'properties' + attributes_to_move = [ + "title", + "description", + "keywords", + "datetime", + "start_datetime", + "end_datetime", + "created", + "instruments", + "statistics", + "platform", + "gsd", + ] + + # Initialize 'properties' if it doesn't exist + if "properties" not in item_dict: + item_dict["properties"] = {} + + # Move the specified attributes to 'properties' + for attr in attributes_to_move: + if attr in item_dict: + item_dict["properties"][attr] = str(item_dict.pop(attr)) + + return item_dict + + @staticmethod + def _apply_spatial_temporal_filters( + df: DataFrame, bbox=None, datetime=None + ) -> DataFrame: + """ + This function applies spatial and temporal filters to a Spark DataFrame. + + Parameters: + - df (DataFrame): The input Spark DataFrame to be filtered. + - bbox (Optional[list]): A list of bounding boxes for filtering the items. + Each bounding box is represented as a list of four float values: [min_lon, min_lat, max_lon, max_lat]. + Example: [[-180.0, -90.0, 180.0, 90.0]] # This bounding box covers the entire world. + - datetime (Optional[list]): A list of date-time ranges for filtering the items. + Each date-time range is represented as a list of two strings in ISO 8601 format: [start_datetime, end_datetime]. + Example: [["2020-01-01T00:00:00Z", "2021-01-01T00:00:00Z"]] # This interval covers the entire year of 2020. + + Returns: + - DataFrame: The filtered Spark DataFrame. + + The function constructs SQL conditions for spatial and temporal filters and applies them to the DataFrame. + If bbox is provided, it constructs spatial conditions using st_contains and ST_GeomFromText. + If datetime is provided, it constructs temporal conditions using the datetime column. + The conditions are combined using OR logic. + """ + if bbox: + bbox_conditions = [] + for bbox in bbox: + polygon_wkt = ( + f"POLYGON(({bbox[0]} {bbox[1]}, {bbox[2]} {bbox[1]}, " + f"{bbox[2]} {bbox[3]}, {bbox[0]} {bbox[3]}, {bbox[0]} {bbox[1]}))" + ) + bbox_conditions.append( + f"st_contains(ST_GeomFromText('{polygon_wkt}'), geometry)" + ) + bbox_sql_condition = " OR ".join(bbox_conditions) + df = df.filter(bbox_sql_condition) + + if datetime: + interval_conditions = [] + for interval in datetime: + interval_conditions.append( + f"datetime BETWEEN '{interval[0]}' AND '{interval[1]}'" + ) + interval_sql_condition = " OR ".join(interval_conditions) + df = df.filter(interval_sql_condition) + + return df + + @staticmethod + def _expand_date(date_str): + """ + Expands a simple date string to include the entire time period. + + This function takes a date string in one of the following formats: + - YYYY + - YYYY-mm + - YYYY-mm-dd + - YYYY-mm-ddTHH:MM:SSZ + + It then expands the date string to cover the entire time period for that date. + + Parameters: + - date_str (str): The date string to expand. + + Returns: + - list: A list containing the start and end datetime strings in ISO 8601 format. + + Raises: + - ValueError: If the date string format is invalid. + + Examples: + - "2017" expands to ["2017-01-01T00:00:00Z", "2017-12-31T23:59:59Z"] + - "2017-06" expands to ["2017-06-01T00:00:00Z", "2017-06-30T23:59:59Z"] + - "2017-06-10" expands to ["2017-06-10T00:00:00Z", "2017-06-10T23:59:59Z"] + - "2017-06-01T00:00:00Z" remains as ["2017-06-01T00:00:00Z", "2017-06-01T00:00:00Z"] + """ + if len(date_str) == 4: # YYYY + return [f"{date_str}-01-01T00:00:00Z", f"{date_str}-12-31T23:59:59Z"] + elif len(date_str) == 7: # YYYY-mm + year, month = date_str.split("-") + last_day = (dt(int(year), int(month) + 1, 1) - dt.timedelta(days=1)).day + return [f"{date_str}-01T00:00:00Z", f"{date_str}-{last_day}T23:59:59Z"] + elif len(date_str) == 10: # YYYY-mm-dd + return [f"{date_str}T00:00:00Z", f"{date_str}T23:59:59Z"] + elif len(date_str) == 19: # YYYY-mm-ddTHH:MM:SS + return [date_str, date_str] + elif len(date_str) == 20: # YYYY-mm-ddTHH:MM:SSZ + return [date_str, date_str] + else: + raise ValueError("Invalid date format") + + def get_items( + self, + *ids: Union[str, list], + bbox: Optional[list] = None, + datetime: Optional[Union[str, python_datetime.datetime, list]] = None, + max_items: Optional[int] = None, + ) -> Iterator[PyStacItem]: + """ + Returns an iterator of items. Each item has the supplied item ID and/or optional spatial and temporal extents. + + This method loads the collection data from the specified collection URL and applies + optional filters to the data. The filters include: + - IDs: A list of item IDs to filter the items. If not provided, no ID filtering is applied. + - bbox (Optional[list]): A list of bounding boxes for filtering the items. + - datetime (Optional[Union[str, python_datetime.datetime, list]]): A single datetime, RFC 3339-compliant timestamp, + or a list of date-time ranges for filtering the items. + - max_items (Optional[int]): The maximum number of items to return from the search, even if there are more matching results. + + Returns: + - Iterator[PyStacItem]: An iterator of PyStacItem objects that match the specified filters. + If no filters are provided, the iterator contains all items in the collection. + + Raises: + - RuntimeError: If there is an error loading the data or applying the filters, a RuntimeError + is raised with a message indicating the failure. + """ + try: + # Load the collection data from the specified collection URL + df = self.spark.read.format("stac").load(self.collection_url) + + # Apply ID filters if provided + if ids: + if isinstance(ids, tuple): + ids = list(ids) + if isinstance(ids, str): + ids = [ids] + df = df.filter(df.id.isin(ids)) + + # Ensure bbox is a list of lists + if bbox and isinstance(bbox[0], float): + bbox = [bbox] + + # Handle datetime parameter + if datetime: + if isinstance(datetime, (str, python_datetime.datetime)): + datetime = [self._expand_date(str(datetime))] + elif isinstance(datetime, list) and isinstance(datetime[0], str): + datetime = [datetime] + + # Apply spatial and temporal filters + df = self._apply_spatial_temporal_filters(df, bbox, datetime) + + # Limit the number of items if max_items is specified + if max_items is not None: + df = df.limit(max_items) + + # Collect the filtered rows and convert them to PyStacItem objects + items = [] + for row in df.collect(): + row_dict = row.asDict(True) + row_dict = self._move_attributes_to_properties(row_dict) + items.append(PyStacItem.from_dict(row_dict)) + + # Return an iterator of the items + return iter(items) + except Exception as e: + # Log the error and raise a RuntimeError + logging.error(f"Error getting items: {e}") + raise RuntimeError("Failed to get items") from e + + def get_dataframe( + self, + *ids: Union[str, list], + bbox: Optional[list] = None, + datetime: Optional[Union[str, python_datetime.datetime, list]] = None, + max_items: Optional[int] = None, + ) -> DataFrame: + """ + Returns a Spark DataFrame of items with optional spatial and temporal extents. + + This method loads the collection data from the specified collection URL and applies + optional spatial and temporal filters to the data. The spatial filter is applied using + a bounding box, and the temporal filter is applied using a date-time range. + + Parameters: + - bbox (Optional[list]): A list of bounding boxes for filtering the items. + Each bounding box is represented as a list of four float values: [min_lon, min_lat, max_lon, max_lat]. + Example: [[-180.0, -90.0, 180.0, 90.0]] # This bounding box covers the entire world. + - datetime (Optional[Union[str, python_datetime.datetime, list]]): A single datetime, RFC 3339-compliant timestamp, + or a list of date-time ranges for filtering the items. + Example: "2020-01-01T00:00:00Z" or python_datetime.datetime(2020, 1, 1) or [["2020-01-01T00:00:00Z", "2021-01-01T00:00:00Z"]] + + Returns: + - DataFrame: A Spark DataFrame containing the filtered items. If no filters are provided, + the DataFrame contains all items in the collection. + + Raises: + - RuntimeError: If there is an error loading the data or applying the filters, a RuntimeError + is raised with a message indicating the failure. + """ + try: + df = self.spark.read.format("stac").load(self.collection_url) + + # Apply ID filters if provided + if ids: + if isinstance(ids, tuple): + ids = list(ids) + if isinstance(ids, str): + ids = [ids] + df = df.filter(df.id.isin(ids)) + + # Ensure bbox is a list of lists + if bbox and isinstance(bbox[0], float): + bbox = [bbox] + + # Handle datetime parameter + if datetime: + if isinstance(datetime, (str, python_datetime.datetime)): + datetime = [[str(datetime), str(datetime)]] + elif isinstance(datetime, list) and isinstance(datetime[0], str): + datetime = [datetime] + + df = self._apply_spatial_temporal_filters(df, bbox, datetime) + + # Limit the number of items if max_items is specified + if max_items is not None: + df = df.limit(max_items) + + return df + except Exception as e: + logging.error(f"Error getting filtered dataframe: {e}") + raise RuntimeError("Failed to get filtered dataframe") from e + + def save_to_geoparquet( + self, + *ids: Union[str, list], + output_path: str, + bbox: Optional[list] = None, + datetime: Optional[list] = None, + ) -> None: + """ + Loads the STAC DataFrame and saves it to Parquet format at the given output path. + + This method loads the collection data from the specified collection URL and applies + optional spatial and temporal filters to the data. The filtered data is then saved + to the specified output path in Parquet format. + + Parameters: + - output_path (str): The path where the Parquet file will be saved. + - spatial_extent (Optional[SpatialExtent]): A spatial extent object that defines the + bounding box for filtering the items. If not provided, no spatial filtering is applied. + - temporal_extent (Optional[TemporalExtent]): A temporal extent object that defines the + date-time range for filtering the items. If not provided, no temporal filtering is applied. + To match a single datetime, you can set the start and end datetime to the same value in the datetime. + Here is an example: [["2020-01-01T00:00:00Z", "2020-01-01T00:00:00Z"]] + + Raises: + - RuntimeError: If there is an error loading the data, applying the filters, or saving the + DataFrame to Parquet format, a RuntimeError is raised with a message indicating the failure. + """ + try: + df = self.get_dataframe(*ids, bbox=bbox, datetime=datetime) + df_geoparquet = self._convert_assets_schema(df) + df_geoparquet.write.format("geoparquet").save(output_path) + logging.info(f"DataFrame successfully saved to {output_path}") + except Exception as e: + logging.error(f"Error saving DataFrame to GeoParquet: {e}") + raise RuntimeError("Failed to save DataFrame to GeoParquet") from e + + @staticmethod + def _convert_assets_schema(df: DataFrame) -> DataFrame: + """ + Converts the schema of the assets column in the DataFrame to have a consistent structure. + + This function first identifies all unique keys in the assets column and then ensures that + each row in the DataFrame has these keys with appropriate values. + + The expected input schema of the loaded dataframe (df) can be found here: + https://sedona.apache.org/latest-snapshot/api/sql/Stac/#usage + + Parameters: + - df (DataFrame): The input DataFrame with an assets column. + + Returns: + - DataFrame: The DataFrame with a consistent schema for the assets column. + """ + from pyspark.sql.functions import col, explode, struct + + # Explode the assets column to get all unique keys and their corresponding value struct + exploded_df = df.select(explode("assets").alias("key", "value")) + unique_keys = [ + row["key"] for row in exploded_df.select("key").distinct().collect() + ] + + # Create a new schema with all unique keys and their value struct + new_schema = struct( + [struct(col(f"assets.{key}.*")).alias(key) for key in unique_keys] + ) + + # Apply the new schema to the assets column + df = df.withColumn("assets", new_schema) + + return df + + def __str__(self): + return f"" diff --git a/python/tests/stac/__init__.py b/python/tests/stac/__init__.py new file mode 100644 index 0000000000..a67d5ea255 --- /dev/null +++ b/python/tests/stac/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/python/tests/stac/test_client.py b/python/tests/stac/test_client.py new file mode 100644 index 0000000000..5c6192258a --- /dev/null +++ b/python/tests/stac/test_client.py @@ -0,0 +1,146 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from sedona.stac.client import Client +from pyspark.sql import DataFrame + +from tests.test_base import TestBase + +STAC_URLS = { + "PLANETARY-COMPUTER": "https://planetarycomputer.microsoft.com/api/stac/v1" +} + + +class TestStacClient(TestBase): + def test_collection_client(self) -> None: + client = Client.open(STAC_URLS["PLANETARY-COMPUTER"]) + items = client.search( + collection_id="aster-l1t", + bbox=[-100.0, -72.0, 105.0, -69.0], + datetime=["2006-01-01T00:00:00Z", "2007-01-01T00:00:00Z"], + return_dataframe=False, + ) + assert items is not None + assert len(list(items)) == 2 + + def test_search_with_ids(self) -> None: + client = Client.open(STAC_URLS["PLANETARY-COMPUTER"]) + items = client.search( + *["AST_L1T_00312272006020322_20150518201805", "item2"], + collection_id="aster-l1t", + return_dataframe=False, + ) + assert items is not None + assert len(list(items)) == 1 + + def test_search_with_single_id(self) -> None: + client = Client.open(STAC_URLS["PLANETARY-COMPUTER"]) + items = client.search( + "AST_L1T_00312272006020322_20150518201805", + collection_id="aster-l1t", + return_dataframe=False, + ) + assert items is not None + assert len(list(items)) == 1 + + def test_search_with_bbox_and_datetime(self) -> None: + client = Client.open(STAC_URLS["PLANETARY-COMPUTER"]) + items = client.search( + collection_id="aster-l1t", + bbox=[-180.0, -90.0, 180.0, 90.0], + datetime=["2006-01-01T00:00:00Z", "2007-01-01T00:00:00Z"], + return_dataframe=False, + ) + assert items is not None + assert len(list(items)) > 0 + + def test_search_with_multiple_bboxes_and_intervals(self) -> None: + client = Client.open(STAC_URLS["PLANETARY-COMPUTER"]) + items = client.search( + collection_id="aster-l1t", + bbox=[ + [90, -73, 105, -69], + [-180.0, -90.0, -170.0, -80.0], + [-100.0, -72.0, -90.0, -62.0], + ], + datetime=[["2006-12-01T00:00:00Z", "2006-12-27T03:00:00Z"]], + return_dataframe=False, + ) + assert items is not None + assert len(list(items)) == 4 + + def test_search_with_bbox_and_non_overlapping_intervals(self) -> None: + client = Client.open(STAC_URLS["PLANETARY-COMPUTER"]) + items = client.search( + collection_id="aster-l1t", + bbox=[-180.0, -90.0, 180.0, 90.0], + datetime=[ + ["2006-01-01T00:00:00Z", "2006-06-01T00:00:00Z"], + ["2006-07-01T00:00:00Z", "2007-01-01T00:00:00Z"], + ], + return_dataframe=False, + ) + assert items is not None + assert len(list(items)) == 10 + + def test_search_with_max_items(self) -> None: + client = Client.open(STAC_URLS["PLANETARY-COMPUTER"]) + items = client.search( + collection_id="aster-l1t", + bbox=[-180.0, -90.0, 180.0, 90.0], + datetime=["2006-01-01T00:00:00Z", "2007-01-01T00:00:00Z"], + max_items=5, + return_dataframe=False, + ) + assert items is not None + assert len(list(items)) == 5 + + def test_search_with_single_datetime(self) -> None: + from datetime import datetime + + client = Client.open(STAC_URLS["PLANETARY-COMPUTER"]) + items = client.search( + collection_id="aster-l1t", + bbox=[-180.0, -90.0, 180.0, 90.0], + datetime=datetime(2006, 12, 26, 18, 3, 22), + return_dataframe=False, + ) + assert items is not None + assert len(list(items)) == 0 + + def test_search_with_YYYY(self) -> None: + from datetime import datetime + + client = Client.open(STAC_URLS["PLANETARY-COMPUTER"]) + items = client.search( + collection_id="aster-l1t", + bbox=[-180.0, -90.0, 180.0, 90.0], + datetime="2006", + return_dataframe=False, + ) + assert items is not None + assert len(list(items)) == 10 + + def test_search_with_return_dataframe(self) -> None: + client = Client.open(STAC_URLS["PLANETARY-COMPUTER"]) + df = client.search( + collection_id="aster-l1t", + bbox=[-180.0, -90.0, 180.0, 90.0], + datetime=["2006-01-01T00:00:00Z", "2007-01-01T00:00:00Z"], + ) + assert df is not None + assert isinstance(df, DataFrame) diff --git a/python/tests/stac/test_collection_client.py b/python/tests/stac/test_collection_client.py new file mode 100644 index 0000000000..c30105a4eb --- /dev/null +++ b/python/tests/stac/test_collection_client.py @@ -0,0 +1,189 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from sedona.stac.client import Client +from sedona.stac.collection_client import CollectionClient + +from tests.test_base import TestBase + +STAC_URLS = { + "PLANETARY-COMPUTER": "https://planetarycomputer.microsoft.com/api/stac/v1" +} + + +class TestStacReader(TestBase): + def test_collection_client(self) -> None: + client = Client.open(STAC_URLS["PLANETARY-COMPUTER"]) + collection = client.get_collection("aster-l1t") + + assert isinstance(collection, CollectionClient) + assert str(collection) == "" + + def test_get_dataframe_no_filters(self) -> None: + client = Client.open(STAC_URLS["PLANETARY-COMPUTER"]) + collection = client.get_collection("aster-l1t") + df = collection.get_dataframe() + assert df is not None + assert df.count() == 10 + + def test_get_dataframe_with_spatial_extent(self) -> None: + client = Client.open(STAC_URLS["PLANETARY-COMPUTER"]) + collection = client.get_collection("aster-l1t") + bbox = [[-180.0, -90.0, 180.0, 90.0]] + df = collection.get_dataframe(bbox=bbox) + assert df is not None + assert df.count() > 0 + + def test_get_dataframe_with_temporal_extent(self) -> None: + client = Client.open(STAC_URLS["PLANETARY-COMPUTER"]) + collection = client.get_collection("aster-l1t") + datetime = [["2006-01-01T00:00:00Z", "2007-01-01T00:00:00Z"]] + df = collection.get_dataframe(datetime=datetime) + assert df is not None + assert df.count() > 0 + + def test_get_dataframe_with_both_extents(self) -> None: + client = Client.open(STAC_URLS["PLANETARY-COMPUTER"]) + collection = client.get_collection("aster-l1t") + bbox = [[-180.0, -90.0, 180.0, 90.0]] + datetime = [["2006-01-01T00:00:00Z", "2007-01-01T00:00:00Z"]] + df = collection.get_dataframe(bbox=bbox, datetime=datetime) + assert df is not None + assert df.count() > 0 + + def test_get_items_with_spatial_extent(self) -> None: + client = Client.open(STAC_URLS["PLANETARY-COMPUTER"]) + collection = client.get_collection("aster-l1t") + bbox = [[-100.0, -72.0, 105.0, -69.0]] + items = list(collection.get_items(bbox=bbox)) + assert items is not None + assert len(items) == 2 + + def test_get_items_with_temporal_extent(self) -> None: + client = Client.open(STAC_URLS["PLANETARY-COMPUTER"]) + collection = client.get_collection("aster-l1t") + datetime = [["2006-12-01T00:00:00Z", "2006-12-27T02:00:00Z"]] + items = list(collection.get_items(datetime=datetime)) + assert items is not None + assert len(items) == 6 + + def test_get_items_with_both_extents(self) -> None: + client = Client.open(STAC_URLS["PLANETARY-COMPUTER"]) + collection = client.get_collection("aster-l1t") + bbox = [[90, -73, 105, -69]] + datetime = [["2006-12-01T00:00:00Z", "2006-12-27T03:00:00Z"]] + items = list(collection.get_items(bbox=bbox, datetime=datetime)) + assert items is not None + assert len(items) == 4 + + def test_get_items_with_multiple_bboxes_and_interval(self) -> None: + client = Client.open(STAC_URLS["PLANETARY-COMPUTER"]) + collection = client.get_collection("aster-l1t") + bbox = [ + [90, -73, 105, -69], # Bounding box 1 + [ + -180.0, + -90.0, + -170.0, + -80.0, + ], # Bounding box 2 (non-overlapping with bbox 1) + [ + -100.0, + -72.0, + -90.0, + -62.0, + ], # Bounding box 3 (non-overlapping with bbox 1 and 2) + ] + datetime = [["2006-12-01T00:00:00Z", "2006-12-27T03:00:00Z"]] + items = list(collection.get_items(bbox=bbox, datetime=datetime)) + assert items is not None + assert len(items) == 4 + + def test_get_items_with_ids(self) -> None: + client = Client.open(STAC_URLS["PLANETARY-COMPUTER"]) + collection = client.get_collection("aster-l1t") + ids = ["AST_L1T_00312272006020322_20150518201805", "item2", "item3"] + items = list(collection.get_items(*ids)) + assert items is not None + assert len(items) == 1 + for item in items: + assert item.id in ids + + def test_get_items_with_id(self) -> None: + client = Client.open(STAC_URLS["PLANETARY-COMPUTER"]) + collection = client.get_collection("aster-l1t") + items = list(collection.get_items("AST_L1T_00312272006020322_20150518201805")) + assert items is not None + assert len(items) == 1 + + def test_get_items_with_bbox_and_non_overlapping_intervals(self) -> None: + client = Client.open(STAC_URLS["PLANETARY-COMPUTER"]) + collection = client.get_collection("aster-l1t") + bbox = [[-180.0, -90.0, 180.0, 90.0]] + datetime = [ + ["2006-01-01T00:00:00Z", "2006-06-01T00:00:00Z"], + ["2006-07-01T00:00:00Z", "2007-01-01T00:00:00Z"], + ] + items = list(collection.get_items(bbox=bbox, datetime=datetime)) + assert items is not None + assert len(items) == 10 + + def test_get_items_with_bbox_and_interval(self) -> None: + client = Client.open(STAC_URLS["PLANETARY-COMPUTER"]) + collection = client.get_collection("aster-l1t") + bbox = [-180.0, -90.0, 180.0, 90.0] + interval = ["2006-01-01T00:00:00Z", "2007-01-01T00:00:00Z"] + items = list(collection.get_items(bbox=bbox, datetime=interval)) + assert items is not None + assert len(items) > 0 + + def test_get_dataframe_with_bbox_and_interval(self) -> None: + client = Client.open(STAC_URLS["PLANETARY-COMPUTER"]) + collection = client.get_collection("aster-l1t") + bbox = [-180.0, -90.0, 180.0, 90.0] + interval = ["2006-01-01T00:00:00Z", "2007-01-01T00:00:00Z"] + df = collection.get_dataframe(bbox=bbox, datetime=interval) + assert df is not None + assert df.count() > 0 + + def test_save_to_geoparquet(self) -> None: + client = Client.open(STAC_URLS["PLANETARY-COMPUTER"]) + collection = client.get_collection("aster-l1t") + + # Create a temporary directory for the output path and clean it up after the test + import tempfile + + with tempfile.TemporaryDirectory() as tmpdirname: + output_path = f"{tmpdirname}/test_geoparquet_output" + + # Define spatial and temporal extents + bbox = [[-180.0, -90.0, 180.0, 90.0]] + datetime = [["2006-01-01T00:00:00Z", "2007-01-01T00:00:00Z"]] + + # Call the method to save the DataFrame to GeoParquet + collection.save_to_geoparquet( + output_path=output_path, bbox=bbox, datetime=datetime + ) + + # Check if the file was created + import os + + assert os.path.exists(output_path), "GeoParquet file was not created" + + # Optionally, you can load the file back and check its contents + df_loaded = collection.spark.read.format("geoparquet").load(output_path) + assert df_loaded.count() == 10, "Loaded GeoParquet file is empty"