init research
This commit is contained in:
+43
@@ -0,0 +1,43 @@
|
||||
import org.jetbrains.kotlin.gradle.dsl.JvmTarget
|
||||
|
||||
plugins {
|
||||
application
|
||||
kotlin("jvm")
|
||||
|
||||
// uses the 'old' Gradle plugin instead of the compiler plugin for now
|
||||
id("org.jetbrains.kotlinx.dataframe")
|
||||
|
||||
// only mandatory if `kotlin.dataframe.add.ksp=false` in gradle.properties
|
||||
id("com.google.devtools.ksp")
|
||||
}
|
||||
|
||||
repositories {
|
||||
mavenLocal() // in case of local dataframe development
|
||||
mavenCentral()
|
||||
}
|
||||
|
||||
dependencies {
|
||||
// implementation("org.jetbrains.kotlinx:dataframe:X.Y.Z")
|
||||
implementation(project(":"))
|
||||
|
||||
// exposed + sqlite database support
|
||||
implementation(libs.sqlite)
|
||||
implementation(libs.exposed.core)
|
||||
implementation(libs.exposed.kotlin.datetime)
|
||||
implementation(libs.exposed.jdbc)
|
||||
implementation(libs.exposed.json)
|
||||
implementation(libs.exposed.money)
|
||||
}
|
||||
|
||||
kotlin {
|
||||
compilerOptions {
|
||||
jvmTarget = JvmTarget.JVM_1_8
|
||||
freeCompilerArgs.add("-Xjdk-release=8")
|
||||
}
|
||||
}
|
||||
|
||||
tasks.withType<JavaCompile> {
|
||||
sourceCompatibility = JavaVersion.VERSION_1_8.toString()
|
||||
targetCompatibility = JavaVersion.VERSION_1_8.toString()
|
||||
options.release.set(8)
|
||||
}
|
||||
+107
@@ -0,0 +1,107 @@
|
||||
package org.jetbrains.kotlinx.dataframe.examples.exposed
|
||||
|
||||
import org.jetbrains.exposed.v1.core.BiCompositeColumn
|
||||
import org.jetbrains.exposed.v1.core.Column
|
||||
import org.jetbrains.exposed.v1.core.Expression
|
||||
import org.jetbrains.exposed.v1.core.ExpressionAlias
|
||||
import org.jetbrains.exposed.v1.core.ResultRow
|
||||
import org.jetbrains.exposed.v1.core.Table
|
||||
import org.jetbrains.exposed.v1.jdbc.Query
|
||||
import org.jetbrains.kotlinx.dataframe.AnyFrame
|
||||
import org.jetbrains.kotlinx.dataframe.DataFrame
|
||||
import org.jetbrains.kotlinx.dataframe.annotations.DataSchema
|
||||
import org.jetbrains.kotlinx.dataframe.api.convertTo
|
||||
import org.jetbrains.kotlinx.dataframe.api.toDataFrame
|
||||
import org.jetbrains.kotlinx.dataframe.codeGen.NameNormalizer
|
||||
import org.jetbrains.kotlinx.dataframe.impl.schema.DataFrameSchemaImpl
|
||||
import org.jetbrains.kotlinx.dataframe.schema.ColumnSchema
|
||||
import org.jetbrains.kotlinx.dataframe.schema.DataFrameSchema
|
||||
import kotlin.reflect.KProperty1
|
||||
import kotlin.reflect.full.isSubtypeOf
|
||||
import kotlin.reflect.full.memberProperties
|
||||
import kotlin.reflect.typeOf
|
||||
|
||||
/**
|
||||
* Retrieves all columns of any [Iterable][Iterable]`<`[ResultRow][ResultRow]`>`, like [Query][Query],
|
||||
* from Exposed row by row and converts the resulting [Map] into a [DataFrame], cast to type [T].
|
||||
*
|
||||
* In notebooks, the untyped version works just as well due to runtime inference :)
|
||||
*/
|
||||
inline fun <reified T : Any> Iterable<ResultRow>.convertToDataFrame(): DataFrame<T> =
|
||||
convertToDataFrame().convertTo<T>()
|
||||
|
||||
/**
|
||||
* Retrieves all columns of an [Iterable][Iterable]`<`[ResultRow][ResultRow]`>` from Exposed, like [Query][Query],
|
||||
* row by row and converts the resulting [Map] of lists into a [DataFrame] by calling
|
||||
* [Map.toDataFrame].
|
||||
*/
|
||||
@JvmName("convertToAnyFrame")
|
||||
fun Iterable<ResultRow>.convertToDataFrame(): AnyFrame {
|
||||
val map = mutableMapOf<String, MutableList<Any?>>()
|
||||
for (row in this) {
|
||||
for (expression in row.fieldIndex.keys) {
|
||||
map.getOrPut(expression.readableName) {
|
||||
mutableListOf()
|
||||
} += row[expression]
|
||||
}
|
||||
}
|
||||
return map.toDataFrame()
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieves a simple column name from [this] [Expression].
|
||||
*
|
||||
* Might need to be expanded with multiple types of [Expression].
|
||||
*/
|
||||
val Expression<*>.readableName: String
|
||||
get() = when (this) {
|
||||
is Column<*> -> name
|
||||
is ExpressionAlias<*> -> alias
|
||||
is BiCompositeColumn<*, *, *> -> getRealColumns().joinToString("_") { it.readableName }
|
||||
else -> toString()
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a [DataFrameSchema] from the declared [Table] instance.
|
||||
*
|
||||
* This is not needed for conversion, but it can be useful to create a DataFrame [@DataSchema][DataSchema] instance.
|
||||
*
|
||||
* @param columnNameToAccessor Optional [MutableMap] which will be filled with entries mapping
|
||||
* the SQL column name to the accessor name from the [Table].
|
||||
* This can be used to define a [NameNormalizer] later.
|
||||
* @see toDataFrameSchemaWithNameNormalizer
|
||||
*/
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
fun Table.toDataFrameSchema(columnNameToAccessor: MutableMap<String, String> = mutableMapOf()): DataFrameSchema {
|
||||
// we use reflection to go over all `Column<*>` properties in the Table object
|
||||
val columns = this::class.memberProperties
|
||||
.filter { it.returnType.isSubtypeOf(typeOf<Column<*>>()) }
|
||||
.associate { prop ->
|
||||
prop as KProperty1<Table, Column<*>>
|
||||
|
||||
// retrieve the SQL column name
|
||||
val columnName = prop.get(this).name
|
||||
// store the SQL column name together with the accessor name in the map
|
||||
columnNameToAccessor[columnName] = prop.name
|
||||
|
||||
// get the column type from `val a: Column<Type>`
|
||||
val type = prop.returnType.arguments.first().type!!
|
||||
|
||||
// and we add the name and column shema type to the `columns` map :)
|
||||
columnName to ColumnSchema.Value(type)
|
||||
}
|
||||
return DataFrameSchemaImpl(columns)
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates a [DataFrameSchema] from the declared [Table] instance with a [NameNormalizer] to
|
||||
* convert the SQL column names to the corresponding Kotlin property names.
|
||||
*
|
||||
* This is not needed for conversion, but it can be useful to create a DataFrame [@DataSchema][DataSchema] instance.
|
||||
*
|
||||
* @see toDataFrameSchema
|
||||
*/
|
||||
fun Table.toDataFrameSchemaWithNameNormalizer(): Pair<DataFrameSchema, NameNormalizer> {
|
||||
val columnNameToAccessor = mutableMapOf<String, String>()
|
||||
return Pair(toDataFrameSchema(), NameNormalizer { columnNameToAccessor[it] ?: it })
|
||||
}
|
||||
+96
@@ -0,0 +1,96 @@
|
||||
package org.jetbrains.kotlinx.dataframe.examples.exposed
|
||||
|
||||
import org.jetbrains.exposed.v1.core.Column
|
||||
import org.jetbrains.exposed.v1.core.SortOrder
|
||||
import org.jetbrains.exposed.v1.core.count
|
||||
import org.jetbrains.exposed.v1.jdbc.Database
|
||||
import org.jetbrains.exposed.v1.jdbc.SchemaUtils
|
||||
import org.jetbrains.exposed.v1.jdbc.batchInsert
|
||||
import org.jetbrains.exposed.v1.jdbc.deleteAll
|
||||
import org.jetbrains.exposed.v1.jdbc.select
|
||||
import org.jetbrains.exposed.v1.jdbc.selectAll
|
||||
import org.jetbrains.exposed.v1.jdbc.transactions.transaction
|
||||
import org.jetbrains.kotlinx.dataframe.api.asSequence
|
||||
import org.jetbrains.kotlinx.dataframe.api.count
|
||||
import org.jetbrains.kotlinx.dataframe.api.describe
|
||||
import org.jetbrains.kotlinx.dataframe.api.groupBy
|
||||
import org.jetbrains.kotlinx.dataframe.api.print
|
||||
import org.jetbrains.kotlinx.dataframe.api.sortByDesc
|
||||
import org.jetbrains.kotlinx.dataframe.size
|
||||
import java.io.File
|
||||
|
||||
/**
|
||||
* Describes a simple bridge between [Exposed](https://www.jetbrains.com/exposed/) and DataFrame!
|
||||
*/
|
||||
fun main() {
|
||||
// defining where to find our SQLite database for Exposed
|
||||
val resourceDb = "chinook.db"
|
||||
val dbPath = File(object {}.javaClass.classLoader.getResource(resourceDb)!!.toURI()).absolutePath
|
||||
val db = Database.connect(url = "jdbc:sqlite:$dbPath", driver = "org.sqlite.JDBC")
|
||||
|
||||
// let's read the database!
|
||||
val df = transaction(db) {
|
||||
// addLogger(StdOutSqlLogger) // enable if you want to see verbose logs
|
||||
|
||||
// tables in Exposed need to be defined, see tables.kt
|
||||
SchemaUtils.create(Customers, Artists, Albums)
|
||||
|
||||
println()
|
||||
|
||||
// In Exposed, we can write queries like this.
|
||||
// Here, we count per country how many customers there are and print the results:
|
||||
Customers
|
||||
.select(Customers.country, Customers.customerId.count())
|
||||
.groupBy(Customers.country)
|
||||
.orderBy(Customers.customerId.count() to SortOrder.DESC)
|
||||
.forEach {
|
||||
println("${it[Customers.country]}: ${it[Customers.customerId.count()]} customers")
|
||||
}
|
||||
|
||||
println()
|
||||
|
||||
// Perform the specific query you want to read into the DataFrame.
|
||||
// Note: DataFrames are in-memory structures, so don't make it too large if you don't have the RAM ;)
|
||||
val query = Customers.selectAll() // .where { Customers.company.isNotNull() }
|
||||
|
||||
println()
|
||||
|
||||
// read and convert the query to a typed DataFrame
|
||||
// see compatibilityLayer.kt for how we created convertToDataFrame<>()
|
||||
// and see tables.kt for how we created DfCustomers!
|
||||
query.convertToDataFrame<DfCustomers>()
|
||||
}
|
||||
|
||||
println(df.size())
|
||||
|
||||
// now we have a DataFrame, we can perform DataFrame operations,
|
||||
// like doing the same operation as we did in Exposed above
|
||||
df.groupBy { country }.count()
|
||||
.sortByDesc { "count"<Int>() }
|
||||
.print(columnTypes = true, borders = true)
|
||||
|
||||
// or just general statistics
|
||||
df.describe()
|
||||
.print(columnTypes = true, borders = true)
|
||||
|
||||
// or make plots using Kandy! It's all up to you
|
||||
|
||||
// writing a DataFrame back into an SQL database with Exposed can also be done easily!
|
||||
transaction(db) {
|
||||
// addLogger(StdOutSqlLogger) // enable if you want to see verbose logs
|
||||
|
||||
// first delete the original contents
|
||||
Customers.deleteAll()
|
||||
|
||||
println()
|
||||
|
||||
// batch-insert our dataframe back into the SQL database as a sequence of rows
|
||||
Customers.batchInsert(df.asSequence()) { dfRow ->
|
||||
// we simply go over each value in the row and put it in the right place in the Exposed statement
|
||||
for (column in Customers.columns) {
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
this[column as Column<Any?>] = dfRow[column.name]
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
+97
@@ -0,0 +1,97 @@
|
||||
package org.jetbrains.kotlinx.dataframe.examples.exposed
|
||||
|
||||
import org.jetbrains.exposed.v1.core.Column
|
||||
import org.jetbrains.exposed.v1.core.Table
|
||||
import org.jetbrains.kotlinx.dataframe.annotations.ColumnName
|
||||
import org.jetbrains.kotlinx.dataframe.annotations.DataSchema
|
||||
import org.jetbrains.kotlinx.dataframe.api.generateDataClasses
|
||||
import org.jetbrains.kotlinx.dataframe.api.print
|
||||
|
||||
object Albums : Table() {
|
||||
val albumId: Column<Int> = integer("AlbumId").autoIncrement()
|
||||
val title: Column<String> = varchar("Title", 160)
|
||||
val artistId: Column<Int> = integer("ArtistId")
|
||||
|
||||
override val primaryKey = PrimaryKey(albumId)
|
||||
}
|
||||
|
||||
object Artists : Table() {
|
||||
val artistId: Column<Int> = integer("ArtistId").autoIncrement()
|
||||
val name: Column<String> = varchar("Name", 120)
|
||||
|
||||
override val primaryKey = PrimaryKey(artistId)
|
||||
}
|
||||
|
||||
object Customers : Table() {
|
||||
val customerId: Column<Int> = integer("CustomerId").autoIncrement()
|
||||
val firstName: Column<String> = varchar("FirstName", 40)
|
||||
val lastName: Column<String> = varchar("LastName", 20)
|
||||
val company: Column<String?> = varchar("Company", 80).nullable()
|
||||
val address: Column<String?> = varchar("Address", 70).nullable()
|
||||
val city: Column<String?> = varchar("City", 40).nullable()
|
||||
val state: Column<String?> = varchar("State", 40).nullable()
|
||||
val country: Column<String?> = varchar("Country", 40).nullable()
|
||||
val postalCode: Column<String?> = varchar("PostalCode", 10).nullable()
|
||||
val phone: Column<String?> = varchar("Phone", 24).nullable()
|
||||
val fax: Column<String?> = varchar("Fax", 24).nullable()
|
||||
val email: Column<String> = varchar("Email", 60)
|
||||
val supportRepId: Column<Int?> = integer("SupportRepId").nullable()
|
||||
|
||||
override val primaryKey = PrimaryKey(customerId)
|
||||
}
|
||||
|
||||
/**
|
||||
* Exposed requires you to provide [Table] instances to
|
||||
* provide type-safe access to your columns and data.
|
||||
*
|
||||
* While DataFrame can infer types at runtime, which is enough for Kotlin Notebook,
|
||||
* to get type safe access at compile time, we need to define a [@DataSchema][DataSchema].
|
||||
*
|
||||
* This is what we created the [toDataFrameSchema] function for!
|
||||
*/
|
||||
fun main() {
|
||||
val (schema, nameNormalizer) = Customers.toDataFrameSchemaWithNameNormalizer()
|
||||
|
||||
// checking whether the schema is converted correctly.
|
||||
// schema.print()
|
||||
|
||||
// printing a @DataSchema data class to copy-paste into the code.
|
||||
// we use a NameNormalizer to let DataFrame generate the same accessors as in the Table
|
||||
// while keeping the correct column names
|
||||
schema.generateDataClasses(
|
||||
markerName = "DfCustomers",
|
||||
nameNormalizer = nameNormalizer,
|
||||
).print()
|
||||
}
|
||||
|
||||
// created by Customers.toDataFrameSchema()
|
||||
// The same can be done for the other tables
|
||||
@DataSchema
|
||||
data class DfCustomers(
|
||||
@ColumnName("Address")
|
||||
val address: String?,
|
||||
@ColumnName("City")
|
||||
val city: String?,
|
||||
@ColumnName("Company")
|
||||
val company: String?,
|
||||
@ColumnName("Country")
|
||||
val country: String?,
|
||||
@ColumnName("CustomerId")
|
||||
val customerId: Int,
|
||||
@ColumnName("Email")
|
||||
val email: String,
|
||||
@ColumnName("Fax")
|
||||
val fax: String?,
|
||||
@ColumnName("FirstName")
|
||||
val firstName: String,
|
||||
@ColumnName("LastName")
|
||||
val lastName: String,
|
||||
@ColumnName("Phone")
|
||||
val phone: String?,
|
||||
@ColumnName("PostalCode")
|
||||
val postalCode: String?,
|
||||
@ColumnName("State")
|
||||
val state: String?,
|
||||
@ColumnName("SupportRepId")
|
||||
val supportRepId: Int?,
|
||||
)
|
||||
Vendored
BIN
Binary file not shown.
+43
@@ -0,0 +1,43 @@
|
||||
import org.jetbrains.kotlin.gradle.dsl.JvmTarget
|
||||
|
||||
plugins {
|
||||
application
|
||||
kotlin("jvm")
|
||||
|
||||
// uses the 'old' Gradle plugin instead of the compiler plugin for now
|
||||
id("org.jetbrains.kotlinx.dataframe")
|
||||
|
||||
// only mandatory if `kotlin.dataframe.add.ksp=false` in gradle.properties
|
||||
id("com.google.devtools.ksp")
|
||||
}
|
||||
|
||||
repositories {
|
||||
mavenLocal() // in case of local dataframe development
|
||||
mavenCentral()
|
||||
}
|
||||
|
||||
dependencies {
|
||||
// implementation("org.jetbrains.kotlinx:dataframe:X.Y.Z")
|
||||
implementation(project(":"))
|
||||
|
||||
// Hibernate + H2 + HikariCP (for Hibernate example)
|
||||
implementation(libs.hibernate.core)
|
||||
implementation(libs.hibernate.hikaricp)
|
||||
implementation(libs.hikaricp)
|
||||
|
||||
implementation(libs.h2db)
|
||||
implementation(libs.sl4jsimple)
|
||||
}
|
||||
|
||||
kotlin {
|
||||
compilerOptions {
|
||||
jvmTarget = JvmTarget.JVM_11
|
||||
freeCompilerArgs.add("-Xjdk-release=11")
|
||||
}
|
||||
}
|
||||
|
||||
tasks.withType<JavaCompile> {
|
||||
sourceCompatibility = JavaVersion.VERSION_11.toString()
|
||||
targetCompatibility = JavaVersion.VERSION_11.toString()
|
||||
options.release.set(11)
|
||||
}
|
||||
+100
@@ -0,0 +1,100 @@
|
||||
package org.jetbrains.kotlinx.dataframe.examples.hibernate
|
||||
|
||||
import jakarta.persistence.Column
|
||||
import jakarta.persistence.Entity
|
||||
import jakarta.persistence.GeneratedValue
|
||||
import jakarta.persistence.GenerationType
|
||||
import jakarta.persistence.Id
|
||||
import jakarta.persistence.Table
|
||||
import org.jetbrains.kotlinx.dataframe.annotations.ColumnName
|
||||
import org.jetbrains.kotlinx.dataframe.annotations.DataSchema
|
||||
|
||||
@Entity
|
||||
@Table(name = "Albums")
|
||||
class AlbumsEntity(
|
||||
@Id
|
||||
@GeneratedValue(strategy = GenerationType.IDENTITY)
|
||||
@Column(name = "AlbumId")
|
||||
var albumId: Int? = null,
|
||||
|
||||
@Column(name = "Title", length = 160, nullable = false)
|
||||
var title: String = "",
|
||||
|
||||
@Column(name = "ArtistId", nullable = false)
|
||||
var artistId: Int = 0,
|
||||
)
|
||||
|
||||
@Entity
|
||||
@Table(name = "Artists")
|
||||
class ArtistsEntity(
|
||||
@Id
|
||||
@GeneratedValue(strategy = GenerationType.IDENTITY)
|
||||
@Column(name = "ArtistId")
|
||||
var artistId: Int? = null,
|
||||
|
||||
@Column(name = "Name", length = 120, nullable = false)
|
||||
var name: String = "",
|
||||
)
|
||||
|
||||
@Entity
|
||||
@Table(name = "Customers")
|
||||
class CustomersEntity(
|
||||
@Id
|
||||
@GeneratedValue(strategy = GenerationType.IDENTITY)
|
||||
@Column(name = "CustomerId")
|
||||
var customerId: Int? = null,
|
||||
|
||||
@Column(name = "FirstName", length = 40, nullable = false)
|
||||
var firstName: String = "",
|
||||
|
||||
@Column(name = "LastName", length = 20, nullable = false)
|
||||
var lastName: String = "",
|
||||
|
||||
@Column(name = "Company", length = 80)
|
||||
var company: String? = null,
|
||||
|
||||
@Column(name = "Address", length = 70)
|
||||
var address: String? = null,
|
||||
|
||||
@Column(name = "City", length = 40)
|
||||
var city: String? = null,
|
||||
|
||||
@Column(name = "State", length = 40)
|
||||
var state: String? = null,
|
||||
|
||||
@Column(name = "Country", length = 40)
|
||||
var country: String? = null,
|
||||
|
||||
@Column(name = "PostalCode", length = 10)
|
||||
var postalCode: String? = null,
|
||||
|
||||
@Column(name = "Phone", length = 24)
|
||||
var phone: String? = null,
|
||||
|
||||
@Column(name = "Fax", length = 24)
|
||||
var fax: String? = null,
|
||||
|
||||
@Column(name = "Email", length = 60, nullable = false)
|
||||
var email: String = "",
|
||||
|
||||
@Column(name = "SupportRepId")
|
||||
var supportRepId: Int? = null,
|
||||
)
|
||||
|
||||
// DataFrame schema to get typed accessors similar to Exposed example
|
||||
@DataSchema
|
||||
data class DfCustomers(
|
||||
@ColumnName("Address") val address: String?,
|
||||
@ColumnName("City") val city: String?,
|
||||
@ColumnName("Company") val company: String?,
|
||||
@ColumnName("Country") val country: String?,
|
||||
@ColumnName("CustomerId") val customerId: Int,
|
||||
@ColumnName("Email") val email: String,
|
||||
@ColumnName("Fax") val fax: String?,
|
||||
@ColumnName("FirstName") val firstName: String,
|
||||
@ColumnName("LastName") val lastName: String,
|
||||
@ColumnName("Phone") val phone: String?,
|
||||
@ColumnName("PostalCode") val postalCode: String?,
|
||||
@ColumnName("State") val state: String?,
|
||||
@ColumnName("SupportRepId") val supportRepId: Int?,
|
||||
)
|
||||
+251
@@ -0,0 +1,251 @@
|
||||
package org.jetbrains.kotlinx.dataframe.examples.hibernate
|
||||
|
||||
import jakarta.persistence.Tuple
|
||||
import jakarta.persistence.criteria.CriteriaBuilder
|
||||
import jakarta.persistence.criteria.CriteriaDelete
|
||||
import jakarta.persistence.criteria.CriteriaQuery
|
||||
import jakarta.persistence.criteria.Expression
|
||||
import jakarta.persistence.criteria.Root
|
||||
import org.hibernate.FlushMode
|
||||
import org.hibernate.SessionFactory
|
||||
import org.hibernate.cfg.Configuration
|
||||
import org.jetbrains.kotlinx.dataframe.DataFrame
|
||||
import org.jetbrains.kotlinx.dataframe.DataRow
|
||||
import org.jetbrains.kotlinx.dataframe.api.asSequence
|
||||
import org.jetbrains.kotlinx.dataframe.api.count
|
||||
import org.jetbrains.kotlinx.dataframe.api.describe
|
||||
import org.jetbrains.kotlinx.dataframe.api.groupBy
|
||||
import org.jetbrains.kotlinx.dataframe.api.print
|
||||
import org.jetbrains.kotlinx.dataframe.api.sortByDesc
|
||||
import org.jetbrains.kotlinx.dataframe.api.toDataFrame
|
||||
import org.jetbrains.kotlinx.dataframe.size
|
||||
|
||||
/**
|
||||
* Example showing Kotlin DataFrame with Hibernate ORM + H2 in-memory DB.
|
||||
* Mirrors logic from the Exposed example: load data, convert to DataFrame, group/describe, write back.
|
||||
*/
|
||||
fun main() {
|
||||
val sessionFactory: SessionFactory = buildSessionFactory()
|
||||
|
||||
sessionFactory.insertSampleData()
|
||||
|
||||
val df = sessionFactory.loadCustomersAsDataFrame()
|
||||
|
||||
// Pure Hibernate + Criteria API approach for counting customers per country
|
||||
println("=== Hibernate + Criteria API Approach ===")
|
||||
sessionFactory.countCustomersPerCountryWithHibernate()
|
||||
|
||||
println("\n=== DataFrame Approach ===")
|
||||
df.analyzeAndPrintResults()
|
||||
|
||||
sessionFactory.replaceCustomersFromDataFrame(df)
|
||||
|
||||
sessionFactory.close()
|
||||
}
|
||||
|
||||
private fun SessionFactory.insertSampleData() {
|
||||
withTransaction { session ->
|
||||
// a few artists and albums (minimal, not used further; just demo schema)
|
||||
val artist1 = ArtistsEntity(name = "AC/DC")
|
||||
val artist2 = ArtistsEntity(name = "Queen")
|
||||
session.persist(artist1)
|
||||
session.persist(artist2)
|
||||
session.flush()
|
||||
|
||||
session.persist(AlbumsEntity(title = "High Voltage", artistId = artist1.artistId!!))
|
||||
session.persist(AlbumsEntity(title = "Back in Black", artistId = artist1.artistId!!))
|
||||
session.persist(AlbumsEntity(title = "A Night at the Opera", artistId = artist2.artistId!!))
|
||||
// customers we'll analyze using DataFrame
|
||||
session.persist(
|
||||
CustomersEntity(
|
||||
firstName = "John",
|
||||
lastName = "Doe",
|
||||
email = "john.doe@example.com",
|
||||
country = "USA",
|
||||
),
|
||||
)
|
||||
session.persist(
|
||||
CustomersEntity(
|
||||
firstName = "Jane",
|
||||
lastName = "Smith",
|
||||
email = "jane.smith@example.com",
|
||||
country = "USA",
|
||||
),
|
||||
)
|
||||
session.persist(
|
||||
CustomersEntity(
|
||||
firstName = "Alice",
|
||||
lastName = "Wang",
|
||||
email = "alice.wang@example.com",
|
||||
country = "Canada",
|
||||
),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
private fun SessionFactory.loadCustomersAsDataFrame(): DataFrame<DfCustomers> {
|
||||
return withReadOnlyTransaction { session ->
|
||||
val criteriaBuilder: CriteriaBuilder = session.criteriaBuilder
|
||||
val criteriaQuery: CriteriaQuery<CustomersEntity> = criteriaBuilder.createQuery(CustomersEntity::class.java)
|
||||
val root: Root<CustomersEntity> = criteriaQuery.from(CustomersEntity::class.java)
|
||||
criteriaQuery.select(root)
|
||||
|
||||
session.createQuery(criteriaQuery)
|
||||
.resultList
|
||||
.map { c ->
|
||||
DfCustomers(
|
||||
address = c.address,
|
||||
city = c.city,
|
||||
company = c.company,
|
||||
country = c.country,
|
||||
customerId = c.customerId ?: -1,
|
||||
email = c.email,
|
||||
fax = c.fax,
|
||||
firstName = c.firstName,
|
||||
lastName = c.lastName,
|
||||
phone = c.phone,
|
||||
postalCode = c.postalCode,
|
||||
state = c.state,
|
||||
supportRepId = c.supportRepId,
|
||||
)
|
||||
}
|
||||
.toDataFrame()
|
||||
}
|
||||
}
|
||||
|
||||
/** DTO used for aggregation projection. */
|
||||
private data class CountryCountDto(
|
||||
val country: String,
|
||||
val customerCount: Long,
|
||||
)
|
||||
|
||||
/**
|
||||
* **Hibernate + Criteria API:**
|
||||
* - ✅ Database-level aggregation (efficient)
|
||||
* - ✅ Type-safe queries
|
||||
* - ❌ Verbose syntax
|
||||
* - ❌ Limited to SQL-like operations
|
||||
*/
|
||||
private fun SessionFactory.countCustomersPerCountryWithHibernate() {
|
||||
withReadOnlyTransaction { session ->
|
||||
val cb = session.criteriaBuilder
|
||||
val cq: CriteriaQuery<CountryCountDto> = cb.createQuery(CountryCountDto::class.java)
|
||||
val root: Root<CustomersEntity> = cq.from(CustomersEntity::class.java)
|
||||
|
||||
val countryPath = root.get<String>("country")
|
||||
val idPath = root.get<Long>("customerId")
|
||||
|
||||
val countExpr = cb.count(idPath)
|
||||
|
||||
cq.select(
|
||||
cb.construct(
|
||||
CountryCountDto::class.java,
|
||||
countryPath, // country
|
||||
countExpr, // customerCount
|
||||
),
|
||||
)
|
||||
cq.groupBy(countryPath)
|
||||
cq.orderBy(cb.desc(countExpr))
|
||||
|
||||
val results = session.createQuery(cq).resultList
|
||||
results.forEach { dto ->
|
||||
println("${dto.country}: ${dto.customerCount} customers")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* **DataFrame approach: **
|
||||
* - ✅ Rich analytical operations
|
||||
* - ✅ Fluent, readable API
|
||||
* - ✅ Flexible data transformations
|
||||
* - ❌ In-memory processing (less efficient for large datasets)
|
||||
*/
|
||||
private fun DataFrame<DfCustomers>.analyzeAndPrintResults() {
|
||||
println(size())
|
||||
|
||||
// same operation as Exposed example: customers per country
|
||||
groupBy { country }.count()
|
||||
.sortByDesc { "count"<Int>() }
|
||||
.print(columnTypes = true, borders = true)
|
||||
|
||||
// general statistics
|
||||
describe()
|
||||
.print(columnTypes = true, borders = true)
|
||||
}
|
||||
|
||||
private fun SessionFactory.replaceCustomersFromDataFrame(df: DataFrame<DfCustomers>) {
|
||||
withTransaction { session ->
|
||||
val criteriaBuilder: CriteriaBuilder = session.criteriaBuilder
|
||||
val criteriaDelete: CriteriaDelete<CustomersEntity> =
|
||||
criteriaBuilder.createCriteriaDelete(CustomersEntity::class.java)
|
||||
criteriaDelete.from(CustomersEntity::class.java)
|
||||
|
||||
session.createMutationQuery(criteriaDelete).executeUpdate()
|
||||
}
|
||||
|
||||
withTransaction { session ->
|
||||
df.asSequence().forEach { row ->
|
||||
session.persist(row.toCustomersEntity())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private fun DataRow<DfCustomers>.toCustomersEntity(): CustomersEntity {
|
||||
return CustomersEntity(
|
||||
customerId = null, // let DB generate
|
||||
firstName = this.firstName,
|
||||
lastName = this.lastName,
|
||||
company = this.company,
|
||||
address = this.address,
|
||||
city = this.city,
|
||||
state = this.state,
|
||||
country = this.country,
|
||||
postalCode = this.postalCode,
|
||||
phone = this.phone,
|
||||
fax = this.fax,
|
||||
email = this.email,
|
||||
supportRepId = this.supportRepId,
|
||||
)
|
||||
}
|
||||
|
||||
private inline fun <T> SessionFactory.withSession(block: (session: org.hibernate.Session) -> T): T {
|
||||
return openSession().use(block)
|
||||
}
|
||||
|
||||
private inline fun SessionFactory.withTransaction(block: (session: org.hibernate.Session) -> Unit) {
|
||||
withSession { session ->
|
||||
session.beginTransaction()
|
||||
try {
|
||||
block(session)
|
||||
session.transaction.commit()
|
||||
} catch (e: Exception) {
|
||||
session.transaction.rollback()
|
||||
throw e
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/** Read-only transaction helper for SELECT queries to minimize overhead. */
|
||||
private inline fun <T> SessionFactory.withReadOnlyTransaction(block: (session: org.hibernate.Session) -> T): T {
|
||||
return withSession { session ->
|
||||
session.beginTransaction()
|
||||
// Minimize overhead for read operations
|
||||
session.isDefaultReadOnly = true
|
||||
session.hibernateFlushMode = FlushMode.MANUAL
|
||||
try {
|
||||
val result = block(session)
|
||||
session.transaction.commit()
|
||||
result
|
||||
} catch (e: Exception) {
|
||||
session.transaction.rollback()
|
||||
throw e
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
private fun buildSessionFactory(): SessionFactory {
|
||||
// Load configuration from resources/hibernate/hibernate.cfg.xml
|
||||
return Configuration().configure("hibernate/hibernate.cfg.xml").buildSessionFactory()
|
||||
}
|
||||
+32
@@ -0,0 +1,32 @@
|
||||
<?xml version="1.0" encoding="utf-8"?>
|
||||
<!DOCTYPE hibernate-configuration PUBLIC
|
||||
"-//Hibernate/Hibernate Configuration DTD 5.3//EN"
|
||||
"http://hibernate.org/dtd/hibernate-configuration-5.3.dtd">
|
||||
<hibernate-configuration>
|
||||
<session-factory>
|
||||
<!-- H2 in-memory -->
|
||||
<property name="hibernate.connection.driver_class">org.h2.Driver</property>
|
||||
<property name="hibernate.connection.url">jdbc:h2:mem:testdb;DB_CLOSE_DELAY=-1</property>
|
||||
<property name="hibernate.connection.username">sa</property>
|
||||
<property name="hibernate.connection.password"></property>
|
||||
|
||||
<!-- Connection pool: HikariCP via Hibernate integration -->
|
||||
<property name="hibernate.connection.provider_class">org.hibernate.hikaricp.internal.HikariCPConnectionProvider</property>
|
||||
<property name="hibernate.hikari.maximumPoolSize">5</property>
|
||||
|
||||
<!-- Hibernate Dialect -->
|
||||
<property name="hibernate.dialect">org.hibernate.dialect.H2Dialect</property>
|
||||
|
||||
<!-- Automatic schema generation -->
|
||||
<property name="hibernate.hbm2ddl.auto">create-drop</property>
|
||||
|
||||
<!-- Logging -->
|
||||
<property name="hibernate.show_sql">true</property>
|
||||
<property name="hibernate.format_sql">true</property>
|
||||
|
||||
<!-- Mappings -->
|
||||
<mapping class="org.jetbrains.kotlinx.dataframe.examples.hibernate.CustomersEntity"/>
|
||||
<mapping class="org.jetbrains.kotlinx.dataframe.examples.hibernate.ArtistsEntity"/>
|
||||
<mapping class="org.jetbrains.kotlinx.dataframe.examples.hibernate.AlbumsEntity"/>
|
||||
</session-factory>
|
||||
</hibernate-configuration>
|
||||
+39
@@ -0,0 +1,39 @@
|
||||
import org.jetbrains.kotlin.gradle.dsl.JvmTarget
|
||||
|
||||
plugins {
|
||||
application
|
||||
kotlin("jvm")
|
||||
|
||||
// uses the 'old' Gradle plugin instead of the compiler plugin for now
|
||||
id("org.jetbrains.kotlinx.dataframe")
|
||||
|
||||
// only mandatory if `kotlin.dataframe.add.ksp=false` in gradle.properties
|
||||
id("com.google.devtools.ksp")
|
||||
}
|
||||
|
||||
repositories {
|
||||
mavenLocal() // in case of local dataframe development
|
||||
mavenCentral()
|
||||
}
|
||||
|
||||
dependencies {
|
||||
// implementation("org.jetbrains.kotlinx:dataframe:X.Y.Z")
|
||||
implementation(project(":"))
|
||||
|
||||
// multik support
|
||||
implementation(libs.multik.core)
|
||||
implementation(libs.multik.default)
|
||||
}
|
||||
|
||||
kotlin {
|
||||
compilerOptions {
|
||||
jvmTarget = JvmTarget.JVM_1_8
|
||||
freeCompilerArgs.add("-Xjdk-release=8")
|
||||
}
|
||||
}
|
||||
|
||||
tasks.withType<JavaCompile> {
|
||||
sourceCompatibility = JavaVersion.VERSION_1_8.toString()
|
||||
targetCompatibility = JavaVersion.VERSION_1_8.toString()
|
||||
options.release.set(8)
|
||||
}
|
||||
+374
@@ -0,0 +1,374 @@
|
||||
@file:OptIn(ExperimentalTypeInference::class)
|
||||
|
||||
package org.jetbrains.kotlinx.dataframe.examples.multik
|
||||
|
||||
import org.jetbrains.kotlinx.dataframe.AnyFrame
|
||||
import org.jetbrains.kotlinx.dataframe.ColumnSelector
|
||||
import org.jetbrains.kotlinx.dataframe.ColumnsSelector
|
||||
import org.jetbrains.kotlinx.dataframe.DataColumn
|
||||
import org.jetbrains.kotlinx.dataframe.DataFrame
|
||||
import org.jetbrains.kotlinx.dataframe.api.ValueProperty
|
||||
import org.jetbrains.kotlinx.dataframe.api.cast
|
||||
import org.jetbrains.kotlinx.dataframe.api.colsOf
|
||||
import org.jetbrains.kotlinx.dataframe.api.column
|
||||
import org.jetbrains.kotlinx.dataframe.api.dataFrameOf
|
||||
import org.jetbrains.kotlinx.dataframe.api.getColumn
|
||||
import org.jetbrains.kotlinx.dataframe.api.getColumns
|
||||
import org.jetbrains.kotlinx.dataframe.api.map
|
||||
import org.jetbrains.kotlinx.dataframe.api.named
|
||||
import org.jetbrains.kotlinx.dataframe.api.toColumn
|
||||
import org.jetbrains.kotlinx.dataframe.api.toColumnGroup
|
||||
import org.jetbrains.kotlinx.dataframe.api.toDataFrame
|
||||
import org.jetbrains.kotlinx.dataframe.columns.BaseColumn
|
||||
import org.jetbrains.kotlinx.dataframe.columns.ColumnGroup
|
||||
import org.jetbrains.kotlinx.multik.api.mk
|
||||
import org.jetbrains.kotlinx.multik.api.ndarray
|
||||
import org.jetbrains.kotlinx.multik.ndarray.complex.Complex
|
||||
import org.jetbrains.kotlinx.multik.ndarray.data.D1Array
|
||||
import org.jetbrains.kotlinx.multik.ndarray.data.D2Array
|
||||
import org.jetbrains.kotlinx.multik.ndarray.data.D3Array
|
||||
import org.jetbrains.kotlinx.multik.ndarray.data.MultiArray
|
||||
import org.jetbrains.kotlinx.multik.ndarray.data.NDArray
|
||||
import org.jetbrains.kotlinx.multik.ndarray.data.get
|
||||
import org.jetbrains.kotlinx.multik.ndarray.operations.toList
|
||||
import org.jetbrains.kotlinx.multik.ndarray.operations.toListD2
|
||||
import kotlin.experimental.ExperimentalTypeInference
|
||||
import kotlin.reflect.KClass
|
||||
import kotlin.reflect.KType
|
||||
import kotlin.reflect.full.isSubtypeOf
|
||||
import kotlin.reflect.typeOf
|
||||
|
||||
// region 1D
|
||||
|
||||
/** Converts a one-dimensional array ([D1Array]) to a [DataColumn] with optional [name]. */
|
||||
inline fun <reified N> D1Array<N>.convertToColumn(name: String = ""): DataColumn<N> {
|
||||
// we can simply convert the 1D array to a typed list and create a typed column from it
|
||||
// by using the reified type parameter, DataFrame needs to do no inference :)
|
||||
val values = this.toList()
|
||||
return column<N>(values) named name
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts a one-dimensional array ([D1Array]) of type [N] into a DataFrame.
|
||||
* The resulting DataFrame contains a single column named "value", where each element of the array becomes a row in the DataFrame.
|
||||
*
|
||||
* @return a DataFrame where each element of the source array is represented as a row in a column named "value" under the schema [ValueProperty].
|
||||
*/
|
||||
@JvmName("convert1dArrayToDataFrame")
|
||||
inline fun <reified N> D1Array<N>.convertToDataFrame(): DataFrame<ValueProperty<N>> {
|
||||
// do the conversion like above, but name the column "value"...
|
||||
val column = this.convertToColumn(ValueProperty<*>::value.name)
|
||||
// ...so we can cast it to a ValueProperty DataFrame
|
||||
return dataFrameOf(column).cast<ValueProperty<N>>()
|
||||
}
|
||||
|
||||
/** Converts a [DataColumn] to a one-dimensional array ([D1Array]). */
|
||||
@JvmName("convertNumberColumnToMultik")
|
||||
inline fun <reified N> DataColumn<N>.convertToMultik(): D1Array<N> where N : Number, N : Comparable<N> {
|
||||
// we can convert our column to a typed list again to convert it to a multik array
|
||||
val values = this.toList()
|
||||
return mk.ndarray(values)
|
||||
}
|
||||
|
||||
/** Converts a [DataColumn] to a one-dimensional array ([D1Array]). */
|
||||
@JvmName("convertComplexColumnToMultik")
|
||||
inline fun <reified N : Complex> DataColumn<N>.convertToMultik(): D1Array<N> {
|
||||
// we can convert our column to a typed list again to convert it to a multik array
|
||||
val values = this.toList()
|
||||
return mk.ndarray(values)
|
||||
}
|
||||
|
||||
/** Converts a [DataColumn] selected by [column] to a one-dimensional array ([D1Array]). */
|
||||
@JvmName("convertNumberColumnFromDfToMultik")
|
||||
@OverloadResolutionByLambdaReturnType
|
||||
inline fun <T, reified N> DataFrame<T>.convertToMultik(
|
||||
crossinline column: ColumnSelector<T, N>,
|
||||
): D1Array<N>
|
||||
where N : Number, N : Comparable<N> {
|
||||
// use the selector to get the column from this DataFrame and convert it
|
||||
val col = this.getColumn { column(it) }
|
||||
return col.convertToMultik()
|
||||
}
|
||||
|
||||
/** Converts a [DataColumn] selected by [column] to a one-dimensional array ([D1Array]). */
|
||||
@JvmName("convertComplexColumnFromDfToMultik")
|
||||
@OverloadResolutionByLambdaReturnType
|
||||
inline fun <T, reified N : Complex> DataFrame<T>.convertToMultik(crossinline column: ColumnSelector<T, N>): D1Array<N> {
|
||||
// use the selector to get the column from this DataFrame and convert it
|
||||
val col = this.getColumn { column(it) }
|
||||
return col.convertToMultik()
|
||||
}
|
||||
|
||||
// endregion
|
||||
|
||||
// region 2D
|
||||
|
||||
/**
|
||||
* Converts a two-dimensional array ([D2Array]) to a DataFrame.
|
||||
* It will contain `shape[0]` rows and `shape[1]` columns.
|
||||
*
|
||||
* Column names can be specified using the [columnNameGenerator] lambda.
|
||||
*
|
||||
* The conversion enforces that `multikArray[x][y] == dataframe[x][y]`
|
||||
*/
|
||||
@JvmName("convert2dArrayToDataFrame")
|
||||
inline fun <reified N> D2Array<N>.convertToDataFrame(columnNameGenerator: (Int) -> String = { "col$it" }): AnyFrame {
|
||||
// Turning the 2D array into a list of typed columns first, no inference needed
|
||||
val columns: List<DataColumn<N>> = List(shape[1]) { i ->
|
||||
this[0..<shape[0], i] // get all cells of column i
|
||||
.toList()
|
||||
.toColumn<N>(name = columnNameGenerator(i))
|
||||
}
|
||||
// and make a DataFrame from it
|
||||
return columns.toDataFrame()
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts a [DataFrame] to a two-dimensional array ([D2Array]).
|
||||
* You'll need to specify which columns to convert using the [columns] selector.
|
||||
*
|
||||
* All columns need to be of the same type. If no columns are supplied, the function
|
||||
* will only succeed if all columns are of the same type.
|
||||
*
|
||||
* @see convertToMultikOf
|
||||
*/
|
||||
@JvmName("convertNumberColumnsFromDfToMultik")
|
||||
@OverloadResolutionByLambdaReturnType
|
||||
inline fun <T, reified N> DataFrame<T>.convertToMultik(
|
||||
crossinline columns: ColumnsSelector<T, N>,
|
||||
): D2Array<N>
|
||||
where N : Number, N : Comparable<N> {
|
||||
// use the selector to get the columns from this DataFrame and convert them
|
||||
val cols = this.getColumns { columns(it) }
|
||||
return cols.convertToMultik()
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts a [DataFrame] to a two-dimensional array ([D2Array]).
|
||||
* You'll need to specify which columns to convert using the [columns] selector.
|
||||
*
|
||||
* All columns need to be of the same type. If no columns are supplied, the function
|
||||
* will only succeed if all columns are of the same type.
|
||||
*
|
||||
* @see convertToMultikOf
|
||||
*/
|
||||
@JvmName("convertComplexColumnsFromDfToMultik")
|
||||
@OverloadResolutionByLambdaReturnType
|
||||
inline fun <T, reified N : Complex> DataFrame<T>.convertToMultik(
|
||||
crossinline columns: ColumnsSelector<T, N>,
|
||||
): D2Array<N> {
|
||||
// use the selector to get the columns from this DataFrame and convert them
|
||||
val cols = this.getColumns { columns(it) }
|
||||
return cols.convertToMultik()
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts a [DataFrame] to a two-dimensional array ([D2Array]).
|
||||
* You'll need to specify which columns to convert using the `columns` selector.
|
||||
*
|
||||
* All columns need to be of the same type. If no columns are supplied, the function
|
||||
* will only succeed if all columns in [this] are of the same type.
|
||||
*
|
||||
* @see convertToMultikOf
|
||||
*/
|
||||
@JvmName("convertToMultikGuess")
|
||||
fun AnyFrame.convertToMultik(): D2Array<*> {
|
||||
val columnTypes = this.columnTypes().distinct()
|
||||
val type = columnTypes.singleOrNull() ?: error("found multiple column types: $columnTypes")
|
||||
return when {
|
||||
type == typeOf<Complex>() -> convertToMultik { colsOf<Complex>() }
|
||||
type.isSubtypeOf(typeOf<Byte>()) -> convertToMultik { colsOf<Byte>() }
|
||||
type.isSubtypeOf(typeOf<Short>()) -> convertToMultik { colsOf<Short>() }
|
||||
type.isSubtypeOf(typeOf<Int>()) -> convertToMultik { colsOf<Int>() }
|
||||
type.isSubtypeOf(typeOf<Long>()) -> convertToMultik { colsOf<Long>() }
|
||||
type.isSubtypeOf(typeOf<Float>()) -> convertToMultik { colsOf<Float>() }
|
||||
type.isSubtypeOf(typeOf<Double>()) -> convertToMultik { colsOf<Double>() }
|
||||
else -> error("found multiple column types: $columnTypes")
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts a [DataFrame] to a two-dimensional array ([D2Array]) by taking all
|
||||
* columns of type [N].
|
||||
*
|
||||
* Allows you to write `df.convertToMultikOf<Complex>()`.
|
||||
*
|
||||
* @see convertToMultik
|
||||
*/
|
||||
@JvmName("convertToMultikOfComplex")
|
||||
@Suppress("LocalVariableName")
|
||||
inline fun <reified N : Complex> AnyFrame.convertToMultikOf(
|
||||
// unused param to avoid overload resolution ambiguity
|
||||
_klass: KClass<Complex> = Complex::class,
|
||||
): D2Array<N> =
|
||||
convertToMultik { colsOf<N>() }
|
||||
|
||||
/**
|
||||
* Converts a [DataFrame] to a two-dimensional array ([D2Array]) by taking all
|
||||
* columns of type [N].
|
||||
*
|
||||
* Allows you to write `df.convertToMultikOf<Int>()`.
|
||||
*
|
||||
* @see convertToMultik
|
||||
*/
|
||||
@JvmName("convertToMultikOfNumber")
|
||||
@Suppress("LocalVariableName")
|
||||
inline fun <reified N> AnyFrame.convertToMultikOf(
|
||||
// unused param to avoid overload resolution ambiguity
|
||||
_klass: KClass<Number> = Number::class,
|
||||
): D2Array<N> where N : Number, N : Comparable<N> = convertToMultik { colsOf<N>() }
|
||||
|
||||
/**
|
||||
* Helper function to convert a list of same-typed [DataColumn]s to a two-dimensional array ([D2Array]).
|
||||
* We cannot enforce all columns have the same type if we require just a [DataFrame].
|
||||
*/
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
@JvmName("convertNumberColumnsToMultik")
|
||||
inline fun <reified N> List<DataColumn<N>>.convertToMultik(): D2Array<N> where N : Number, N : Comparable<N> {
|
||||
// to get the list of columns as a list of rows, we need to convert them back to a dataframe first,
|
||||
// then we can get the values of each row
|
||||
val rows = this.toDataFrame().map { row -> row.values() as List<N> }
|
||||
return mk.ndarray(rows)
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper function to convert a list of same-typed [DataColumn]s to a two-dimensional array ([D2Array]).
|
||||
* We cannot enforce all columns have the same type if we require just a [DataFrame].
|
||||
*/
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
@JvmName("convertComplexColumnsToMultik")
|
||||
inline fun <reified N : Complex> List<DataColumn<N>>.convertToMultik(): D2Array<N> {
|
||||
// to get the list of columns as a list of rows, we need to convert them back to a dataframe first,
|
||||
// then we can get the values of each row
|
||||
val rows = this.toDataFrame().map { row -> row.values() as List<N> }
|
||||
return mk.ndarray(rows)
|
||||
}
|
||||
|
||||
// endregion
|
||||
|
||||
// region higher dimensions
|
||||
|
||||
/**
|
||||
* Converts a three-dimensional array ([D3Array]) to a DataFrame.
|
||||
* It will contain `shape[0]` rows and `shape[1]` columns containing lists of size `shape[2]`.
|
||||
*
|
||||
* Column names can be specified using the [columnNameGenerator] lambda.
|
||||
*
|
||||
* The conversion enforces that `multikArray[x][y][z] == dataframe[x][y][z]`
|
||||
*/
|
||||
inline fun <reified N> D3Array<N>.convertToDataFrameWithLists(
|
||||
columnNameGenerator: (Int) -> String = { "col$it" },
|
||||
): AnyFrame {
|
||||
val columns: List<DataColumn<List<N>>> = List(shape[1]) { y ->
|
||||
this[0..<shape[0], y, 0..<shape[2]] // get all cells of column y, each is a 2d array of size shape[0] x shape[2]
|
||||
.toListD2() // get a shape[0]-sized list/column filled with lists of size shape[2]
|
||||
.toColumn<List<N>>(name = columnNameGenerator(y))
|
||||
}
|
||||
return columns.toDataFrame()
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts a three-dimensional array ([D3Array]) to a DataFrame.
|
||||
* It will contain `shape[0]` rows and `shape[1]` column groups containing `shape[2]` columns each.
|
||||
*
|
||||
* Column names can be specified using the [columnNameGenerator] lambda.
|
||||
*
|
||||
* The conversion enforces that `multikArray[x][y][z] == dataframe[x][y][z]`
|
||||
*/
|
||||
@JvmName("convert3dArrayToDataFrame")
|
||||
inline fun <reified N> D3Array<N>.convertToDataFrame(columnNameGenerator: (Int) -> String = { "col$it" }): AnyFrame {
|
||||
val columns: List<ColumnGroup<*>> = List(shape[1]) { y ->
|
||||
this[0..<shape[0], y, 0..<shape[2]] // get all cells of column i, each is a 2d array of size shape[0] x shape[2]
|
||||
.transpose(1, 0) // flip, so we get shape[2] x shape[0]
|
||||
.toListD2() // get a shape[2]-sized list filled with lists of size shape[0]
|
||||
.mapIndexed { z, list ->
|
||||
list.toColumn<N>(name = columnNameGenerator(z))
|
||||
} // we get shape[2] columns inside each column group
|
||||
.toColumnGroup(name = columnNameGenerator(y))
|
||||
}
|
||||
return columns.toDataFrame()
|
||||
}
|
||||
|
||||
/**
|
||||
* Exploratory recursive function to convert a [MultiArray] of any number of dimensions
|
||||
* to a `List<List<...>>` of the same number of dimensions.
|
||||
*/
|
||||
fun <T> MultiArray<T, *>.toListDn(): List<*> {
|
||||
// Recursive helper function to handle traversal across dimensions
|
||||
fun toListRecursive(indices: IntArray): List<*> {
|
||||
// If we are at the last dimension (1D case)
|
||||
if (indices.size == shape.lastIndex) {
|
||||
return List(shape[indices.size]) { i ->
|
||||
this[intArrayOf(*indices, i)] // Collect values for this dimension
|
||||
}
|
||||
}
|
||||
|
||||
// For higher dimensions, recursively process smaller dimensions
|
||||
return List(shape[indices.size]) { i ->
|
||||
toListRecursive(indices + i) // Add `i` to the current index array
|
||||
}
|
||||
}
|
||||
return toListRecursive(intArrayOf())
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts a multidimensional array ([NDArray]) to a DataFrame.
|
||||
* Inspired by [toListDn].
|
||||
*
|
||||
* For a single-dimensional array, it will call [D1Array.convertToDataFrame].
|
||||
*
|
||||
* Column names can be specified using the [columnNameGenerator] lambda.
|
||||
*
|
||||
* The conversion enforces that `multikArray[a][b][c][d]... == dataframe[a][b][c][d]...`
|
||||
*/
|
||||
@Suppress("UNCHECKED_CAST")
|
||||
inline fun <reified N> NDArray<N, *>.convertToDataFrameNestedGroups(
|
||||
noinline columnNameGenerator: (Int) -> String = { "col$it" },
|
||||
): AnyFrame {
|
||||
if (shape.size == 1) return (this as D1Array<N>).convertToDataFrame()
|
||||
|
||||
// push the first dimension to the end, because this represents the rows in DataFrame,
|
||||
// and they are accessed by []'s first
|
||||
return transpose(*(1..<dim.d).toList().toIntArray(), 0)
|
||||
.convertToDataFrameNestedGroupsRecursive(
|
||||
indices = intArrayOf(),
|
||||
type = typeOf<N>(), // cannot inline a recursive function, so pass the type explicitly
|
||||
columnNameGenerator = columnNameGenerator,
|
||||
).let {
|
||||
// we could just cast this to a DataFrame<*>, because a ColumnGroup<*>: DataFrame
|
||||
// however, this can sometimes cause issues where instance checks are done at runtime
|
||||
// this converts it to an actual DataFrame instance
|
||||
dataFrameOf((it as ColumnGroup<*>).columns())
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Recursive helper function to handle traversal across dimensions. Do not call directly,
|
||||
* use [convertToDataFrameNestedGroups] instead.
|
||||
*/
|
||||
@PublishedApi
|
||||
internal fun NDArray<*, *>.convertToDataFrameNestedGroupsRecursive(
|
||||
indices: IntArray,
|
||||
type: KType,
|
||||
columnNameGenerator: (Int) -> String,
|
||||
): BaseColumn<*> {
|
||||
// If we are at the last dimension (1D case)
|
||||
if (indices.size == shape.lastIndex) {
|
||||
return List(shape[indices.size]) { i ->
|
||||
this[intArrayOf(*indices, i)] // Collect values for this dimension
|
||||
}.let {
|
||||
DataColumn.createByType(name = "", values = it, type = type)
|
||||
}
|
||||
}
|
||||
|
||||
// For higher dimensions, recursively process smaller dimensions
|
||||
return List(shape[indices.size]) { i ->
|
||||
convertToDataFrameNestedGroupsRecursive(
|
||||
indices = indices + i, // Add `i` to the current index array
|
||||
type = type,
|
||||
columnNameGenerator = columnNameGenerator,
|
||||
).rename(columnNameGenerator(i))
|
||||
}.toColumnGroup("")
|
||||
}
|
||||
|
||||
// endregion
|
||||
+23
@@ -0,0 +1,23 @@
|
||||
package org.jetbrains.kotlinx.dataframe.examples.multik
|
||||
|
||||
import org.jetbrains.kotlinx.dataframe.api.print
|
||||
import org.jetbrains.kotlinx.multik.api.io.readNPY
|
||||
import org.jetbrains.kotlinx.multik.api.mk
|
||||
import org.jetbrains.kotlinx.multik.ndarray.data.D1
|
||||
import java.io.File
|
||||
|
||||
/**
|
||||
* Multik can read/write data from NPY/NPZ files.
|
||||
* We can use this from DataFrame too!
|
||||
*
|
||||
* We use compatibilityLayer.kt for the conversions, check it out for the implementation details of the conversion!
|
||||
*/
|
||||
fun main() {
|
||||
val npyFilename = "a1d.npy"
|
||||
val npyFile = File(object {}.javaClass.classLoader.getResource(npyFilename)!!.toURI())
|
||||
|
||||
val mk1 = mk.readNPY<Long, D1>(npyFile)
|
||||
val df1 = mk1.convertToDataFrame()
|
||||
|
||||
df1.print(borders = true, columnTypes = true)
|
||||
}
|
||||
+99
@@ -0,0 +1,99 @@
|
||||
package org.jetbrains.kotlinx.dataframe.examples.multik
|
||||
|
||||
import org.jetbrains.kotlinx.dataframe.api.cast
|
||||
import org.jetbrains.kotlinx.dataframe.api.colsOf
|
||||
import org.jetbrains.kotlinx.dataframe.api.describe
|
||||
import org.jetbrains.kotlinx.dataframe.api.mean
|
||||
import org.jetbrains.kotlinx.dataframe.api.meanFor
|
||||
import org.jetbrains.kotlinx.dataframe.api.print
|
||||
import org.jetbrains.kotlinx.dataframe.api.value
|
||||
import org.jetbrains.kotlinx.multik.api.mk
|
||||
import org.jetbrains.kotlinx.multik.api.rand
|
||||
import org.jetbrains.kotlinx.multik.ndarray.data.get
|
||||
|
||||
/**
|
||||
* Let's explore some ways we can combine Multik with Kotlin DataFrame.
|
||||
*
|
||||
* We will use compatibilityLayer.kt for the conversions.
|
||||
* Take a look at that file for the implementation details!
|
||||
*/
|
||||
fun main() {
|
||||
oneDimension()
|
||||
twoDimensions()
|
||||
higherDimensions()
|
||||
}
|
||||
|
||||
fun oneDimension() {
|
||||
// we can convert a 1D ndarray to a column of a DataFrame:
|
||||
val mk1 = mk.rand<Double>(50)
|
||||
val col1 by mk1.convertToColumn()
|
||||
println(col1)
|
||||
|
||||
// or straight to a DataFrame. It will become the `value` column.
|
||||
val df1 = mk1.convertToDataFrame()
|
||||
println(df1)
|
||||
|
||||
// this allows us to perform any DF operation:
|
||||
println(df1.mean { value })
|
||||
df1.describe().print(borders = true)
|
||||
|
||||
// we can convert back to Multik:
|
||||
val mk2 = df1.convertToMultik { value }
|
||||
// or
|
||||
df1.value.convertToMultik()
|
||||
|
||||
println(mk2)
|
||||
}
|
||||
|
||||
fun twoDimensions() {
|
||||
// we can also convert a 2D ndarray to a DataFrame
|
||||
// This conversion will create columns like "col0", "col1", etc.
|
||||
// (careful, when the number of columns is too large, this can cause problems)
|
||||
// but will allow for similar access like in multik
|
||||
// aka: `multikArray[x][y] == dataframe[x][y]`
|
||||
val mk1 = mk.rand<Int>(5, 10)
|
||||
println(mk1)
|
||||
val df = mk1.convertToDataFrame()
|
||||
df.print()
|
||||
|
||||
// this allows us to perform any DF operation:
|
||||
val means = df.meanFor { ("col0".."col9").cast<Int>() }
|
||||
means.print()
|
||||
|
||||
// we can convert back to Multik in multiple ways.
|
||||
// Multik can only store one type of data, so we need to specify the type or select
|
||||
// only the columns we want:
|
||||
val mk2 = df.convertToMultik { colsOf<Int>() }
|
||||
// or
|
||||
df.convertToMultikOf<Int>()
|
||||
// or if all columns are of the same type:
|
||||
df.convertToMultik()
|
||||
|
||||
println(mk2)
|
||||
}
|
||||
|
||||
fun higherDimensions() {
|
||||
// Multik can store higher dimensions as well
|
||||
// however; to convert this to a DataFrame, we need to specify how to do a particular conversion
|
||||
// for instance, for 3d, we could store a list in each cell of the DF to represent the extra dimension:
|
||||
val mk1 = mk.rand<Int>(5, 4, 3)
|
||||
|
||||
println(mk1)
|
||||
|
||||
val df1 = mk1.convertToDataFrameWithLists()
|
||||
df1.print()
|
||||
|
||||
// Alternatively, this could be solved using column groups.
|
||||
// This subdivides each column into more columns, while ensuring `multikArray[x][y][z] == dataframe[x][y][z]`
|
||||
val df2 = mk1.convertToDataFrame()
|
||||
df2.print()
|
||||
|
||||
// For even higher dimensions, we can keep adding more column groups
|
||||
val mk2 = mk.rand<Int>(5, 4, 3, 2)
|
||||
val df3 = mk2.convertToDataFrameNestedGroups()
|
||||
df3.print()
|
||||
|
||||
// ...or use nested DataFrames (in FrameColumns)
|
||||
// (for instance, a 4D matrix could be stored in a 2D DataFrame where each cell is another DataFrame)
|
||||
// but, we'll leave that as an exercise for the reader :)
|
||||
}
|
||||
+115
@@ -0,0 +1,115 @@
|
||||
package org.jetbrains.kotlinx.dataframe.examples.multik
|
||||
|
||||
import kotlinx.datetime.LocalDate
|
||||
import kotlinx.datetime.Month
|
||||
import org.jetbrains.kotlinx.dataframe.annotations.DataSchema
|
||||
import org.jetbrains.kotlinx.dataframe.api.append
|
||||
import org.jetbrains.kotlinx.dataframe.api.cast
|
||||
import org.jetbrains.kotlinx.dataframe.api.mapToFrame
|
||||
import org.jetbrains.kotlinx.dataframe.api.print
|
||||
import org.jetbrains.kotlinx.dataframe.api.single
|
||||
import org.jetbrains.kotlinx.dataframe.api.toDataFrame
|
||||
import org.jetbrains.kotlinx.multik.api.mk
|
||||
import org.jetbrains.kotlinx.multik.api.rand
|
||||
import org.jetbrains.kotlinx.multik.ndarray.data.D3Array
|
||||
import org.jetbrains.kotlinx.multik.ndarray.data.D4Array
|
||||
|
||||
/**
|
||||
* DataFrames can store anything inside, including Multik ndarrays.
|
||||
* This can be useful for storing matrices for easier access later or to simply organize data read from other files.
|
||||
* For example, MRI data is often stored as 3D arrays and sometimes even 4D arrays.
|
||||
*/
|
||||
fun main() {
|
||||
// imaginary list of patient data
|
||||
@Suppress("ktlint:standard:argument-list-wrapping")
|
||||
val metadata = listOf(
|
||||
MriMetadata(10012L, 25, "Healthy", LocalDate(2023, 1, 1)),
|
||||
MriMetadata(10013L, 45, "Tuberculosis", LocalDate(2023, 2, 15)),
|
||||
MriMetadata(10014L, 32, "Healthy", LocalDate(2023, 3, 22)),
|
||||
MriMetadata(10015L, 58, "Pneumonia", LocalDate(2023, 4, 8)),
|
||||
MriMetadata(10016L, 29, "Tuberculosis", LocalDate(2023, 5, 30)),
|
||||
MriMetadata(10017L, 42, "Healthy", LocalDate(2023, 6, 15)),
|
||||
MriMetadata(10018L, 37, "Healthy", LocalDate(2023, 7, 1)),
|
||||
MriMetadata(10019L, 55, "Healthy", LocalDate(2023, 8, 15)),
|
||||
MriMetadata(10020L, 28, "Healthy", LocalDate(2023, 9, 1)),
|
||||
MriMetadata(10021L, 44, "Healthy", LocalDate(2023, 10, 15)),
|
||||
MriMetadata(10022L, 31, "Healthy", LocalDate(2023, 11, 1)),
|
||||
).toDataFrame()
|
||||
|
||||
// "reading" the results from "files"
|
||||
val results = metadata.mapToFrame {
|
||||
+patientId
|
||||
+age
|
||||
+diagnosis
|
||||
+scanDate
|
||||
"t1WeightedMri" from { readT1WeightedMri(patientId) }
|
||||
"fMriBoldSeries" from { readFMRiBoldSeries(patientId) }
|
||||
}.cast<MriResults>(verify = true)
|
||||
.append()
|
||||
|
||||
results.print(borders = true)
|
||||
|
||||
// now when we want to check and visualize the T1-weighted MRI scan
|
||||
// for that one healthy patient in July, we can do:
|
||||
val scan = results
|
||||
.single { scanDate.month == Month.JULY && diagnosis == "Healthy" }
|
||||
.t1WeightedMri
|
||||
|
||||
// easy :)
|
||||
visualize(scan)
|
||||
}
|
||||
|
||||
@DataSchema
|
||||
data class MriMetadata(
|
||||
/** Unique patient ID. */
|
||||
val patientId: Long,
|
||||
/** Patient age. */
|
||||
val age: Int,
|
||||
/** Clinical diagnosis (e.g. "Healthy", "Tuberculosis") */
|
||||
val diagnosis: String,
|
||||
/** Date of the scan */
|
||||
val scanDate: LocalDate,
|
||||
)
|
||||
|
||||
@DataSchema
|
||||
data class MriResults(
|
||||
/** Unique patient ID. */
|
||||
val patientId: Long,
|
||||
/** Patient age. */
|
||||
val age: Int,
|
||||
/** Clinical diagnosis (e.g. "Healthy", "Tuberculosis") */
|
||||
val diagnosis: String,
|
||||
/** Date of the scan */
|
||||
val scanDate: LocalDate,
|
||||
/**
|
||||
* T1-weighted anatomical MRI scan.
|
||||
*
|
||||
* Dimensions: (256 x 256 x 180)
|
||||
* - 256 width x 256 height
|
||||
* - 180 slices
|
||||
*/
|
||||
val t1WeightedMri: D3Array<Float>,
|
||||
/**
|
||||
* Blood oxygenation level-dependent (BOLD) time series from an fMRI scan.
|
||||
*
|
||||
* Dimensions: (64 x 64 x 30 x 200)
|
||||
* - 64 width x 64 height
|
||||
* - 30 slices
|
||||
* - 200 timepoints
|
||||
*/
|
||||
val fMriBoldSeries: D4Array<Float>,
|
||||
)
|
||||
|
||||
fun readT1WeightedMri(id: Long): D3Array<Float> {
|
||||
// This should in practice, of course, read the actual data, but for this example we just return a dummy array
|
||||
return mk.rand(256, 256, 180)
|
||||
}
|
||||
|
||||
fun readFMRiBoldSeries(id: Long): D4Array<Float> {
|
||||
// This should in practice, of course, read the actual data, but for this example we just return a dummy array
|
||||
return mk.rand(64, 64, 30, 200)
|
||||
}
|
||||
|
||||
fun visualize(scan: D3Array<Float>) {
|
||||
// This would then actually visualize the scan
|
||||
}
|
||||
Vendored
BIN
Binary file not shown.
+77
@@ -0,0 +1,77 @@
|
||||
import org.jetbrains.kotlin.gradle.dsl.JvmTarget
|
||||
|
||||
plugins {
|
||||
application
|
||||
kotlin("jvm")
|
||||
|
||||
// uses the 'old' Gradle plugin instead of the compiler plugin for now
|
||||
id("org.jetbrains.kotlinx.dataframe")
|
||||
|
||||
// only mandatory if `kotlin.dataframe.add.ksp=false` in gradle.properties
|
||||
id("com.google.devtools.ksp")
|
||||
}
|
||||
|
||||
repositories {
|
||||
mavenLocal() // in case of local dataframe development
|
||||
mavenCentral()
|
||||
}
|
||||
|
||||
dependencies {
|
||||
// implementation("org.jetbrains.kotlinx:dataframe:X.Y.Z")
|
||||
implementation(project(":"))
|
||||
|
||||
// (kotlin) spark support
|
||||
implementation(libs.kotlin.spark)
|
||||
compileOnly(libs.spark)
|
||||
implementation(libs.log4j.core)
|
||||
implementation(libs.log4j.api)
|
||||
}
|
||||
|
||||
/**
|
||||
* Runs the kotlinSpark/typedDataset example with java 11.
|
||||
*/
|
||||
val runKotlinSparkTypedDataset by tasks.registering(JavaExec::class) {
|
||||
classpath = sourceSets["main"].runtimeClasspath
|
||||
javaLauncher = javaToolchains.launcherFor { languageVersion = JavaLanguageVersion.of(11) }
|
||||
mainClass = "org.jetbrains.kotlinx.dataframe.examples.kotlinSpark.TypedDatasetKt"
|
||||
}
|
||||
|
||||
/**
|
||||
* Runs the kotlinSpark/untypedDataset example with java 11.
|
||||
*/
|
||||
val runKotlinSparkUntypedDataset by tasks.registering(JavaExec::class) {
|
||||
classpath = sourceSets["main"].runtimeClasspath
|
||||
javaLauncher = javaToolchains.launcherFor { languageVersion = JavaLanguageVersion.of(11) }
|
||||
mainClass = "org.jetbrains.kotlinx.dataframe.examples.kotlinSpark.UntypedDatasetKt"
|
||||
}
|
||||
|
||||
/**
|
||||
* Runs the spark/typedDataset example with java 11.
|
||||
*/
|
||||
val runSparkTypedDataset by tasks.registering(JavaExec::class) {
|
||||
classpath = sourceSets["main"].runtimeClasspath
|
||||
javaLauncher = javaToolchains.launcherFor { languageVersion = JavaLanguageVersion.of(11) }
|
||||
mainClass = "org.jetbrains.kotlinx.dataframe.examples.spark.TypedDatasetKt"
|
||||
}
|
||||
|
||||
/**
|
||||
* Runs the spark/untypedDataset example with java 11.
|
||||
*/
|
||||
val runSparkUntypedDataset by tasks.registering(JavaExec::class) {
|
||||
classpath = sourceSets["main"].runtimeClasspath
|
||||
javaLauncher = javaToolchains.launcherFor { languageVersion = JavaLanguageVersion.of(11) }
|
||||
mainClass = "org.jetbrains.kotlinx.dataframe.examples.spark.UntypedDatasetKt"
|
||||
}
|
||||
|
||||
kotlin {
|
||||
compilerOptions {
|
||||
jvmTarget = JvmTarget.JVM_11
|
||||
freeCompilerArgs.add("-Xjdk-release=11")
|
||||
}
|
||||
}
|
||||
|
||||
tasks.withType<JavaCompile> {
|
||||
sourceCompatibility = JavaVersion.VERSION_11.toString()
|
||||
targetCompatibility = JavaVersion.VERSION_11.toString()
|
||||
options.release.set(11)
|
||||
}
|
||||
+8
@@ -0,0 +1,8 @@
|
||||
@file:Suppress("ktlint:standard:no-empty-file")
|
||||
|
||||
package org.jetbrains.kotlinx.dataframe.examples.kotlinSpark
|
||||
|
||||
/*
|
||||
* See ../spark/compatibilityLayer.kt for the implementation.
|
||||
* It's the same with- and without the Kotlin Spark API.
|
||||
*/
|
||||
+78
@@ -0,0 +1,78 @@
|
||||
@file:Suppress("ktlint:standard:function-signature")
|
||||
|
||||
package org.jetbrains.kotlinx.dataframe.examples.kotlinSpark
|
||||
|
||||
import org.apache.spark.sql.Dataset
|
||||
import org.jetbrains.kotlinx.dataframe.annotations.DataSchema
|
||||
import org.jetbrains.kotlinx.dataframe.api.aggregate
|
||||
import org.jetbrains.kotlinx.dataframe.api.groupBy
|
||||
import org.jetbrains.kotlinx.dataframe.api.max
|
||||
import org.jetbrains.kotlinx.dataframe.api.mean
|
||||
import org.jetbrains.kotlinx.dataframe.api.min
|
||||
import org.jetbrains.kotlinx.dataframe.api.print
|
||||
import org.jetbrains.kotlinx.dataframe.api.schema
|
||||
import org.jetbrains.kotlinx.dataframe.api.std
|
||||
import org.jetbrains.kotlinx.dataframe.api.toDataFrame
|
||||
import org.jetbrains.kotlinx.dataframe.api.toList
|
||||
import org.jetbrains.kotlinx.spark.api.withSpark
|
||||
|
||||
/**
|
||||
* With the Kotlin Spark API, normal Kotlin data classes are supported,
|
||||
* meaning we can reuse the same class for Spark and DataFrame!
|
||||
*
|
||||
* Also, since we use an actual class to define the schema, we need no type conversion!
|
||||
*
|
||||
* See [Person] and [Name] for an example.
|
||||
*
|
||||
* NOTE: You will likely need to run this function with Java 8 or 11 for it to work correctly.
|
||||
* Use the `runKotlinSparkTypedDataset` Gradle task to do so.
|
||||
*/
|
||||
fun main() = withSpark {
|
||||
// Creating a Spark Dataset. Usually, this is loaded from some server or database.
|
||||
val rawDataset: Dataset<Person> = listOf(
|
||||
Person(Name("Alice", "Cooper"), 15, "London", 54, true),
|
||||
Person(Name("Bob", "Dylan"), 45, "Dubai", 87, true),
|
||||
Person(Name("Charlie", "Daniels"), 20, "Moscow", null, false),
|
||||
Person(Name("Charlie", "Chaplin"), 40, "Milan", null, true),
|
||||
Person(Name("Bob", "Marley"), 30, "Tokyo", 68, true),
|
||||
Person(Name("Alice", "Wolf"), 20, null, 55, false),
|
||||
Person(Name("Charlie", "Byrd"), 30, "Moscow", 90, true),
|
||||
).toDS()
|
||||
|
||||
// we can perform large operations in Spark.
|
||||
// DataFrames are in-memory structures, so this is a good place to limit the number of rows if you don't have the RAM ;)
|
||||
val dataset = rawDataset.filter { it.age > 17 }
|
||||
|
||||
// and convert it to DataFrame via a typed List
|
||||
val dataframe = dataset.collectAsList().toDataFrame()
|
||||
dataframe.schema().print()
|
||||
dataframe.print(columnTypes = true, borders = true)
|
||||
|
||||
// now we can use DataFrame-specific functions
|
||||
val ageStats = dataframe
|
||||
.groupBy { city }.aggregate {
|
||||
mean { age } into "meanAge"
|
||||
std { age } into "stdAge"
|
||||
min { age } into "minAge"
|
||||
max { age } into "maxAge"
|
||||
}
|
||||
|
||||
ageStats.print(columnTypes = true, borders = true)
|
||||
|
||||
// and when we want to convert a DataFrame back to Spark, we can do the same trick via a typed List
|
||||
val sparkDatasetAgain = dataframe.toList().toDS()
|
||||
sparkDatasetAgain.printSchema()
|
||||
sparkDatasetAgain.show()
|
||||
}
|
||||
|
||||
@DataSchema
|
||||
data class Name(val firstName: String, val lastName: String)
|
||||
|
||||
@DataSchema
|
||||
data class Person(
|
||||
val name: Name,
|
||||
val age: Int,
|
||||
val city: String?,
|
||||
val weight: Int?,
|
||||
val isHappy: Boolean,
|
||||
)
|
||||
+74
@@ -0,0 +1,74 @@
|
||||
@file:Suppress("ktlint:standard:function-signature")
|
||||
|
||||
package org.jetbrains.kotlinx.dataframe.examples.kotlinSpark
|
||||
|
||||
import org.apache.spark.sql.Dataset
|
||||
import org.apache.spark.sql.Row
|
||||
import org.jetbrains.kotlinx.dataframe.api.aggregate
|
||||
import org.jetbrains.kotlinx.dataframe.api.groupBy
|
||||
import org.jetbrains.kotlinx.dataframe.api.max
|
||||
import org.jetbrains.kotlinx.dataframe.api.mean
|
||||
import org.jetbrains.kotlinx.dataframe.api.min
|
||||
import org.jetbrains.kotlinx.dataframe.api.print
|
||||
import org.jetbrains.kotlinx.dataframe.api.schema
|
||||
import org.jetbrains.kotlinx.dataframe.api.std
|
||||
import org.jetbrains.kotlinx.dataframe.examples.spark.convertToDataFrame
|
||||
import org.jetbrains.kotlinx.dataframe.examples.spark.convertToDataFrameByInference
|
||||
import org.jetbrains.kotlinx.dataframe.examples.spark.convertToSpark
|
||||
import org.jetbrains.kotlinx.spark.api.col
|
||||
import org.jetbrains.kotlinx.spark.api.gt
|
||||
import org.jetbrains.kotlinx.spark.api.withSpark
|
||||
|
||||
/**
|
||||
* Since we don't know the schema at compile time this time, we need to do
|
||||
* some schema mapping in between Spark and DataFrame.
|
||||
*
|
||||
* We will use spark/compatibilityLayer.kt to do this.
|
||||
* Take a look at that file for the implementation details!
|
||||
*
|
||||
* NOTE: You will likely need to run this function with Java 8 or 11 for it to work correctly.
|
||||
* Use the `runKotlinSparkUntypedDataset` Gradle task to do so.
|
||||
*/
|
||||
fun main() = withSpark {
|
||||
// Creating a Spark Dataframe (untyped Dataset). Usually, this is loaded from some server or database.
|
||||
val rawDataset: Dataset<Row> = listOf(
|
||||
Person(Name("Alice", "Cooper"), 15, "London", 54, true),
|
||||
Person(Name("Bob", "Dylan"), 45, "Dubai", 87, true),
|
||||
Person(Name("Charlie", "Daniels"), 20, "Moscow", null, false),
|
||||
Person(Name("Charlie", "Chaplin"), 40, "Milan", null, true),
|
||||
Person(Name("Bob", "Marley"), 30, "Tokyo", 68, true),
|
||||
Person(Name("Alice", "Wolf"), 20, null, 55, false),
|
||||
Person(Name("Charlie", "Byrd"), 30, "Moscow", 90, true),
|
||||
).toDF()
|
||||
|
||||
// we can perform large operations in Spark.
|
||||
// DataFrames are in-memory structures, so this is a good place to limit the number of rows if you don't have the RAM ;)
|
||||
val dataset = rawDataset.filter(col("age") gt 17)
|
||||
|
||||
// Using inference
|
||||
val df1 = dataset.convertToDataFrameByInference()
|
||||
df1.schema().print()
|
||||
df1.print(columnTypes = true, borders = true)
|
||||
|
||||
// Using full schema mapping
|
||||
val df2 = dataset.convertToDataFrame()
|
||||
df2.schema().print()
|
||||
df2.print(columnTypes = true, borders = true)
|
||||
|
||||
// now we can use DataFrame-specific functions
|
||||
val ageStats = df1
|
||||
.groupBy("city").aggregate {
|
||||
mean("age") into "meanAge"
|
||||
std("age") into "stdAge"
|
||||
min("age") into "minAge"
|
||||
max("age") into "maxAge"
|
||||
}
|
||||
|
||||
ageStats.print(columnTypes = true, borders = true)
|
||||
|
||||
// and when we want to convert a DataFrame back to Spark, we will use the `convertToSpark()` extension function
|
||||
// This performs the necessary schema mapping under the hood.
|
||||
val sparkDataset = df2.convertToSpark(spark, sc)
|
||||
sparkDataset.printSchema()
|
||||
sparkDataset.show()
|
||||
}
|
||||
+330
@@ -0,0 +1,330 @@
|
||||
package org.jetbrains.kotlinx.dataframe.examples.spark
|
||||
|
||||
import org.apache.spark.api.java.JavaRDD
|
||||
import org.apache.spark.api.java.JavaSparkContext
|
||||
import org.apache.spark.sql.Dataset
|
||||
import org.apache.spark.sql.Row
|
||||
import org.apache.spark.sql.RowFactory
|
||||
import org.apache.spark.sql.SparkSession
|
||||
import org.apache.spark.sql.types.ArrayType
|
||||
import org.apache.spark.sql.types.DataType
|
||||
import org.apache.spark.sql.types.DataTypes
|
||||
import org.apache.spark.sql.types.Decimal
|
||||
import org.apache.spark.sql.types.DecimalType
|
||||
import org.apache.spark.sql.types.MapType
|
||||
import org.apache.spark.sql.types.StructType
|
||||
import org.apache.spark.unsafe.types.CalendarInterval
|
||||
import org.jetbrains.kotlinx.dataframe.AnyFrame
|
||||
import org.jetbrains.kotlinx.dataframe.DataColumn
|
||||
import org.jetbrains.kotlinx.dataframe.DataFrame
|
||||
import org.jetbrains.kotlinx.dataframe.DataRow
|
||||
import org.jetbrains.kotlinx.dataframe.api.rows
|
||||
import org.jetbrains.kotlinx.dataframe.api.schema
|
||||
import org.jetbrains.kotlinx.dataframe.api.toDataFrame
|
||||
import org.jetbrains.kotlinx.dataframe.columns.ColumnGroup
|
||||
import org.jetbrains.kotlinx.dataframe.columns.TypeSuggestion
|
||||
import org.jetbrains.kotlinx.dataframe.schema.ColumnSchema
|
||||
import org.jetbrains.kotlinx.dataframe.schema.DataFrameSchema
|
||||
import java.math.BigDecimal
|
||||
import java.math.BigInteger
|
||||
import java.sql.Date
|
||||
import java.sql.Timestamp
|
||||
import java.time.Instant
|
||||
import java.time.LocalDate
|
||||
import kotlin.reflect.KType
|
||||
import kotlin.reflect.KTypeProjection
|
||||
import kotlin.reflect.full.createType
|
||||
import kotlin.reflect.full.isSubtypeOf
|
||||
import kotlin.reflect.full.withNullability
|
||||
import kotlin.reflect.typeOf
|
||||
|
||||
// region Spark to DataFrame
|
||||
|
||||
/**
|
||||
* Converts an untyped Spark [Dataset] (Dataframe) to a Kotlin [DataFrame].
|
||||
* [StructTypes][StructType] are converted to [ColumnGroups][ColumnGroup].
|
||||
*
|
||||
* DataFrame supports type inference to do the conversion automatically.
|
||||
* This is usually fine for smaller data sets, but when working with larger datasets, a type map might be a good idea.
|
||||
* See [convertToDataFrame] for more information.
|
||||
*/
|
||||
fun Dataset<Row>.convertToDataFrameByInference(
|
||||
schema: StructType = schema(),
|
||||
prefix: List<String> = emptyList(),
|
||||
): AnyFrame {
|
||||
val columns = schema.fields().map { field ->
|
||||
val name = field.name()
|
||||
when (val dataType = field.dataType()) {
|
||||
is StructType ->
|
||||
// a column group can be easily created from a dataframe and a name
|
||||
DataColumn.createColumnGroup(
|
||||
name = name,
|
||||
df = this.convertToDataFrameByInference(dataType, prefix + name),
|
||||
)
|
||||
|
||||
else ->
|
||||
// we can use DataFrame type inference to create a column with the correct type
|
||||
// from Spark we use `select()` to select a single column
|
||||
// and `collectAsList()` to get all the values in a list of single-celled rows
|
||||
DataColumn.createByInference(
|
||||
name = name,
|
||||
values = this.select((prefix + name).joinToString("."))
|
||||
.collectAsList()
|
||||
.map { it[0] },
|
||||
suggestedType = TypeSuggestion.Infer,
|
||||
// Spark provides nullability :) you can leave this out if you want this to be inferred too
|
||||
nullable = field.nullable(),
|
||||
)
|
||||
}
|
||||
}
|
||||
return columns.toDataFrame()
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts an untyped Spark [Dataset] (Dataframe) to a Kotlin [DataFrame].
|
||||
* [StructTypes][StructType] are converted to [ColumnGroups][ColumnGroup].
|
||||
*
|
||||
* This version uses a [type-map][DataType.convertToDataFrame] to convert the schemas with a fallback to inference.
|
||||
* For smaller data sets, inference is usually fine too.
|
||||
* See [convertToDataFrameByInference] for more information.
|
||||
*/
|
||||
fun Dataset<Row>.convertToDataFrame(schema: StructType = schema(), prefix: List<String> = emptyList()): AnyFrame {
|
||||
val columns = schema.fields().map { field ->
|
||||
val name = field.name()
|
||||
when (val dataType = field.dataType()) {
|
||||
is StructType ->
|
||||
// a column group can be easily created from a dataframe and a name
|
||||
DataColumn.createColumnGroup(
|
||||
name = name,
|
||||
df = convertToDataFrame(dataType, prefix + name),
|
||||
)
|
||||
|
||||
else ->
|
||||
// we create a column with the correct type using our type-map with fallback to inference
|
||||
// from Spark we use `select()` to select a single column
|
||||
// and `collectAsList()` to get all the values in a list of single-celled rows
|
||||
DataColumn.createByInference(
|
||||
name = name,
|
||||
values = select((prefix + name).joinToString("."))
|
||||
.collectAsList()
|
||||
.map { it[0] },
|
||||
suggestedType =
|
||||
dataType.convertToDataFrame()
|
||||
?.let(TypeSuggestion::Use)
|
||||
?: TypeSuggestion.Infer, // fallback to inference if needed
|
||||
nullable = field.nullable(),
|
||||
)
|
||||
}
|
||||
}
|
||||
return columns.toDataFrame()
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the corresponding [Kotlin type][KType] for a given Spark [DataType].
|
||||
*
|
||||
* This list may be incomplete, but it can at least give you a good start.
|
||||
*
|
||||
* @return The [KType] that corresponds to the Spark [DataType], or null if no matching [KType] is found.
|
||||
*/
|
||||
fun DataType.convertToDataFrame(): KType? =
|
||||
when {
|
||||
this == DataTypes.ByteType -> typeOf<Byte>()
|
||||
|
||||
this == DataTypes.ShortType -> typeOf<Short>()
|
||||
|
||||
this == DataTypes.IntegerType -> typeOf<Int>()
|
||||
|
||||
this == DataTypes.LongType -> typeOf<Long>()
|
||||
|
||||
this == DataTypes.BooleanType -> typeOf<Boolean>()
|
||||
|
||||
this == DataTypes.FloatType -> typeOf<Float>()
|
||||
|
||||
this == DataTypes.DoubleType -> typeOf<Double>()
|
||||
|
||||
this == DataTypes.StringType -> typeOf<String>()
|
||||
|
||||
this == DataTypes.DateType -> typeOf<Date>()
|
||||
|
||||
this == DataTypes.TimestampType -> typeOf<Timestamp>()
|
||||
|
||||
this is DecimalType -> typeOf<Decimal>()
|
||||
|
||||
this == DataTypes.CalendarIntervalType -> typeOf<CalendarInterval>()
|
||||
|
||||
this == DataTypes.NullType -> nullableNothingType
|
||||
|
||||
this == DataTypes.BinaryType -> typeOf<ByteArray>()
|
||||
|
||||
this is ArrayType -> {
|
||||
when (elementType()) {
|
||||
DataTypes.ShortType -> typeOf<ShortArray>()
|
||||
DataTypes.IntegerType -> typeOf<IntArray>()
|
||||
DataTypes.LongType -> typeOf<LongArray>()
|
||||
DataTypes.FloatType -> typeOf<FloatArray>()
|
||||
DataTypes.DoubleType -> typeOf<DoubleArray>()
|
||||
DataTypes.BooleanType -> typeOf<BooleanArray>()
|
||||
else -> null
|
||||
}
|
||||
}
|
||||
|
||||
this is MapType -> {
|
||||
val key = keyType().convertToDataFrame() ?: return null
|
||||
val value = valueType().convertToDataFrame() ?: return null
|
||||
Map::class.createType(
|
||||
listOf(
|
||||
KTypeProjection.invariant(key),
|
||||
KTypeProjection.invariant(value.withNullability(valueContainsNull())),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
else -> null
|
||||
}
|
||||
|
||||
// endregion
|
||||
|
||||
// region DataFrame to Spark
|
||||
|
||||
/**
|
||||
* Converts the [DataFrame] to a Spark [Dataset] of [Rows][Row] using the provided [SparkSession] and [JavaSparkContext].
|
||||
*
|
||||
* Spark needs both the data and the schema to be converted to create a correct [Dataset],
|
||||
* so we need to map our types somehow.
|
||||
*
|
||||
* @param spark The [SparkSession] object to use for creating the [DataFrame].
|
||||
* @param sc The [JavaSparkContext] object to use for converting the [DataFrame] to [RDD][JavaRDD].
|
||||
* @return A [Dataset] of [Rows][Row] representing the converted DataFrame.
|
||||
*/
|
||||
fun DataFrame<*>.convertToSpark(spark: SparkSession, sc: JavaSparkContext): Dataset<Row> {
|
||||
// Convert each row to spark rows
|
||||
val rows = sc.parallelize(this.rows().map { it.convertToSpark() })
|
||||
// convert the data schema to a spark StructType
|
||||
val schema = this.schema().convertToSpark()
|
||||
return spark.createDataFrame(rows, schema)
|
||||
}
|
||||
|
||||
/**
|
||||
* Converts a [DataRow] to a Spark [Row] object.
|
||||
*
|
||||
* @return The converted Spark [Row].
|
||||
*/
|
||||
fun DataRow<*>.convertToSpark(): Row =
|
||||
RowFactory.create(
|
||||
*values().map {
|
||||
when (it) {
|
||||
// a row can be nested inside another row if it's a column group
|
||||
is DataRow<*> -> it.convertToSpark()
|
||||
|
||||
is DataFrame<*> -> error("nested dataframes are not supported")
|
||||
|
||||
else -> it
|
||||
}
|
||||
}.toTypedArray(),
|
||||
)
|
||||
|
||||
/**
|
||||
* Converts a [DataFrameSchema] to a Spark [StructType].
|
||||
*
|
||||
* @return The converted Spark [StructType].
|
||||
*/
|
||||
fun DataFrameSchema.convertToSpark(): StructType =
|
||||
DataTypes.createStructType(
|
||||
this.columns.map { (name, schema) ->
|
||||
DataTypes.createStructField(name, schema.convertToSpark(), schema.nullable)
|
||||
},
|
||||
)
|
||||
|
||||
/**
|
||||
* Converts a [ColumnSchema] object to Spark [DataType].
|
||||
*
|
||||
* @return The Spark [DataType] corresponding to the given [ColumnSchema] object.
|
||||
* @throws IllegalArgumentException if the column type or kind is unknown.
|
||||
*/
|
||||
fun ColumnSchema.convertToSpark(): DataType =
|
||||
when (this) {
|
||||
is ColumnSchema.Value -> type.convertToSpark() ?: error("unknown data type: $type")
|
||||
is ColumnSchema.Group -> schema.convertToSpark()
|
||||
is ColumnSchema.Frame -> error("nested dataframes are not supported")
|
||||
else -> error("unknown column kind: $this")
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the corresponding Spark [DataType] for a given [Kotlin type][KType].
|
||||
*
|
||||
* This list may be incomplete, but it can at least give you a good start.
|
||||
*
|
||||
* @return The Spark [DataType] that corresponds to the [Kotlin type][KType], or null if no matching [DataType] is found.
|
||||
*/
|
||||
fun KType.convertToSpark(): DataType? =
|
||||
when {
|
||||
isSubtypeOf(typeOf<Byte?>()) -> DataTypes.ByteType
|
||||
|
||||
isSubtypeOf(typeOf<Short?>()) -> DataTypes.ShortType
|
||||
|
||||
isSubtypeOf(typeOf<Int?>()) -> DataTypes.IntegerType
|
||||
|
||||
isSubtypeOf(typeOf<Long?>()) -> DataTypes.LongType
|
||||
|
||||
isSubtypeOf(typeOf<Boolean?>()) -> DataTypes.BooleanType
|
||||
|
||||
isSubtypeOf(typeOf<Float?>()) -> DataTypes.FloatType
|
||||
|
||||
isSubtypeOf(typeOf<Double?>()) -> DataTypes.DoubleType
|
||||
|
||||
isSubtypeOf(typeOf<String?>()) -> DataTypes.StringType
|
||||
|
||||
isSubtypeOf(typeOf<LocalDate?>()) -> DataTypes.DateType
|
||||
|
||||
isSubtypeOf(typeOf<Date?>()) -> DataTypes.DateType
|
||||
|
||||
isSubtypeOf(typeOf<Timestamp?>()) -> DataTypes.TimestampType
|
||||
|
||||
isSubtypeOf(typeOf<Instant?>()) -> DataTypes.TimestampType
|
||||
|
||||
isSubtypeOf(typeOf<Decimal?>()) -> DecimalType.SYSTEM_DEFAULT()
|
||||
|
||||
isSubtypeOf(typeOf<BigDecimal?>()) -> DecimalType.SYSTEM_DEFAULT()
|
||||
|
||||
isSubtypeOf(typeOf<BigInteger?>()) -> DecimalType.SYSTEM_DEFAULT()
|
||||
|
||||
isSubtypeOf(typeOf<CalendarInterval?>()) -> DataTypes.CalendarIntervalType
|
||||
|
||||
isSubtypeOf(nullableNothingType) -> DataTypes.NullType
|
||||
|
||||
isSubtypeOf(typeOf<ByteArray?>()) -> DataTypes.BinaryType
|
||||
|
||||
isSubtypeOf(typeOf<ShortArray?>()) -> DataTypes.createArrayType(DataTypes.ShortType, false)
|
||||
|
||||
isSubtypeOf(typeOf<IntArray?>()) -> DataTypes.createArrayType(DataTypes.IntegerType, false)
|
||||
|
||||
isSubtypeOf(typeOf<LongArray?>()) -> DataTypes.createArrayType(DataTypes.LongType, false)
|
||||
|
||||
isSubtypeOf(typeOf<FloatArray?>()) -> DataTypes.createArrayType(DataTypes.FloatType, false)
|
||||
|
||||
isSubtypeOf(typeOf<DoubleArray?>()) -> DataTypes.createArrayType(DataTypes.DoubleType, false)
|
||||
|
||||
isSubtypeOf(typeOf<BooleanArray?>()) -> DataTypes.createArrayType(DataTypes.BooleanType, false)
|
||||
|
||||
isSubtypeOf(typeOf<Array<*>>()) ->
|
||||
error("non-primitive arrays are not supported for now, you can add it yourself")
|
||||
|
||||
isSubtypeOf(typeOf<List<*>>()) -> error("lists are not supported for now, you can add it yourself")
|
||||
|
||||
isSubtypeOf(typeOf<Set<*>>()) -> error("sets are not supported for now, you can add it yourself")
|
||||
|
||||
classifier == Map::class -> {
|
||||
val (key, value) = arguments
|
||||
DataTypes.createMapType(
|
||||
key.type?.convertToSpark(),
|
||||
value.type?.convertToSpark(),
|
||||
value.type?.isMarkedNullable ?: true,
|
||||
)
|
||||
}
|
||||
|
||||
else -> null
|
||||
}
|
||||
|
||||
private val nullableNothingType: KType = typeOf<List<Nothing?>>().arguments.first().type!!
|
||||
|
||||
// endregion
|
||||
+105
@@ -0,0 +1,105 @@
|
||||
@file:Suppress("ktlint:standard:function-signature")
|
||||
|
||||
package org.jetbrains.kotlinx.dataframe.examples.spark
|
||||
|
||||
import org.apache.spark.SparkConf
|
||||
import org.apache.spark.api.java.JavaSparkContext
|
||||
import org.apache.spark.sql.Dataset
|
||||
import org.apache.spark.sql.Encoder
|
||||
import org.apache.spark.sql.Encoders
|
||||
import org.apache.spark.sql.SparkSession
|
||||
import org.jetbrains.kotlinx.dataframe.annotations.DataSchema
|
||||
import org.jetbrains.kotlinx.dataframe.api.aggregate
|
||||
import org.jetbrains.kotlinx.dataframe.api.groupBy
|
||||
import org.jetbrains.kotlinx.dataframe.api.max
|
||||
import org.jetbrains.kotlinx.dataframe.api.mean
|
||||
import org.jetbrains.kotlinx.dataframe.api.min
|
||||
import org.jetbrains.kotlinx.dataframe.api.print
|
||||
import org.jetbrains.kotlinx.dataframe.api.schema
|
||||
import org.jetbrains.kotlinx.dataframe.api.std
|
||||
import org.jetbrains.kotlinx.dataframe.api.toDataFrame
|
||||
import org.jetbrains.kotlinx.dataframe.api.toList
|
||||
import java.io.Serializable
|
||||
|
||||
/**
|
||||
* For Spark, Kotlin data classes are supported if we:
|
||||
* - Add [@JvmOverloads][JvmOverloads] to the constructor
|
||||
* - Make all parameter arguments mutable and with defaults
|
||||
* - Make them [Serializable]
|
||||
*
|
||||
* But by adding [@DataSchema][DataSchema] we can reuse the same class for Spark and DataFrame!
|
||||
*
|
||||
* See [Person] and [Name] for an example.
|
||||
*
|
||||
* Also, since we use an actual class to define the schema, we need no type conversion!
|
||||
*
|
||||
* NOTE: You will likely need to run this function with Java 8 or 11 for it to work correctly.
|
||||
* Use the `runSparkTypedDataset` Gradle task to do so.
|
||||
*/
|
||||
fun main() {
|
||||
val spark = SparkSession.builder()
|
||||
.master(SparkConf().get("spark.master", "local[*]"))
|
||||
.appName("Kotlin Spark Sample")
|
||||
.getOrCreate()
|
||||
val sc = JavaSparkContext(spark.sparkContext())
|
||||
|
||||
// Creating a Spark Dataset. Usually, this is loaded from some server or database.
|
||||
val rawDataset: Dataset<Person> = spark.createDataset(
|
||||
listOf(
|
||||
Person(Name("Alice", "Cooper"), 15, "London", 54, true),
|
||||
Person(Name("Bob", "Dylan"), 45, "Dubai", 87, true),
|
||||
Person(Name("Charlie", "Daniels"), 20, "Moscow", null, false),
|
||||
Person(Name("Charlie", "Chaplin"), 40, "Milan", null, true),
|
||||
Person(Name("Bob", "Marley"), 30, "Tokyo", 68, true),
|
||||
Person(Name("Alice", "Wolf"), 20, null, 55, false),
|
||||
Person(Name("Charlie", "Byrd"), 30, "Moscow", 90, true),
|
||||
),
|
||||
beanEncoderOf(),
|
||||
)
|
||||
|
||||
// we can perform large operations in Spark.
|
||||
// DataFrames are in-memory structures, so this is a good place to limit the number of rows if you don't have the RAM ;)
|
||||
val dataset = rawDataset.filter { it.age > 17 }
|
||||
|
||||
// and convert it to DataFrame via a typed List
|
||||
val dataframe = dataset.collectAsList().toDataFrame()
|
||||
dataframe.schema().print()
|
||||
dataframe.print(columnTypes = true, borders = true)
|
||||
|
||||
// now we can use DataFrame-specific functions
|
||||
val ageStats = dataframe
|
||||
.groupBy { city }.aggregate {
|
||||
mean { age } into "meanAge"
|
||||
std { age } into "stdAge"
|
||||
min { age } into "minAge"
|
||||
max { age } into "maxAge"
|
||||
}
|
||||
|
||||
ageStats.print(columnTypes = true, borders = true)
|
||||
|
||||
// and when we want to convert a DataFrame back to Spark, we can do the same trick via a typed List
|
||||
val sparkDatasetAgain = spark.createDataset(dataframe.toList(), beanEncoderOf())
|
||||
sparkDatasetAgain.printSchema()
|
||||
sparkDatasetAgain.show()
|
||||
|
||||
spark.stop()
|
||||
}
|
||||
|
||||
/** Creates a [bean encoder][Encoders.bean] for the given [T] instance. */
|
||||
inline fun <reified T : Serializable> beanEncoderOf(): Encoder<T> = Encoders.bean(T::class.java)
|
||||
|
||||
@DataSchema
|
||||
data class Name
|
||||
@JvmOverloads
|
||||
constructor(var firstName: String = "", var lastName: String = "") : Serializable
|
||||
|
||||
@DataSchema
|
||||
data class Person
|
||||
@JvmOverloads
|
||||
constructor(
|
||||
var name: Name = Name(),
|
||||
var age: Int = -1,
|
||||
var city: String? = null,
|
||||
var weight: Int? = null,
|
||||
var isHappy: Boolean = false,
|
||||
) : Serializable
|
||||
+87
@@ -0,0 +1,87 @@
|
||||
@file:Suppress("ktlint:standard:function-signature")
|
||||
|
||||
package org.jetbrains.kotlinx.dataframe.examples.spark
|
||||
|
||||
import org.apache.spark.SparkConf
|
||||
import org.apache.spark.api.java.JavaSparkContext
|
||||
import org.apache.spark.sql.Dataset
|
||||
import org.apache.spark.sql.Row
|
||||
import org.apache.spark.sql.SparkSession
|
||||
import org.jetbrains.kotlinx.dataframe.api.aggregate
|
||||
import org.jetbrains.kotlinx.dataframe.api.groupBy
|
||||
import org.jetbrains.kotlinx.dataframe.api.max
|
||||
import org.jetbrains.kotlinx.dataframe.api.mean
|
||||
import org.jetbrains.kotlinx.dataframe.api.min
|
||||
import org.jetbrains.kotlinx.dataframe.api.print
|
||||
import org.jetbrains.kotlinx.dataframe.api.schema
|
||||
import org.jetbrains.kotlinx.dataframe.api.std
|
||||
import org.jetbrains.kotlinx.dataframe.examples.spark.convertToDataFrame
|
||||
import org.jetbrains.kotlinx.dataframe.examples.spark.convertToDataFrameByInference
|
||||
import org.jetbrains.kotlinx.dataframe.examples.spark.convertToSpark
|
||||
import org.jetbrains.kotlinx.spark.api.col
|
||||
import org.jetbrains.kotlinx.spark.api.gt
|
||||
|
||||
/**
|
||||
* Since we don't know the schema at compile time this time, we need to do
|
||||
* some schema mapping in between Spark and DataFrame.
|
||||
*
|
||||
* We will use spark/compatibilityLayer.kt to do this.
|
||||
* Take a look at that file for the implementation details!
|
||||
*
|
||||
* NOTE: You will likely need to run this function with Java 8 or 11 for it to work correctly.
|
||||
* Use the `runSparkUntypedDataset` Gradle task to do so.
|
||||
*/
|
||||
fun main() {
|
||||
val spark = SparkSession.builder()
|
||||
.master(SparkConf().get("spark.master", "local[*]"))
|
||||
.appName("Kotlin Spark Sample")
|
||||
.getOrCreate()
|
||||
val sc = JavaSparkContext(spark.sparkContext())
|
||||
|
||||
// Creating a Spark Dataframe (untyped Dataset). Usually, this is loaded from some server or database.
|
||||
val rawDataset: Dataset<Row> = spark.createDataset(
|
||||
listOf(
|
||||
Person(Name("Alice", "Cooper"), 15, "London", 54, true),
|
||||
Person(Name("Bob", "Dylan"), 45, "Dubai", 87, true),
|
||||
Person(Name("Charlie", "Daniels"), 20, "Moscow", null, false),
|
||||
Person(Name("Charlie", "Chaplin"), 40, "Milan", null, true),
|
||||
Person(Name("Bob", "Marley"), 30, "Tokyo", 68, true),
|
||||
Person(Name("Alice", "Wolf"), 20, null, 55, false),
|
||||
Person(Name("Charlie", "Byrd"), 30, "Moscow", 90, true),
|
||||
),
|
||||
beanEncoderOf<Person>(),
|
||||
).toDF()
|
||||
|
||||
// we can perform large operations in Spark.
|
||||
// DataFrames are in-memory structures, so this is a good place to limit the number of rows if you don't have the RAM ;)
|
||||
val dataset = rawDataset.filter(col("age") gt 17)
|
||||
|
||||
// Using inference
|
||||
val df1 = dataset.convertToDataFrameByInference()
|
||||
df1.schema().print()
|
||||
df1.print(columnTypes = true, borders = true)
|
||||
|
||||
// Using full schema mapping
|
||||
val df2 = dataset.convertToDataFrame()
|
||||
df2.schema().print()
|
||||
df2.print(columnTypes = true, borders = true)
|
||||
|
||||
// now we can use DataFrame-specific functions
|
||||
val ageStats = df1
|
||||
.groupBy("city").aggregate {
|
||||
mean("age") into "meanAge"
|
||||
std("age") into "stdAge"
|
||||
min("age") into "minAge"
|
||||
max("age") into "maxAge"
|
||||
}
|
||||
|
||||
ageStats.print(columnTypes = true, borders = true)
|
||||
|
||||
// and when we want to convert a DataFrame back to Spark, we will use the `convertToSpark()` extension function
|
||||
// This performs the necessary schema mapping under the hood.
|
||||
val sparkDataset = df2.convertToSpark(spark, sc)
|
||||
sparkDataset.printSchema()
|
||||
sparkDataset.show()
|
||||
|
||||
spark.stop()
|
||||
}
|
||||
Reference in New Issue
Block a user