Skip to content

Commit

Permalink
RD-10439 data duplication on MySQL.InferAndRead (#338)
Browse files Browse the repository at this point in the history
  • Loading branch information
torcato authored Jan 29, 2024
1 parent 37d8998 commit f7a3df5
Show file tree
Hide file tree
Showing 21 changed files with 96 additions and 24 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -305,4 +305,12 @@ trait MySQLPackageTest extends CompilerTestContext with CredentialsTestContext w
s"""failed to read from database mysql:${mysqlCreds.database}: Table '${mysqlCreds.database}.dont_exist' doesn't exist"""
it should evaluateTo(s"""[3L, Error.Build("$error")]""")
}

test(s"""MySQL.InferAndRead("$mysqlRegDb", "rd10439")""") { it =>
it should evaluateTo("""[
| {id: 1, name: "john", salary: 23.5},
| {id: 2, name: "jane", salary: 30.4},
| {id: 3, name: "bob", salary: 17.8}
|]""".stripMargin)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ abstract class JdbcClient()(implicit settings: RawSettings) extends StrictLoggin

def vendor: String

// Database is optional because some databases do not have the concept of database (Teradata and Sqlite).
def database: Option[String]

// Wrap vendor-specific calls and ensure only RelationalDatabaseException is thrown.
def wrapSQLException[T](f: => T): T

Expand Down Expand Up @@ -101,8 +104,8 @@ abstract class JdbcClient()(implicit settings: RawSettings) extends StrictLoggin
listTables(schema).close()
}

def testAccess(maybeSchema: Option[String], table: String): Unit = {
tableMetadata(maybeSchema, table)
def testAccess(database: Option[String], maybeSchema: Option[String], table: String): Unit = {
tableMetadata(database, maybeSchema, table)
}

def listSchemas: Iterator[String] with Closeable = {
Expand All @@ -121,10 +124,10 @@ abstract class JdbcClient()(implicit settings: RawSettings) extends StrictLoggin
SchemaMetadata()
}

def tableMetadata(maybeSchema: Option[String], table: String): TableMetadata = {
def tableMetadata(database: Option[String], maybeSchema: Option[String], table: String): TableMetadata = {
val conn = getConnection
try {
val res = getTableMetadata(conn, maybeSchema, table)
val res = getTableMetadata(conn, database, maybeSchema, table)
try {
getTableTypeFromTableMetadata(res)
} finally {
Expand All @@ -135,11 +138,16 @@ abstract class JdbcClient()(implicit settings: RawSettings) extends StrictLoggin
}
}

private def getTableMetadata(conn: Connection, maybeSchema: Option[String], table: String): ResultSet = {
private def getTableMetadata(
conn: Connection,
maybeDatabase: Option[String],
maybeSchema: Option[String],
table: String
): ResultSet = {
wrapSQLException {
val metaData = conn.getMetaData
metaData.getColumns(
null, // Database/Catalog is set to null because we assume it is already set as part of the connection string
maybeDatabase.orNull,
maybeSchema.orNull,
table,
null // Read all columns
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ abstract class JdbcTableLocation(
with StrictLogging {

final override def testAccess(): Unit = {
jdbcClient.testAccess(maybeSchema, table)
jdbcClient.testAccess(Some(dbName), maybeSchema, table)
}

final def getType(): TableMetadata = {
jdbcClient.tableMetadata(maybeSchema, table)
jdbcClient.tableMetadata(Some(dbName), maybeSchema, table)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ class MySqlClient(db: MySqlCredential)(implicit settings: RawSettings) extends J

override val hostname: String = db.host

override val database: Option[String] = Some(db.database)

override def wrapSQLException[T](f: => T): T = {
try {
f
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class MySqlLocationBuilder extends JdbcLocationBuilder {
location.url match {
case mysqlTableRegex(dbName) =>
val db = MySqlClients.get(dbName, location)
new MySqlLocation(db, dbName)
new MySqlLocation(db, db.database.get)
case _ => throw new LocationException("not a mysql database location")
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class MySqlSchemaLocationBuilder extends JdbcSchemaLocationBuilder {
location.url match {
case schemaRegex(dbName) =>
val db = MySqlClients.get(dbName, location)
new MySqlSchema(db, dbName)
new MySqlSchema(db, db.database.get)
case _ => throw new LocationException("not a mysql schema location")
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class MySqlTableLocationBuilder extends JdbcTableLocationBuilder {
location.url match {
case mysqlTableRegex(dbName, table) =>
val db = MySqlClients.get(dbName, location)
new MySqlTable(db, dbName, table)
new MySqlTable(db, db.database.get, table)
case _ => throw new LocationException("not a mysql table location")
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ class PostgresqlClient(db: PostgresqlCredential)(implicit settings: RawSettings)

override val hostname: String = db.host

override val database: Option[String] = Some(db.database)
// override val datasource: DataSource = {
// val pgDatasource = new PGSimpleDataSource()
// pgDatasource.setURL(connectionString)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class PostgresqlLocationBuilder extends JdbcLocationBuilder {
location.url match {
case postgresqlDatabaseRegex(dbName) =>
val db = PostgresqlClients.get(dbName, location)
new PostgresqlLocation(db, dbName)
new PostgresqlLocation(db, db.database.get)
case _ => throw new LocationException("not a postgresql database location")
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class PostgresqlSchemaLocationBuilder extends JdbcSchemaLocationBuilder {
location.url match {
case schemaRegex(dbName, schema) =>
val db = PostgresqlClients.get(dbName, location)
new PostgresqlSchema(db, dbName, schema)
new PostgresqlSchema(db, db.database.get, schema)
case _ => throw new LocationException("not a postgresql schema location")
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class PostgresqlTableLocationBuilder extends JdbcTableLocationBuilder {
location.url match {
case postgresqlTableRegex(dbName, schema, table) =>
val db = PostgresqlClients.get(dbName, location)
new PostgresqlTable(db, dbName, schema, table)
new PostgresqlTable(db, db.database.get, schema, table)
case _ => throw new LocationException("not a postgresql location")
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class SnowflakeClient(db: SnowflakeCredential)(implicit settings: RawSettings) e
override val password: Option[String] = db.password

override val hostname: String = db.host

override val database: Option[String] = Some(db.database)
override def getConnection: Connection = {
wrapSQLException {
val parameters = db.parameters ++ Seq("db" -> db.database)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class SnowflakeLocationBuilder extends JdbcLocationBuilder {
location.url match {
case snowflakeTableRegex(dbName) =>
val db = SnowflakeClients.get(dbName, location)
new SnowflakeLocation(db, dbName)
new SnowflakeLocation(db, db.database.get)
case _ => throw new LocationException("not an snowflake database location")
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class SnowflakeSchemaLocationBuilder extends JdbcSchemaLocationBuilder {
location.url match {
case schemaRegex(dbName, schema) =>
val db = SnowflakeClients.get(dbName, location)
new SnowflakeSchema(db, dbName, schema)
new SnowflakeSchema(db, db.database.get, schema)
case _ => throw new LocationException("not an snowflake schema location")
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class SnowflakeTableLocationBuilder extends JdbcTableLocationBuilder {
location.url match {
case snowflakeTableRegex(dbName, schema, table) =>
val db = SnowflakeClients.get(dbName, location)
new SnowflakeTable(db, dbName, schema, table)
new SnowflakeTable(db, db.database.get, schema, table)
case _ => throw new LocationException("not an snowflake table location")
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,10 @@

package raw.sources.jdbc.sqlite

import java.nio.file.{InvalidPathException, Path}
import raw.sources.jdbc.api.{AuthenticationFailedException, JdbcClient, JdbcLocationException}
import raw.sources.jdbc.api._
import raw.utils.RawSettings

import java.nio.file.{InvalidPathException, Path}
import java.sql.SQLException
import scala.util.control.NonFatal

Expand All @@ -41,6 +41,7 @@ class SqliteClient(path: Path)(implicit settings: RawSettings) extends JdbcClien
override val connectionString: String = s"jdbc:$vendor:$sqlitePath"
override val username: Option[String] = None
override val password: Option[String] = None
override val database: Option[String] = None

override val hostname: String = path.toAbsolutePath.toString

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class SqlServerClient(protected val db: SqlServerCredential)(implicit settings:
override val password: Option[String] = db.password

override val hostname: String = db.host

override val database: Option[String] = Some(db.database)
override def wrapSQLException[T](f: => T): T = {
try {
f
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class SqlServerLocationBuilder extends JdbcLocationBuilder {
location.url match {
case sqlServerTableRegex(dbName) =>
val db = SqlServerClients.get(dbName, location)
new SqlServerLocation(db, dbName)
new SqlServerLocation(db, db.database.get)
case _ => throw new LocationException("not a sqlserver database location")
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class SqlServerSchemaLocationBuilder extends JdbcSchemaLocationBuilder {
location.url match {
case schemaRegex(dbName, schema) =>
val db = SqlServerClients.get(dbName, location)
new SqlServerSchema(db, dbName, schema)
new SqlServerSchema(db, db.database.get, schema)
case _ => throw new LocationException("not a sqlserver schema location")
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class SqlServerTableLocationBuilder extends JdbcTableLocationBuilder {
location.url match {
case sqlServerTableRegex(dbName, schema, table) =>
val db = SqlServerClients.get(dbName, location)
new SqlServerTable(db, dbName, schema, table)
new SqlServerTable(db, db.database.get, schema, table)
case _ => throw new LocationException("not a sqlserver table location")
}
}
Expand Down
52 changes: 52 additions & 0 deletions snapi-frontend/src/test/scala/raw/inferrer/local/RD10439.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
/*
* Copyright 2023 RAW Labs S.A.
*
* Use of this software is governed by the Business Source License
* included in the file licenses/BSL.txt.
*
* As of the Change Date specified in that file, in accordance with
* the Business Source License, use of this software will be governed
* by the Apache License, Version 2.0, included in the file
* licenses/APL.txt.
*/

package raw.inferrer.local

import com.typesafe.scalalogging.StrictLogging
import raw.client.api.{LocationDescription, LocationSettingKey, LocationStringSetting}
import raw.creds.api.MySqlCredential
import raw.inferrer.api._
import raw.inferrer.local.jdbc.JdbcInferrer
import raw.sources.api.SourceContext
import raw.sources.jdbc.api.JdbcTableLocationProvider
import raw.sources.jdbc.mysql.{MySqlClient, MySqlTable}
import raw.utils.{RawTestSuite, SettingsTestContext}

class RD10439 extends RawTestSuite with SettingsTestContext with StrictLogging {

val mysqlHostname: String = sys.env("RAW_MYSQL_TEST_HOST")
val mysqlDb: String = sys.env("RAW_MYSQL_TEST_DB")
val mysqlUsername: String = sys.env("RAW_MYSQL_TEST_USER")
val mysqlPassword: String = sys.env("RAW_MYSQL_TEST_PASSWORD")

test("infer mysql table which is repeated in another database") { _ =>
val mysqlCreds = MySqlCredential(mysqlHostname, None, mysqlDb, Some(mysqlUsername), Some(mysqlPassword))
val inferrer = new JdbcInferrer()
val client = new MySqlClient(mysqlCreds)
val location = new MySqlTable(client, mysqlDb, "rd10439")
val tipe = inferrer.getTableType(location)
logger.info(s"tipe: $tipe")
val expected = SourceCollectionType(
SourceRecordType(
Vector(
SourceAttrType("id", SourceIntType(true)),
SourceAttrType("name", SourceStringType(true)),
SourceAttrType("salary", SourceFloatType(true))
),
nullable = false
),
nullable = false
)
assert(tipe == expected)
}
}

0 comments on commit f7a3df5

Please sign in to comment.