From da9eac982063cdf05a8924110ea4f43c13b3dd7a Mon Sep 17 00:00:00 2001 From: Kevin J Nguyen Date: Tue, 18 Jul 2023 11:18:58 -0500 Subject: [PATCH] Entity Key Filter Integration Test (#520) (#521) Adds a test for the entity key filter --- clients/python/src/fenlmagic/__init__.py | 22 +-- clients/python/src/kaskada/client.py | 6 +- clients/python/src/kaskada/query.py | 16 ++- clients/python/tests/test_query.py | 51 +++++++ crates/sparrow-runtime/src/prepare.rs | 117 ++++++++++++++++ .../integration/api/query_v1_slicing_test.go | 127 +++++++++++++++++- tests/integration/shared/helpers/helpers.go | 3 +- 7 files changed, 313 insertions(+), 29 deletions(-) diff --git a/clients/python/src/fenlmagic/__init__.py b/clients/python/src/fenlmagic/__init__.py index 2ba36f4d9..dabe23d40 100644 --- a/clients/python/src/fenlmagic/__init__.py +++ b/clients/python/src/fenlmagic/__init__.py @@ -9,6 +9,7 @@ import logging import os import sys +from typing import Optional import IPython import pandas @@ -40,12 +41,10 @@ def set_dataframe(self, dataframe: pandas.DataFrame): @magics_class class FenlMagics(Magics): - client = None - - def __init__(self, shell, client): + def __init__(self, shell, client: Optional[client.Client]): super(FenlMagics, self).__init__(shell) - self.client = client logger.info("extension loaded") + self.client = client @magic_arguments() @argument( @@ -174,19 +173,8 @@ def fenl(self, arg, cell=None): raise UsageError(e) -def load_ipython_extension(ipython): - if client.KASKADA_DEFAULT_CLIENT is None: - logger.warn( - "No client was initialized. Initializing default client to connect to localhost:50051." - ) - default_client = client.Client( - client_id=os.getenv("KASKADA_CLIENT_ID", None), - endpoint=client.KASKADA_DEFAULT_ENDPOINT, - is_secure=client.KASKADA_IS_SECURE, - ) - client.set_default_client(default_client) - - magics = FenlMagics(ipython, client.get_client()) +def load_ipython_extension(ipython, client: Optional[client.Client] = None): + magics = FenlMagics(ipython, client) ipython.register_magics(magics) diff --git a/clients/python/src/kaskada/client.py b/clients/python/src/kaskada/client.py index b6a3e6855..f8a36d384 100644 --- a/clients/python/src/kaskada/client.py +++ b/clients/python/src/kaskada/client.py @@ -270,10 +270,12 @@ def set_default_slice(slice: SliceFilter): Args: slice (SliceFilter): SliceFilter to set the default """ - logger.debug(f"Default slice set to type {type(slice)}") - global KASKADA_DEFAULT_SLICE KASKADA_DEFAULT_SLICE = slice + if KASKADA_DEFAULT_SLICE is None: + logger.info("Slicing disabled") + else: + logger.info(f"Slicing set to: {slice.to_request()}") def set_default_client(client: Client): diff --git a/clients/python/src/kaskada/query.py b/clients/python/src/kaskada/query.py index b091592e8..054daf9b3 100644 --- a/clients/python/src/kaskada/query.py +++ b/clients/python/src/kaskada/query.py @@ -11,7 +11,7 @@ import kaskada.formatters import kaskada.kaskada.v1alpha.destinations_pb2 as destinations_pb import kaskada.kaskada.v1alpha.query_service_pb2 as query_pb -from kaskada.client import KASKADA_DEFAULT_SLICE, Client, get_client +from kaskada.client import Client, get_client from kaskada.slice_filters import SliceFilter from kaskada.utils import get_timestamp, handleException, handleGrpcError @@ -134,7 +134,19 @@ def create_query( query_pb.CreateQueryResponse """ if slice_filter is None: - slice_filter = KASKADA_DEFAULT_SLICE + """ + Subtle Python Implementation Note: + + The KASKADA_DEFAULT_SLICE is a global variable that varies at runtime. Users can set the default slice at any point. + The value of the slice is evaluated at execution time of this method. + + Incorrect: from kaskada.client import KASKADA_DEFAULT_SLICE + This value is evaluated once at the import of the query module. + + Correct: import kaskada.client + The value is then fetched from the module every time a query is invoked. + """ + slice_filter = kaskada.client.KASKADA_DEFAULT_SLICE change_since_time = get_timestamp(changed_since_time) final_result_time = get_timestamp(final_result_time) diff --git a/clients/python/tests/test_query.py b/clients/python/tests/test_query.py index 3f49d6c7c..d80fbbad7 100644 --- a/clients/python/tests/test_query.py +++ b/clients/python/tests/test_query.py @@ -8,6 +8,7 @@ import kaskada.kaskada.v1alpha.destinations_pb2 as destinations_pb import kaskada.kaskada.v1alpha.query_service_pb2 as query_pb import kaskada.query +from kaskada.slice_filters import EntityPercentFilter """ def create_query( @@ -48,6 +49,56 @@ def test_create_query_with_defaults(mockClient): ) +@patch("kaskada.client.Client") +def test_query_uses_client_global_slice_filter(mockClient): + filter_percentage = 65 + entity_filter = EntityPercentFilter(filter_percentage) + kaskada.client.set_default_slice(entity_filter) + expression = "test_with_defaults" + expected_request = query_pb.CreateQueryRequest( + query=query_pb.Query( + expression=expression, + destination={ + "object_store": destinations_pb.ObjectStoreDestination( + file_type=common_pb.FILE_TYPE_PARQUET + ) + }, + result_behavior="RESULT_BEHAVIOR_ALL_RESULTS", + slice=common_pb.SliceRequest( + percent=common_pb.SliceRequest.PercentSlice(percent=65), + ), + ), + query_options=query_pb.QueryOptions(presign_results=True), + ) + kaskada.query.create_query(expression, client=mockClient) + mockClient.query_stub.CreateQuery.assert_called_with( + expected_request, metadata=mockClient.get_metadata() + ) + + filter_percentage = 10 + entity_filter = EntityPercentFilter(filter_percentage) + kaskada.client.set_default_slice(entity_filter) + expected_request = query_pb.CreateQueryRequest( + query=query_pb.Query( + expression=expression, + destination={ + "object_store": destinations_pb.ObjectStoreDestination( + file_type=common_pb.FILE_TYPE_PARQUET + ) + }, + result_behavior="RESULT_BEHAVIOR_ALL_RESULTS", + slice=common_pb.SliceRequest( + percent=common_pb.SliceRequest.PercentSlice(percent=10), + ), + ), + query_options=query_pb.QueryOptions(presign_results=True), + ) + kaskada.query.create_query(expression, client=mockClient) + mockClient.query_stub.CreateQuery.assert_called_with( + expected_request, metadata=mockClient.get_metadata() + ) + + @patch("kaskada.client.Client") def test_get_query(mockClient): query_id = "12345" diff --git a/crates/sparrow-runtime/src/prepare.rs b/crates/sparrow-runtime/src/prepare.rs index 204e4b7ed..9e05ffe46 100644 --- a/crates/sparrow-runtime/src/prepare.rs +++ b/crates/sparrow-runtime/src/prepare.rs @@ -278,7 +278,11 @@ async fn reader_from_csv<'a, R: std::io::Read + std::io::Seek + Send + 'static>( #[cfg(test)] mod tests { + use std::sync::Arc; + + use arrow::datatypes::{DataType, Field, Schema}; use futures::StreamExt; + use sparrow_api::kaskada::v1alpha::slice_plan::{EntityKeysSlice, Slice}; use sparrow_api::kaskada::v1alpha::{source_data, SourceData, TableConfig}; use uuid::Uuid; @@ -409,4 +413,117 @@ mod tests { let _prepared_schema = prepared_batch.schema(); let _metadata_schema = metadata.schema(); } + + #[tokio::test] + async fn test_preparation_single_entity_key_slicing() { + let entity_keys = vec!["0b00083c-5c1e-47f5-abba-f89b12ae3cf4".to_owned()]; + let slice = Some(Slice::EntityKeys(EntityKeysSlice { entity_keys })); + test_slicing_config(&slice, 23, 1).await; + } + + #[tokio::test] + async fn test_preparation_no_matching_entity_key_slicing() { + let entity_keys = vec!["some-random-invalid-entity-key".to_owned()]; + let slice = Some(Slice::EntityKeys(EntityKeysSlice { entity_keys })); + test_slicing_config(&slice, 0, 0).await; + } + + #[tokio::test] + async fn test_preparation_multiple_matching_entity_key_slicing() { + let entity_keys = vec![ + "0b00083c-5c1e-47f5-abba-f89b12ae3cf4".to_owned(), + "8a16beda-c07a-4625-a805-2d28f5934107".to_owned(), + ]; + let slice = Some(Slice::EntityKeys(EntityKeysSlice { entity_keys })); + test_slicing_config(&slice, 41, 2).await; + } + + #[tokio::test] + async fn test_slicing_issue() { + let input_path = sparrow_testing::testdata_path("transactions/transactions_part1.parquet"); + + let input_path = + source_data::Source::ParquetPath(format!("file:///{}", input_path.display())); + let source_data = SourceData { + source: Some(input_path), + }; + + let table_config = TableConfig::new_with_table_source( + "transactions_slicing", + &Uuid::new_v4(), + "transaction_time", + Some("idx"), + "purchaser_id", + "", + ); + + let entity_keys = vec!["2798e270c7cab8c9eeacc046a3100a57".to_owned()]; + let slice = Some(Slice::EntityKeys(EntityKeysSlice { entity_keys })); + + let prepared_batches = super::prepared_batches( + &ObjectStoreRegistry::default(), + &source_data, + &table_config, + &slice, + ) + .await + .unwrap() + .collect::>() + .await; + assert_eq!(prepared_batches.len(), 1); + let (prepared_batch, metadata) = prepared_batches[0].as_ref().unwrap(); + assert_eq!(prepared_batch.num_rows(), 300); + let _prepared_schema = prepared_batch.schema(); + assert_metadata_schema_eq(metadata.schema()); + assert_eq!(metadata.num_rows(), 1); + } + + async fn test_slicing_config( + slice: &Option, + num_prepared_rows: usize, + num_metadata_rows: usize, + ) { + let input_path = sparrow_testing::testdata_path("eventdata/sample_event_data.parquet"); + + let input_path = + source_data::Source::ParquetPath(format!("file:///{}", input_path.display())); + let source_data = SourceData { + source: Some(input_path), + }; + + let table_config = TableConfig::new_with_table_source( + "Event", + &Uuid::new_v4(), + "timestamp", + Some("subsort_id"), + "anonymousId", + "user", + ); + + let prepared_batches = super::prepared_batches( + &ObjectStoreRegistry::default(), + &source_data, + &table_config, + slice, + ) + .await + .unwrap() + .collect::>() + .await; + assert_eq!(prepared_batches.len(), 1); + let (prepared_batch, metadata) = prepared_batches[0].as_ref().unwrap(); + assert_eq!(prepared_batch.num_rows(), num_prepared_rows); + let _prepared_schema = prepared_batch.schema(); + assert_metadata_schema_eq(metadata.schema()); + assert_eq!(metadata.num_rows(), num_metadata_rows); + } + + fn assert_metadata_schema_eq(metadata_schema: Arc) { + let fields = vec![ + Field::new("_hash", DataType::UInt64, false), + Field::new("_entity_key", DataType::Utf8, true), + ]; + let schema = Arc::new(Schema::new(fields)); + assert_eq!(metadata_schema, schema); + } } diff --git a/tests/integration/api/query_v1_slicing_test.go b/tests/integration/api/query_v1_slicing_test.go index 7dbeee619..bfeaf425b 100644 --- a/tests/integration/api/query_v1_slicing_test.go +++ b/tests/integration/api/query_v1_slicing_test.go @@ -50,7 +50,7 @@ max_spent_in_single_transaction: max(transactions_slicing.price * transactions_s table = &v1alpha.Table{ TableName: tableName, TimeColumnName: "transaction_time", - EntityKeyColumnName: "id", + EntityKeyColumnName: "purchaser_id", SubsortColumnName: &wrapperspb.StringValue{ Value: "idx", }, @@ -199,12 +199,21 @@ max_spent_in_single_transaction: max(transactions_slicing.price * transactions_s helpers.LogLn(fmt.Sprintf("Result set size, with 10%% slice: %d", len(results))) - Expect(len(results)).Should(BeNumerically("~", 5000, 250)) + /* + * There are 150 unique entities in this dataset. + * Each entity has an average of 333.3 events. + * The total dataset size is 50,000 (150 * 333.3 = 49,995) + * Assuming uniform distribution (not entirely true), then 10% of the entities = 15 entites + * Since the 10% slice is based on a hashing function, there is some room for error. + * Random Heuristic: Lower Bound -> 7% (10 entities) and Upper Bound -> 13% (20 entities) + * Lower Bound: 3333 (10 * 333.3) and Upper Bound: 6666 (20 * 333.3) + */ + Expect(len(results)).Should(BeNumerically("~", 5000, 1666)) }) }) - Describe("Run the query with a 0.1% slice", func() { - It("should return about 0.1% of the results", func() { + Describe("Run the query with a 0.3% slice", func() { + It("should return about 0.3% of the results", func() { destination := &v1alpha.Destination{} destination.Destination = &v1alpha.Destination_ObjectStore{ ObjectStore: &v1alpha.ObjectStoreDestination{ @@ -219,7 +228,7 @@ max_spent_in_single_transaction: max(transactions_slicing.price * transactions_s Slice: &v1alpha.SliceRequest{ Slice: &v1alpha.SliceRequest_Percent{ Percent: &v1alpha.SliceRequest_PercentSlice{ - Percent: 0.1, + Percent: 0.3, }, }, }, @@ -243,9 +252,113 @@ max_spent_in_single_transaction: max(transactions_slicing.price * transactions_s resultsUrl := res.GetDestination().GetObjectStore().GetOutputPaths().Paths[0] results := helpers.DownloadParquet(resultsUrl) - helpers.LogLn(fmt.Sprintf("Result set size, with 0.1%% slice: %d", len(results))) + helpers.LogLn(fmt.Sprintf("Result set size, with 0.3%% slice: %d", len(results))) + + /* + * There are 150 unique entities in this dataset. + * Each entity has an average of 333.3 events. + * The total dataset size is 50,000 (150 * 333.3 = 49,995) + * Assuming uniform distribution (not entirely true), then 0.03% of the entities = ~1 entites (0.45 entities) + * Since the 0.3% slice is based on a hashing function, there is some room for error. + * Random Heuristic: Lower Bound -> 0% (0 entities) and Upper Bound -> 1% (1.5 entities) + * Lower Bound: 0 (0 * 333.3) and Upper Bound: 499.95 (1.5 * 333.3) + */ + Expect(len(results)).Should(BeNumerically("~", 250, 125)) + }) + }) + + Describe("Run the query with entity key filter", func() { + It("should return 300 number of results for single", func() { + destination := &v1alpha.Destination{} + destination.Destination = &v1alpha.Destination_ObjectStore{ + ObjectStore: &v1alpha.ObjectStoreDestination{ + FileType: v1alpha.FileType_FILE_TYPE_PARQUET, + }, + } + createQueryRequest := &v1alpha.CreateQueryRequest{ + Query: &v1alpha.Query{ + Expression: expression, + Destination: destination, + ResultBehavior: v1alpha.Query_RESULT_BEHAVIOR_ALL_RESULTS, + Slice: &v1alpha.SliceRequest{ + Slice: &v1alpha.SliceRequest_EntityKeys{ + EntityKeys: &v1alpha.SliceRequest_EntityKeysSlice{ + EntityKeys: []string{ + "2798e270c7cab8c9eeacc046a3100a57", + }, + }, + }, + }, + }, + QueryOptions: &v1alpha.QueryOptions{ + PresignResults: true, + }, + } + stream, err := queryClient.CreateQuery(ctx, createQueryRequest) + Expect(err).ShouldNot(HaveOccurredGrpc()) + Expect(stream).ShouldNot(BeNil()) + + res, err := helpers.GetMergedCreateQueryResponse(stream) + Expect(err).ShouldNot(HaveOccurred()) + Expect(res).ShouldNot(BeNil()) + Expect(res.RequestDetails.RequestId).ShouldNot(BeEmpty()) + Expect(res.GetDestination().GetObjectStore().GetOutputPaths().GetPaths()).ShouldNot(BeNil()) + Expect(res.GetDestination().GetObjectStore().GetOutputPaths().Paths).Should(HaveLen(1)) + + resultsUrl := res.GetDestination().GetObjectStore().GetOutputPaths().Paths[0] + results := helpers.DownloadParquet(resultsUrl) + + helpers.LogLn(fmt.Sprintf("Result set size, with entity key filter: %d", len(results))) + + // There are two rows in this dataset with the provided entity keys. + Expect(len(results)).Should(BeNumerically("~", 300)) + }) + + It("should return 685 (300 + 385) number of results for multiple entity keys", func() { + destination := &v1alpha.Destination{} + destination.Destination = &v1alpha.Destination_ObjectStore{ + ObjectStore: &v1alpha.ObjectStoreDestination{ + FileType: v1alpha.FileType_FILE_TYPE_PARQUET, + }, + } + createQueryRequest := &v1alpha.CreateQueryRequest{ + Query: &v1alpha.Query{ + Expression: expression, + Destination: destination, + ResultBehavior: v1alpha.Query_RESULT_BEHAVIOR_ALL_RESULTS, + Slice: &v1alpha.SliceRequest{ + Slice: &v1alpha.SliceRequest_EntityKeys{ + EntityKeys: &v1alpha.SliceRequest_EntityKeysSlice{ + EntityKeys: []string{ + "2798e270c7cab8c9eeacc046a3100a57", + "79b3ced09d3df7c98fbb04fdfda6ce80", + }, + }, + }, + }, + }, + QueryOptions: &v1alpha.QueryOptions{ + PresignResults: true, + }, + } + stream, err := queryClient.CreateQuery(ctx, createQueryRequest) + Expect(err).ShouldNot(HaveOccurredGrpc()) + Expect(stream).ShouldNot(BeNil()) + + res, err := helpers.GetMergedCreateQueryResponse(stream) + Expect(err).ShouldNot(HaveOccurred()) + Expect(res).ShouldNot(BeNil()) + Expect(res.RequestDetails.RequestId).ShouldNot(BeEmpty()) + Expect(res.GetDestination().GetObjectStore().GetOutputPaths().GetPaths()).ShouldNot(BeNil()) + Expect(res.GetDestination().GetObjectStore().GetOutputPaths().Paths).Should(HaveLen(1)) + + resultsUrl := res.GetDestination().GetObjectStore().GetOutputPaths().Paths[0] + results := helpers.DownloadParquet(resultsUrl) + + helpers.LogLn(fmt.Sprintf("Result set size, with entity key filter: %d", len(results))) - Expect(len(results)).Should(BeNumerically("~", 50, 25)) + // There are two rows in this dataset with the provided entity keys. + Expect(len(results)).Should(BeNumerically("~", 685)) }) }) }) diff --git a/tests/integration/shared/helpers/helpers.go b/tests/integration/shared/helpers/helpers.go index 7801555fe..5c2e4e9e5 100644 --- a/tests/integration/shared/helpers/helpers.go +++ b/tests/integration/shared/helpers/helpers.go @@ -135,7 +135,8 @@ func GetFileURI(fileName string) string { if os.Getenv("ENV") == "local-local" { workDir, err := os.Getwd() Expect(err).ShouldNot(HaveOccurred()) - return fmt.Sprintf("file://%s/../../../testdata/%s", workDir, fileName) + path := filepath.Join(workDir, "../../../testdata", fileName) + return fmt.Sprintf("file://%s", path) } return fmt.Sprintf("file:///testdata/%s", fileName) }