From f0e70d3d8f511c3c9a31e3d6ebcd6d10daaeb6c9 Mon Sep 17 00:00:00 2001 From: Dewey Dunnington Date: Tue, 4 Feb 2025 21:49:07 +0000 Subject: [PATCH] port scala test --- .../sql/structuredAdapterTestScala.scala | 22 ++++++++++++++++++- 1 file changed, 21 insertions(+), 1 deletion(-) diff --git a/spark/common/src/test/scala/org/apache/sedona/sql/structuredAdapterTestScala.scala b/spark/common/src/test/scala/org/apache/sedona/sql/structuredAdapterTestScala.scala index 6938916a6d..d258ce3b40 100644 --- a/spark/common/src/test/scala/org/apache/sedona/sql/structuredAdapterTestScala.scala +++ b/spark/common/src/test/scala/org/apache/sedona/sql/structuredAdapterTestScala.scala @@ -21,6 +21,7 @@ package org.apache.sedona.sql import org.apache.sedona.core.enums.{GridType, IndexType} import org.apache.sedona.core.spatialOperator.{JoinQuery, SpatialPredicate} import org.apache.sedona.core.spatialRDD.CircleRDD +import org.apache.spark.sql.functions.spark_partition_id import org.apache.spark.sql.Row import org.apache.spark.sql.sedona_sql.adapters.StructuredAdapter import org.junit.Assert.assertEquals @@ -105,6 +106,25 @@ class structuredAdapterTestScala extends TestBaseScala with GivenWhenThen { assertEquals(0, spatialRdd.rawSpatialRDD.count()) assertEquals(0, spatialRdd.schema.size) } - } + it("can convert spatial RDD to Dataframe preserving spatial partitioning") { + var pointCsvDF = sparkSession.read + .format("csv") + .option("delimiter", ",") + .option("header", "false") + .load(csvPointInputLocation) + pointCsvDF.createOrReplaceTempView("pointtable") + var pointDf = sparkSession.sql( + "select ST_Point(cast(pointtable._c0 as Decimal(24,20)), cast(pointtable._c1 as Decimal(24,20))) as arealandmark from pointtable") + var srcRdd = StructuredAdapter.toSpatialRdd(pointDf, "arealandmark") + srcRdd.analyze() + srcRdd.spatialPartitioning(GridType.KDBTREE, 16) + var numSpatialPartitions = srcRdd.spatialPartitionedRDD.getNumPartitions + assert(numSpatialPartitions >= 16) + + var partitionedDF = StructuredAdapter.toSpatialPartitionedDf(srcRdd, sparkSession) + val dfPartitions: Long = partitionedDF.select(spark_partition_id).distinct().count() + assert(dfPartitions == numSpatialPartitions) + } + } }