Skip to content

Commit

Permalink
port scala test
Browse files Browse the repository at this point in the history
  • Loading branch information
paleolimbot committed Feb 4, 2025
1 parent d709d06 commit f0e70d3
Showing 1 changed file with 21 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ package org.apache.sedona.sql
import org.apache.sedona.core.enums.{GridType, IndexType}
import org.apache.sedona.core.spatialOperator.{JoinQuery, SpatialPredicate}
import org.apache.sedona.core.spatialRDD.CircleRDD
import org.apache.spark.sql.functions.spark_partition_id
import org.apache.spark.sql.Row
import org.apache.spark.sql.sedona_sql.adapters.StructuredAdapter
import org.junit.Assert.assertEquals
Expand Down Expand Up @@ -105,6 +106,25 @@ class structuredAdapterTestScala extends TestBaseScala with GivenWhenThen {
assertEquals(0, spatialRdd.rawSpatialRDD.count())
assertEquals(0, spatialRdd.schema.size)
}
}

it("can convert spatial RDD to Dataframe preserving spatial partitioning") {
var pointCsvDF = sparkSession.read
.format("csv")
.option("delimiter", ",")
.option("header", "false")
.load(csvPointInputLocation)
pointCsvDF.createOrReplaceTempView("pointtable")
var pointDf = sparkSession.sql(
"select ST_Point(cast(pointtable._c0 as Decimal(24,20)), cast(pointtable._c1 as Decimal(24,20))) as arealandmark from pointtable")
var srcRdd = StructuredAdapter.toSpatialRdd(pointDf, "arealandmark")
srcRdd.analyze()
srcRdd.spatialPartitioning(GridType.KDBTREE, 16)
var numSpatialPartitions = srcRdd.spatialPartitionedRDD.getNumPartitions
assert(numSpatialPartitions >= 16)

var partitionedDF = StructuredAdapter.toSpatialPartitionedDf(srcRdd, sparkSession)
val dfPartitions: Long = partitionedDF.select(spark_partition_id).distinct().count()
assert(dfPartitions == numSpatialPartitions)
}
}
}

0 comments on commit f0e70d3

Please sign in to comment.