init research

This commit is contained in:
2026-02-08 11:20:43 -10:00
commit bdf064f54d
3041 changed files with 1592200 additions and 0 deletions
@@ -0,0 +1,43 @@
import org.jetbrains.kotlin.gradle.dsl.JvmTarget
plugins {
application
kotlin("jvm")
// uses the 'old' Gradle plugin instead of the compiler plugin for now
id("org.jetbrains.kotlinx.dataframe")
// only mandatory if `kotlin.dataframe.add.ksp=false` in gradle.properties
id("com.google.devtools.ksp")
}
repositories {
mavenLocal() // in case of local dataframe development
mavenCentral()
}
dependencies {
// implementation("org.jetbrains.kotlinx:dataframe:X.Y.Z")
implementation(project(":"))
// exposed + sqlite database support
implementation(libs.sqlite)
implementation(libs.exposed.core)
implementation(libs.exposed.kotlin.datetime)
implementation(libs.exposed.jdbc)
implementation(libs.exposed.json)
implementation(libs.exposed.money)
}
kotlin {
compilerOptions {
jvmTarget = JvmTarget.JVM_1_8
freeCompilerArgs.add("-Xjdk-release=8")
}
}
tasks.withType<JavaCompile> {
sourceCompatibility = JavaVersion.VERSION_1_8.toString()
targetCompatibility = JavaVersion.VERSION_1_8.toString()
options.release.set(8)
}
@@ -0,0 +1,107 @@
package org.jetbrains.kotlinx.dataframe.examples.exposed
import org.jetbrains.exposed.v1.core.BiCompositeColumn
import org.jetbrains.exposed.v1.core.Column
import org.jetbrains.exposed.v1.core.Expression
import org.jetbrains.exposed.v1.core.ExpressionAlias
import org.jetbrains.exposed.v1.core.ResultRow
import org.jetbrains.exposed.v1.core.Table
import org.jetbrains.exposed.v1.jdbc.Query
import org.jetbrains.kotlinx.dataframe.AnyFrame
import org.jetbrains.kotlinx.dataframe.DataFrame
import org.jetbrains.kotlinx.dataframe.annotations.DataSchema
import org.jetbrains.kotlinx.dataframe.api.convertTo
import org.jetbrains.kotlinx.dataframe.api.toDataFrame
import org.jetbrains.kotlinx.dataframe.codeGen.NameNormalizer
import org.jetbrains.kotlinx.dataframe.impl.schema.DataFrameSchemaImpl
import org.jetbrains.kotlinx.dataframe.schema.ColumnSchema
import org.jetbrains.kotlinx.dataframe.schema.DataFrameSchema
import kotlin.reflect.KProperty1
import kotlin.reflect.full.isSubtypeOf
import kotlin.reflect.full.memberProperties
import kotlin.reflect.typeOf
/**
* Retrieves all columns of any [Iterable][Iterable]`<`[ResultRow][ResultRow]`>`, like [Query][Query],
* from Exposed row by row and converts the resulting [Map] into a [DataFrame], cast to type [T].
*
* In notebooks, the untyped version works just as well due to runtime inference :)
*/
inline fun <reified T : Any> Iterable<ResultRow>.convertToDataFrame(): DataFrame<T> =
convertToDataFrame().convertTo<T>()
/**
* Retrieves all columns of an [Iterable][Iterable]`<`[ResultRow][ResultRow]`>` from Exposed, like [Query][Query],
* row by row and converts the resulting [Map] of lists into a [DataFrame] by calling
* [Map.toDataFrame].
*/
@JvmName("convertToAnyFrame")
fun Iterable<ResultRow>.convertToDataFrame(): AnyFrame {
val map = mutableMapOf<String, MutableList<Any?>>()
for (row in this) {
for (expression in row.fieldIndex.keys) {
map.getOrPut(expression.readableName) {
mutableListOf()
} += row[expression]
}
}
return map.toDataFrame()
}
/**
* Retrieves a simple column name from [this] [Expression].
*
* Might need to be expanded with multiple types of [Expression].
*/
val Expression<*>.readableName: String
get() = when (this) {
is Column<*> -> name
is ExpressionAlias<*> -> alias
is BiCompositeColumn<*, *, *> -> getRealColumns().joinToString("_") { it.readableName }
else -> toString()
}
/**
* Creates a [DataFrameSchema] from the declared [Table] instance.
*
* This is not needed for conversion, but it can be useful to create a DataFrame [@DataSchema][DataSchema] instance.
*
* @param columnNameToAccessor Optional [MutableMap] which will be filled with entries mapping
* the SQL column name to the accessor name from the [Table].
* This can be used to define a [NameNormalizer] later.
* @see toDataFrameSchemaWithNameNormalizer
*/
@Suppress("UNCHECKED_CAST")
fun Table.toDataFrameSchema(columnNameToAccessor: MutableMap<String, String> = mutableMapOf()): DataFrameSchema {
// we use reflection to go over all `Column<*>` properties in the Table object
val columns = this::class.memberProperties
.filter { it.returnType.isSubtypeOf(typeOf<Column<*>>()) }
.associate { prop ->
prop as KProperty1<Table, Column<*>>
// retrieve the SQL column name
val columnName = prop.get(this).name
// store the SQL column name together with the accessor name in the map
columnNameToAccessor[columnName] = prop.name
// get the column type from `val a: Column<Type>`
val type = prop.returnType.arguments.first().type!!
// and we add the name and column shema type to the `columns` map :)
columnName to ColumnSchema.Value(type)
}
return DataFrameSchemaImpl(columns)
}
/**
* Creates a [DataFrameSchema] from the declared [Table] instance with a [NameNormalizer] to
* convert the SQL column names to the corresponding Kotlin property names.
*
* This is not needed for conversion, but it can be useful to create a DataFrame [@DataSchema][DataSchema] instance.
*
* @see toDataFrameSchema
*/
fun Table.toDataFrameSchemaWithNameNormalizer(): Pair<DataFrameSchema, NameNormalizer> {
val columnNameToAccessor = mutableMapOf<String, String>()
return Pair(toDataFrameSchema(), NameNormalizer { columnNameToAccessor[it] ?: it })
}
@@ -0,0 +1,96 @@
package org.jetbrains.kotlinx.dataframe.examples.exposed
import org.jetbrains.exposed.v1.core.Column
import org.jetbrains.exposed.v1.core.SortOrder
import org.jetbrains.exposed.v1.core.count
import org.jetbrains.exposed.v1.jdbc.Database
import org.jetbrains.exposed.v1.jdbc.SchemaUtils
import org.jetbrains.exposed.v1.jdbc.batchInsert
import org.jetbrains.exposed.v1.jdbc.deleteAll
import org.jetbrains.exposed.v1.jdbc.select
import org.jetbrains.exposed.v1.jdbc.selectAll
import org.jetbrains.exposed.v1.jdbc.transactions.transaction
import org.jetbrains.kotlinx.dataframe.api.asSequence
import org.jetbrains.kotlinx.dataframe.api.count
import org.jetbrains.kotlinx.dataframe.api.describe
import org.jetbrains.kotlinx.dataframe.api.groupBy
import org.jetbrains.kotlinx.dataframe.api.print
import org.jetbrains.kotlinx.dataframe.api.sortByDesc
import org.jetbrains.kotlinx.dataframe.size
import java.io.File
/**
* Describes a simple bridge between [Exposed](https://www.jetbrains.com/exposed/) and DataFrame!
*/
fun main() {
// defining where to find our SQLite database for Exposed
val resourceDb = "chinook.db"
val dbPath = File(object {}.javaClass.classLoader.getResource(resourceDb)!!.toURI()).absolutePath
val db = Database.connect(url = "jdbc:sqlite:$dbPath", driver = "org.sqlite.JDBC")
// let's read the database!
val df = transaction(db) {
// addLogger(StdOutSqlLogger) // enable if you want to see verbose logs
// tables in Exposed need to be defined, see tables.kt
SchemaUtils.create(Customers, Artists, Albums)
println()
// In Exposed, we can write queries like this.
// Here, we count per country how many customers there are and print the results:
Customers
.select(Customers.country, Customers.customerId.count())
.groupBy(Customers.country)
.orderBy(Customers.customerId.count() to SortOrder.DESC)
.forEach {
println("${it[Customers.country]}: ${it[Customers.customerId.count()]} customers")
}
println()
// Perform the specific query you want to read into the DataFrame.
// Note: DataFrames are in-memory structures, so don't make it too large if you don't have the RAM ;)
val query = Customers.selectAll() // .where { Customers.company.isNotNull() }
println()
// read and convert the query to a typed DataFrame
// see compatibilityLayer.kt for how we created convertToDataFrame<>()
// and see tables.kt for how we created DfCustomers!
query.convertToDataFrame<DfCustomers>()
}
println(df.size())
// now we have a DataFrame, we can perform DataFrame operations,
// like doing the same operation as we did in Exposed above
df.groupBy { country }.count()
.sortByDesc { "count"<Int>() }
.print(columnTypes = true, borders = true)
// or just general statistics
df.describe()
.print(columnTypes = true, borders = true)
// or make plots using Kandy! It's all up to you
// writing a DataFrame back into an SQL database with Exposed can also be done easily!
transaction(db) {
// addLogger(StdOutSqlLogger) // enable if you want to see verbose logs
// first delete the original contents
Customers.deleteAll()
println()
// batch-insert our dataframe back into the SQL database as a sequence of rows
Customers.batchInsert(df.asSequence()) { dfRow ->
// we simply go over each value in the row and put it in the right place in the Exposed statement
for (column in Customers.columns) {
@Suppress("UNCHECKED_CAST")
this[column as Column<Any?>] = dfRow[column.name]
}
}
}
}
@@ -0,0 +1,97 @@
package org.jetbrains.kotlinx.dataframe.examples.exposed
import org.jetbrains.exposed.v1.core.Column
import org.jetbrains.exposed.v1.core.Table
import org.jetbrains.kotlinx.dataframe.annotations.ColumnName
import org.jetbrains.kotlinx.dataframe.annotations.DataSchema
import org.jetbrains.kotlinx.dataframe.api.generateDataClasses
import org.jetbrains.kotlinx.dataframe.api.print
object Albums : Table() {
val albumId: Column<Int> = integer("AlbumId").autoIncrement()
val title: Column<String> = varchar("Title", 160)
val artistId: Column<Int> = integer("ArtistId")
override val primaryKey = PrimaryKey(albumId)
}
object Artists : Table() {
val artistId: Column<Int> = integer("ArtistId").autoIncrement()
val name: Column<String> = varchar("Name", 120)
override val primaryKey = PrimaryKey(artistId)
}
object Customers : Table() {
val customerId: Column<Int> = integer("CustomerId").autoIncrement()
val firstName: Column<String> = varchar("FirstName", 40)
val lastName: Column<String> = varchar("LastName", 20)
val company: Column<String?> = varchar("Company", 80).nullable()
val address: Column<String?> = varchar("Address", 70).nullable()
val city: Column<String?> = varchar("City", 40).nullable()
val state: Column<String?> = varchar("State", 40).nullable()
val country: Column<String?> = varchar("Country", 40).nullable()
val postalCode: Column<String?> = varchar("PostalCode", 10).nullable()
val phone: Column<String?> = varchar("Phone", 24).nullable()
val fax: Column<String?> = varchar("Fax", 24).nullable()
val email: Column<String> = varchar("Email", 60)
val supportRepId: Column<Int?> = integer("SupportRepId").nullable()
override val primaryKey = PrimaryKey(customerId)
}
/**
* Exposed requires you to provide [Table] instances to
* provide type-safe access to your columns and data.
*
* While DataFrame can infer types at runtime, which is enough for Kotlin Notebook,
* to get type safe access at compile time, we need to define a [@DataSchema][DataSchema].
*
* This is what we created the [toDataFrameSchema] function for!
*/
fun main() {
val (schema, nameNormalizer) = Customers.toDataFrameSchemaWithNameNormalizer()
// checking whether the schema is converted correctly.
// schema.print()
// printing a @DataSchema data class to copy-paste into the code.
// we use a NameNormalizer to let DataFrame generate the same accessors as in the Table
// while keeping the correct column names
schema.generateDataClasses(
markerName = "DfCustomers",
nameNormalizer = nameNormalizer,
).print()
}
// created by Customers.toDataFrameSchema()
// The same can be done for the other tables
@DataSchema
data class DfCustomers(
@ColumnName("Address")
val address: String?,
@ColumnName("City")
val city: String?,
@ColumnName("Company")
val company: String?,
@ColumnName("Country")
val country: String?,
@ColumnName("CustomerId")
val customerId: Int,
@ColumnName("Email")
val email: String,
@ColumnName("Fax")
val fax: String?,
@ColumnName("FirstName")
val firstName: String,
@ColumnName("LastName")
val lastName: String,
@ColumnName("Phone")
val phone: String?,
@ColumnName("PostalCode")
val postalCode: String?,
@ColumnName("State")
val state: String?,
@ColumnName("SupportRepId")
val supportRepId: Int?,
)
@@ -0,0 +1,43 @@
import org.jetbrains.kotlin.gradle.dsl.JvmTarget
plugins {
application
kotlin("jvm")
// uses the 'old' Gradle plugin instead of the compiler plugin for now
id("org.jetbrains.kotlinx.dataframe")
// only mandatory if `kotlin.dataframe.add.ksp=false` in gradle.properties
id("com.google.devtools.ksp")
}
repositories {
mavenLocal() // in case of local dataframe development
mavenCentral()
}
dependencies {
// implementation("org.jetbrains.kotlinx:dataframe:X.Y.Z")
implementation(project(":"))
// Hibernate + H2 + HikariCP (for Hibernate example)
implementation(libs.hibernate.core)
implementation(libs.hibernate.hikaricp)
implementation(libs.hikaricp)
implementation(libs.h2db)
implementation(libs.sl4jsimple)
}
kotlin {
compilerOptions {
jvmTarget = JvmTarget.JVM_11
freeCompilerArgs.add("-Xjdk-release=11")
}
}
tasks.withType<JavaCompile> {
sourceCompatibility = JavaVersion.VERSION_11.toString()
targetCompatibility = JavaVersion.VERSION_11.toString()
options.release.set(11)
}
@@ -0,0 +1,100 @@
package org.jetbrains.kotlinx.dataframe.examples.hibernate
import jakarta.persistence.Column
import jakarta.persistence.Entity
import jakarta.persistence.GeneratedValue
import jakarta.persistence.GenerationType
import jakarta.persistence.Id
import jakarta.persistence.Table
import org.jetbrains.kotlinx.dataframe.annotations.ColumnName
import org.jetbrains.kotlinx.dataframe.annotations.DataSchema
@Entity
@Table(name = "Albums")
class AlbumsEntity(
@Id
@GeneratedValue(strategy = GenerationType.IDENTITY)
@Column(name = "AlbumId")
var albumId: Int? = null,
@Column(name = "Title", length = 160, nullable = false)
var title: String = "",
@Column(name = "ArtistId", nullable = false)
var artistId: Int = 0,
)
@Entity
@Table(name = "Artists")
class ArtistsEntity(
@Id
@GeneratedValue(strategy = GenerationType.IDENTITY)
@Column(name = "ArtistId")
var artistId: Int? = null,
@Column(name = "Name", length = 120, nullable = false)
var name: String = "",
)
@Entity
@Table(name = "Customers")
class CustomersEntity(
@Id
@GeneratedValue(strategy = GenerationType.IDENTITY)
@Column(name = "CustomerId")
var customerId: Int? = null,
@Column(name = "FirstName", length = 40, nullable = false)
var firstName: String = "",
@Column(name = "LastName", length = 20, nullable = false)
var lastName: String = "",
@Column(name = "Company", length = 80)
var company: String? = null,
@Column(name = "Address", length = 70)
var address: String? = null,
@Column(name = "City", length = 40)
var city: String? = null,
@Column(name = "State", length = 40)
var state: String? = null,
@Column(name = "Country", length = 40)
var country: String? = null,
@Column(name = "PostalCode", length = 10)
var postalCode: String? = null,
@Column(name = "Phone", length = 24)
var phone: String? = null,
@Column(name = "Fax", length = 24)
var fax: String? = null,
@Column(name = "Email", length = 60, nullable = false)
var email: String = "",
@Column(name = "SupportRepId")
var supportRepId: Int? = null,
)
// DataFrame schema to get typed accessors similar to Exposed example
@DataSchema
data class DfCustomers(
@ColumnName("Address") val address: String?,
@ColumnName("City") val city: String?,
@ColumnName("Company") val company: String?,
@ColumnName("Country") val country: String?,
@ColumnName("CustomerId") val customerId: Int,
@ColumnName("Email") val email: String,
@ColumnName("Fax") val fax: String?,
@ColumnName("FirstName") val firstName: String,
@ColumnName("LastName") val lastName: String,
@ColumnName("Phone") val phone: String?,
@ColumnName("PostalCode") val postalCode: String?,
@ColumnName("State") val state: String?,
@ColumnName("SupportRepId") val supportRepId: Int?,
)
@@ -0,0 +1,251 @@
package org.jetbrains.kotlinx.dataframe.examples.hibernate
import jakarta.persistence.Tuple
import jakarta.persistence.criteria.CriteriaBuilder
import jakarta.persistence.criteria.CriteriaDelete
import jakarta.persistence.criteria.CriteriaQuery
import jakarta.persistence.criteria.Expression
import jakarta.persistence.criteria.Root
import org.hibernate.FlushMode
import org.hibernate.SessionFactory
import org.hibernate.cfg.Configuration
import org.jetbrains.kotlinx.dataframe.DataFrame
import org.jetbrains.kotlinx.dataframe.DataRow
import org.jetbrains.kotlinx.dataframe.api.asSequence
import org.jetbrains.kotlinx.dataframe.api.count
import org.jetbrains.kotlinx.dataframe.api.describe
import org.jetbrains.kotlinx.dataframe.api.groupBy
import org.jetbrains.kotlinx.dataframe.api.print
import org.jetbrains.kotlinx.dataframe.api.sortByDesc
import org.jetbrains.kotlinx.dataframe.api.toDataFrame
import org.jetbrains.kotlinx.dataframe.size
/**
* Example showing Kotlin DataFrame with Hibernate ORM + H2 in-memory DB.
* Mirrors logic from the Exposed example: load data, convert to DataFrame, group/describe, write back.
*/
fun main() {
val sessionFactory: SessionFactory = buildSessionFactory()
sessionFactory.insertSampleData()
val df = sessionFactory.loadCustomersAsDataFrame()
// Pure Hibernate + Criteria API approach for counting customers per country
println("=== Hibernate + Criteria API Approach ===")
sessionFactory.countCustomersPerCountryWithHibernate()
println("\n=== DataFrame Approach ===")
df.analyzeAndPrintResults()
sessionFactory.replaceCustomersFromDataFrame(df)
sessionFactory.close()
}
private fun SessionFactory.insertSampleData() {
withTransaction { session ->
// a few artists and albums (minimal, not used further; just demo schema)
val artist1 = ArtistsEntity(name = "AC/DC")
val artist2 = ArtistsEntity(name = "Queen")
session.persist(artist1)
session.persist(artist2)
session.flush()
session.persist(AlbumsEntity(title = "High Voltage", artistId = artist1.artistId!!))
session.persist(AlbumsEntity(title = "Back in Black", artistId = artist1.artistId!!))
session.persist(AlbumsEntity(title = "A Night at the Opera", artistId = artist2.artistId!!))
// customers we'll analyze using DataFrame
session.persist(
CustomersEntity(
firstName = "John",
lastName = "Doe",
email = "john.doe@example.com",
country = "USA",
),
)
session.persist(
CustomersEntity(
firstName = "Jane",
lastName = "Smith",
email = "jane.smith@example.com",
country = "USA",
),
)
session.persist(
CustomersEntity(
firstName = "Alice",
lastName = "Wang",
email = "alice.wang@example.com",
country = "Canada",
),
)
}
}
private fun SessionFactory.loadCustomersAsDataFrame(): DataFrame<DfCustomers> {
return withReadOnlyTransaction { session ->
val criteriaBuilder: CriteriaBuilder = session.criteriaBuilder
val criteriaQuery: CriteriaQuery<CustomersEntity> = criteriaBuilder.createQuery(CustomersEntity::class.java)
val root: Root<CustomersEntity> = criteriaQuery.from(CustomersEntity::class.java)
criteriaQuery.select(root)
session.createQuery(criteriaQuery)
.resultList
.map { c ->
DfCustomers(
address = c.address,
city = c.city,
company = c.company,
country = c.country,
customerId = c.customerId ?: -1,
email = c.email,
fax = c.fax,
firstName = c.firstName,
lastName = c.lastName,
phone = c.phone,
postalCode = c.postalCode,
state = c.state,
supportRepId = c.supportRepId,
)
}
.toDataFrame()
}
}
/** DTO used for aggregation projection. */
private data class CountryCountDto(
val country: String,
val customerCount: Long,
)
/**
* **Hibernate + Criteria API:**
* - ✅ Database-level aggregation (efficient)
* - ✅ Type-safe queries
* - ❌ Verbose syntax
* - ❌ Limited to SQL-like operations
*/
private fun SessionFactory.countCustomersPerCountryWithHibernate() {
withReadOnlyTransaction { session ->
val cb = session.criteriaBuilder
val cq: CriteriaQuery<CountryCountDto> = cb.createQuery(CountryCountDto::class.java)
val root: Root<CustomersEntity> = cq.from(CustomersEntity::class.java)
val countryPath = root.get<String>("country")
val idPath = root.get<Long>("customerId")
val countExpr = cb.count(idPath)
cq.select(
cb.construct(
CountryCountDto::class.java,
countryPath, // country
countExpr, // customerCount
),
)
cq.groupBy(countryPath)
cq.orderBy(cb.desc(countExpr))
val results = session.createQuery(cq).resultList
results.forEach { dto ->
println("${dto.country}: ${dto.customerCount} customers")
}
}
}
/**
* **DataFrame approach: **
* - ✅ Rich analytical operations
* - ✅ Fluent, readable API
* - ✅ Flexible data transformations
* - ❌ In-memory processing (less efficient for large datasets)
*/
private fun DataFrame<DfCustomers>.analyzeAndPrintResults() {
println(size())
// same operation as Exposed example: customers per country
groupBy { country }.count()
.sortByDesc { "count"<Int>() }
.print(columnTypes = true, borders = true)
// general statistics
describe()
.print(columnTypes = true, borders = true)
}
private fun SessionFactory.replaceCustomersFromDataFrame(df: DataFrame<DfCustomers>) {
withTransaction { session ->
val criteriaBuilder: CriteriaBuilder = session.criteriaBuilder
val criteriaDelete: CriteriaDelete<CustomersEntity> =
criteriaBuilder.createCriteriaDelete(CustomersEntity::class.java)
criteriaDelete.from(CustomersEntity::class.java)
session.createMutationQuery(criteriaDelete).executeUpdate()
}
withTransaction { session ->
df.asSequence().forEach { row ->
session.persist(row.toCustomersEntity())
}
}
}
private fun DataRow<DfCustomers>.toCustomersEntity(): CustomersEntity {
return CustomersEntity(
customerId = null, // let DB generate
firstName = this.firstName,
lastName = this.lastName,
company = this.company,
address = this.address,
city = this.city,
state = this.state,
country = this.country,
postalCode = this.postalCode,
phone = this.phone,
fax = this.fax,
email = this.email,
supportRepId = this.supportRepId,
)
}
private inline fun <T> SessionFactory.withSession(block: (session: org.hibernate.Session) -> T): T {
return openSession().use(block)
}
private inline fun SessionFactory.withTransaction(block: (session: org.hibernate.Session) -> Unit) {
withSession { session ->
session.beginTransaction()
try {
block(session)
session.transaction.commit()
} catch (e: Exception) {
session.transaction.rollback()
throw e
}
}
}
/** Read-only transaction helper for SELECT queries to minimize overhead. */
private inline fun <T> SessionFactory.withReadOnlyTransaction(block: (session: org.hibernate.Session) -> T): T {
return withSession { session ->
session.beginTransaction()
// Minimize overhead for read operations
session.isDefaultReadOnly = true
session.hibernateFlushMode = FlushMode.MANUAL
try {
val result = block(session)
session.transaction.commit()
result
} catch (e: Exception) {
session.transaction.rollback()
throw e
}
}
}
private fun buildSessionFactory(): SessionFactory {
// Load configuration from resources/hibernate/hibernate.cfg.xml
return Configuration().configure("hibernate/hibernate.cfg.xml").buildSessionFactory()
}
@@ -0,0 +1,32 @@
<?xml version="1.0" encoding="utf-8"?>
<!DOCTYPE hibernate-configuration PUBLIC
"-//Hibernate/Hibernate Configuration DTD 5.3//EN"
"http://hibernate.org/dtd/hibernate-configuration-5.3.dtd">
<hibernate-configuration>
<session-factory>
<!-- H2 in-memory -->
<property name="hibernate.connection.driver_class">org.h2.Driver</property>
<property name="hibernate.connection.url">jdbc:h2:mem:testdb;DB_CLOSE_DELAY=-1</property>
<property name="hibernate.connection.username">sa</property>
<property name="hibernate.connection.password"></property>
<!-- Connection pool: HikariCP via Hibernate integration -->
<property name="hibernate.connection.provider_class">org.hibernate.hikaricp.internal.HikariCPConnectionProvider</property>
<property name="hibernate.hikari.maximumPoolSize">5</property>
<!-- Hibernate Dialect -->
<property name="hibernate.dialect">org.hibernate.dialect.H2Dialect</property>
<!-- Automatic schema generation -->
<property name="hibernate.hbm2ddl.auto">create-drop</property>
<!-- Logging -->
<property name="hibernate.show_sql">true</property>
<property name="hibernate.format_sql">true</property>
<!-- Mappings -->
<mapping class="org.jetbrains.kotlinx.dataframe.examples.hibernate.CustomersEntity"/>
<mapping class="org.jetbrains.kotlinx.dataframe.examples.hibernate.ArtistsEntity"/>
<mapping class="org.jetbrains.kotlinx.dataframe.examples.hibernate.AlbumsEntity"/>
</session-factory>
</hibernate-configuration>
@@ -0,0 +1,39 @@
import org.jetbrains.kotlin.gradle.dsl.JvmTarget
plugins {
application
kotlin("jvm")
// uses the 'old' Gradle plugin instead of the compiler plugin for now
id("org.jetbrains.kotlinx.dataframe")
// only mandatory if `kotlin.dataframe.add.ksp=false` in gradle.properties
id("com.google.devtools.ksp")
}
repositories {
mavenLocal() // in case of local dataframe development
mavenCentral()
}
dependencies {
// implementation("org.jetbrains.kotlinx:dataframe:X.Y.Z")
implementation(project(":"))
// multik support
implementation(libs.multik.core)
implementation(libs.multik.default)
}
kotlin {
compilerOptions {
jvmTarget = JvmTarget.JVM_1_8
freeCompilerArgs.add("-Xjdk-release=8")
}
}
tasks.withType<JavaCompile> {
sourceCompatibility = JavaVersion.VERSION_1_8.toString()
targetCompatibility = JavaVersion.VERSION_1_8.toString()
options.release.set(8)
}
@@ -0,0 +1,374 @@
@file:OptIn(ExperimentalTypeInference::class)
package org.jetbrains.kotlinx.dataframe.examples.multik
import org.jetbrains.kotlinx.dataframe.AnyFrame
import org.jetbrains.kotlinx.dataframe.ColumnSelector
import org.jetbrains.kotlinx.dataframe.ColumnsSelector
import org.jetbrains.kotlinx.dataframe.DataColumn
import org.jetbrains.kotlinx.dataframe.DataFrame
import org.jetbrains.kotlinx.dataframe.api.ValueProperty
import org.jetbrains.kotlinx.dataframe.api.cast
import org.jetbrains.kotlinx.dataframe.api.colsOf
import org.jetbrains.kotlinx.dataframe.api.column
import org.jetbrains.kotlinx.dataframe.api.dataFrameOf
import org.jetbrains.kotlinx.dataframe.api.getColumn
import org.jetbrains.kotlinx.dataframe.api.getColumns
import org.jetbrains.kotlinx.dataframe.api.map
import org.jetbrains.kotlinx.dataframe.api.named
import org.jetbrains.kotlinx.dataframe.api.toColumn
import org.jetbrains.kotlinx.dataframe.api.toColumnGroup
import org.jetbrains.kotlinx.dataframe.api.toDataFrame
import org.jetbrains.kotlinx.dataframe.columns.BaseColumn
import org.jetbrains.kotlinx.dataframe.columns.ColumnGroup
import org.jetbrains.kotlinx.multik.api.mk
import org.jetbrains.kotlinx.multik.api.ndarray
import org.jetbrains.kotlinx.multik.ndarray.complex.Complex
import org.jetbrains.kotlinx.multik.ndarray.data.D1Array
import org.jetbrains.kotlinx.multik.ndarray.data.D2Array
import org.jetbrains.kotlinx.multik.ndarray.data.D3Array
import org.jetbrains.kotlinx.multik.ndarray.data.MultiArray
import org.jetbrains.kotlinx.multik.ndarray.data.NDArray
import org.jetbrains.kotlinx.multik.ndarray.data.get
import org.jetbrains.kotlinx.multik.ndarray.operations.toList
import org.jetbrains.kotlinx.multik.ndarray.operations.toListD2
import kotlin.experimental.ExperimentalTypeInference
import kotlin.reflect.KClass
import kotlin.reflect.KType
import kotlin.reflect.full.isSubtypeOf
import kotlin.reflect.typeOf
// region 1D
/** Converts a one-dimensional array ([D1Array]) to a [DataColumn] with optional [name]. */
inline fun <reified N> D1Array<N>.convertToColumn(name: String = ""): DataColumn<N> {
// we can simply convert the 1D array to a typed list and create a typed column from it
// by using the reified type parameter, DataFrame needs to do no inference :)
val values = this.toList()
return column<N>(values) named name
}
/**
* Converts a one-dimensional array ([D1Array]) of type [N] into a DataFrame.
* The resulting DataFrame contains a single column named "value", where each element of the array becomes a row in the DataFrame.
*
* @return a DataFrame where each element of the source array is represented as a row in a column named "value" under the schema [ValueProperty].
*/
@JvmName("convert1dArrayToDataFrame")
inline fun <reified N> D1Array<N>.convertToDataFrame(): DataFrame<ValueProperty<N>> {
// do the conversion like above, but name the column "value"...
val column = this.convertToColumn(ValueProperty<*>::value.name)
// ...so we can cast it to a ValueProperty DataFrame
return dataFrameOf(column).cast<ValueProperty<N>>()
}
/** Converts a [DataColumn] to a one-dimensional array ([D1Array]). */
@JvmName("convertNumberColumnToMultik")
inline fun <reified N> DataColumn<N>.convertToMultik(): D1Array<N> where N : Number, N : Comparable<N> {
// we can convert our column to a typed list again to convert it to a multik array
val values = this.toList()
return mk.ndarray(values)
}
/** Converts a [DataColumn] to a one-dimensional array ([D1Array]). */
@JvmName("convertComplexColumnToMultik")
inline fun <reified N : Complex> DataColumn<N>.convertToMultik(): D1Array<N> {
// we can convert our column to a typed list again to convert it to a multik array
val values = this.toList()
return mk.ndarray(values)
}
/** Converts a [DataColumn] selected by [column] to a one-dimensional array ([D1Array]). */
@JvmName("convertNumberColumnFromDfToMultik")
@OverloadResolutionByLambdaReturnType
inline fun <T, reified N> DataFrame<T>.convertToMultik(
crossinline column: ColumnSelector<T, N>,
): D1Array<N>
where N : Number, N : Comparable<N> {
// use the selector to get the column from this DataFrame and convert it
val col = this.getColumn { column(it) }
return col.convertToMultik()
}
/** Converts a [DataColumn] selected by [column] to a one-dimensional array ([D1Array]). */
@JvmName("convertComplexColumnFromDfToMultik")
@OverloadResolutionByLambdaReturnType
inline fun <T, reified N : Complex> DataFrame<T>.convertToMultik(crossinline column: ColumnSelector<T, N>): D1Array<N> {
// use the selector to get the column from this DataFrame and convert it
val col = this.getColumn { column(it) }
return col.convertToMultik()
}
// endregion
// region 2D
/**
* Converts a two-dimensional array ([D2Array]) to a DataFrame.
* It will contain `shape[0]` rows and `shape[1]` columns.
*
* Column names can be specified using the [columnNameGenerator] lambda.
*
* The conversion enforces that `multikArray[x][y] == dataframe[x][y]`
*/
@JvmName("convert2dArrayToDataFrame")
inline fun <reified N> D2Array<N>.convertToDataFrame(columnNameGenerator: (Int) -> String = { "col$it" }): AnyFrame {
// Turning the 2D array into a list of typed columns first, no inference needed
val columns: List<DataColumn<N>> = List(shape[1]) { i ->
this[0..<shape[0], i] // get all cells of column i
.toList()
.toColumn<N>(name = columnNameGenerator(i))
}
// and make a DataFrame from it
return columns.toDataFrame()
}
/**
* Converts a [DataFrame] to a two-dimensional array ([D2Array]).
* You'll need to specify which columns to convert using the [columns] selector.
*
* All columns need to be of the same type. If no columns are supplied, the function
* will only succeed if all columns are of the same type.
*
* @see convertToMultikOf
*/
@JvmName("convertNumberColumnsFromDfToMultik")
@OverloadResolutionByLambdaReturnType
inline fun <T, reified N> DataFrame<T>.convertToMultik(
crossinline columns: ColumnsSelector<T, N>,
): D2Array<N>
where N : Number, N : Comparable<N> {
// use the selector to get the columns from this DataFrame and convert them
val cols = this.getColumns { columns(it) }
return cols.convertToMultik()
}
/**
* Converts a [DataFrame] to a two-dimensional array ([D2Array]).
* You'll need to specify which columns to convert using the [columns] selector.
*
* All columns need to be of the same type. If no columns are supplied, the function
* will only succeed if all columns are of the same type.
*
* @see convertToMultikOf
*/
@JvmName("convertComplexColumnsFromDfToMultik")
@OverloadResolutionByLambdaReturnType
inline fun <T, reified N : Complex> DataFrame<T>.convertToMultik(
crossinline columns: ColumnsSelector<T, N>,
): D2Array<N> {
// use the selector to get the columns from this DataFrame and convert them
val cols = this.getColumns { columns(it) }
return cols.convertToMultik()
}
/**
* Converts a [DataFrame] to a two-dimensional array ([D2Array]).
* You'll need to specify which columns to convert using the `columns` selector.
*
* All columns need to be of the same type. If no columns are supplied, the function
* will only succeed if all columns in [this] are of the same type.
*
* @see convertToMultikOf
*/
@JvmName("convertToMultikGuess")
fun AnyFrame.convertToMultik(): D2Array<*> {
val columnTypes = this.columnTypes().distinct()
val type = columnTypes.singleOrNull() ?: error("found multiple column types: $columnTypes")
return when {
type == typeOf<Complex>() -> convertToMultik { colsOf<Complex>() }
type.isSubtypeOf(typeOf<Byte>()) -> convertToMultik { colsOf<Byte>() }
type.isSubtypeOf(typeOf<Short>()) -> convertToMultik { colsOf<Short>() }
type.isSubtypeOf(typeOf<Int>()) -> convertToMultik { colsOf<Int>() }
type.isSubtypeOf(typeOf<Long>()) -> convertToMultik { colsOf<Long>() }
type.isSubtypeOf(typeOf<Float>()) -> convertToMultik { colsOf<Float>() }
type.isSubtypeOf(typeOf<Double>()) -> convertToMultik { colsOf<Double>() }
else -> error("found multiple column types: $columnTypes")
}
}
/**
* Converts a [DataFrame] to a two-dimensional array ([D2Array]) by taking all
* columns of type [N].
*
* Allows you to write `df.convertToMultikOf<Complex>()`.
*
* @see convertToMultik
*/
@JvmName("convertToMultikOfComplex")
@Suppress("LocalVariableName")
inline fun <reified N : Complex> AnyFrame.convertToMultikOf(
// unused param to avoid overload resolution ambiguity
_klass: KClass<Complex> = Complex::class,
): D2Array<N> =
convertToMultik { colsOf<N>() }
/**
* Converts a [DataFrame] to a two-dimensional array ([D2Array]) by taking all
* columns of type [N].
*
* Allows you to write `df.convertToMultikOf<Int>()`.
*
* @see convertToMultik
*/
@JvmName("convertToMultikOfNumber")
@Suppress("LocalVariableName")
inline fun <reified N> AnyFrame.convertToMultikOf(
// unused param to avoid overload resolution ambiguity
_klass: KClass<Number> = Number::class,
): D2Array<N> where N : Number, N : Comparable<N> = convertToMultik { colsOf<N>() }
/**
* Helper function to convert a list of same-typed [DataColumn]s to a two-dimensional array ([D2Array]).
* We cannot enforce all columns have the same type if we require just a [DataFrame].
*/
@Suppress("UNCHECKED_CAST")
@JvmName("convertNumberColumnsToMultik")
inline fun <reified N> List<DataColumn<N>>.convertToMultik(): D2Array<N> where N : Number, N : Comparable<N> {
// to get the list of columns as a list of rows, we need to convert them back to a dataframe first,
// then we can get the values of each row
val rows = this.toDataFrame().map { row -> row.values() as List<N> }
return mk.ndarray(rows)
}
/**
* Helper function to convert a list of same-typed [DataColumn]s to a two-dimensional array ([D2Array]).
* We cannot enforce all columns have the same type if we require just a [DataFrame].
*/
@Suppress("UNCHECKED_CAST")
@JvmName("convertComplexColumnsToMultik")
inline fun <reified N : Complex> List<DataColumn<N>>.convertToMultik(): D2Array<N> {
// to get the list of columns as a list of rows, we need to convert them back to a dataframe first,
// then we can get the values of each row
val rows = this.toDataFrame().map { row -> row.values() as List<N> }
return mk.ndarray(rows)
}
// endregion
// region higher dimensions
/**
* Converts a three-dimensional array ([D3Array]) to a DataFrame.
* It will contain `shape[0]` rows and `shape[1]` columns containing lists of size `shape[2]`.
*
* Column names can be specified using the [columnNameGenerator] lambda.
*
* The conversion enforces that `multikArray[x][y][z] == dataframe[x][y][z]`
*/
inline fun <reified N> D3Array<N>.convertToDataFrameWithLists(
columnNameGenerator: (Int) -> String = { "col$it" },
): AnyFrame {
val columns: List<DataColumn<List<N>>> = List(shape[1]) { y ->
this[0..<shape[0], y, 0..<shape[2]] // get all cells of column y, each is a 2d array of size shape[0] x shape[2]
.toListD2() // get a shape[0]-sized list/column filled with lists of size shape[2]
.toColumn<List<N>>(name = columnNameGenerator(y))
}
return columns.toDataFrame()
}
/**
* Converts a three-dimensional array ([D3Array]) to a DataFrame.
* It will contain `shape[0]` rows and `shape[1]` column groups containing `shape[2]` columns each.
*
* Column names can be specified using the [columnNameGenerator] lambda.
*
* The conversion enforces that `multikArray[x][y][z] == dataframe[x][y][z]`
*/
@JvmName("convert3dArrayToDataFrame")
inline fun <reified N> D3Array<N>.convertToDataFrame(columnNameGenerator: (Int) -> String = { "col$it" }): AnyFrame {
val columns: List<ColumnGroup<*>> = List(shape[1]) { y ->
this[0..<shape[0], y, 0..<shape[2]] // get all cells of column i, each is a 2d array of size shape[0] x shape[2]
.transpose(1, 0) // flip, so we get shape[2] x shape[0]
.toListD2() // get a shape[2]-sized list filled with lists of size shape[0]
.mapIndexed { z, list ->
list.toColumn<N>(name = columnNameGenerator(z))
} // we get shape[2] columns inside each column group
.toColumnGroup(name = columnNameGenerator(y))
}
return columns.toDataFrame()
}
/**
* Exploratory recursive function to convert a [MultiArray] of any number of dimensions
* to a `List<List<...>>` of the same number of dimensions.
*/
fun <T> MultiArray<T, *>.toListDn(): List<*> {
// Recursive helper function to handle traversal across dimensions
fun toListRecursive(indices: IntArray): List<*> {
// If we are at the last dimension (1D case)
if (indices.size == shape.lastIndex) {
return List(shape[indices.size]) { i ->
this[intArrayOf(*indices, i)] // Collect values for this dimension
}
}
// For higher dimensions, recursively process smaller dimensions
return List(shape[indices.size]) { i ->
toListRecursive(indices + i) // Add `i` to the current index array
}
}
return toListRecursive(intArrayOf())
}
/**
* Converts a multidimensional array ([NDArray]) to a DataFrame.
* Inspired by [toListDn].
*
* For a single-dimensional array, it will call [D1Array.convertToDataFrame].
*
* Column names can be specified using the [columnNameGenerator] lambda.
*
* The conversion enforces that `multikArray[a][b][c][d]... == dataframe[a][b][c][d]...`
*/
@Suppress("UNCHECKED_CAST")
inline fun <reified N> NDArray<N, *>.convertToDataFrameNestedGroups(
noinline columnNameGenerator: (Int) -> String = { "col$it" },
): AnyFrame {
if (shape.size == 1) return (this as D1Array<N>).convertToDataFrame()
// push the first dimension to the end, because this represents the rows in DataFrame,
// and they are accessed by []'s first
return transpose(*(1..<dim.d).toList().toIntArray(), 0)
.convertToDataFrameNestedGroupsRecursive(
indices = intArrayOf(),
type = typeOf<N>(), // cannot inline a recursive function, so pass the type explicitly
columnNameGenerator = columnNameGenerator,
).let {
// we could just cast this to a DataFrame<*>, because a ColumnGroup<*>: DataFrame
// however, this can sometimes cause issues where instance checks are done at runtime
// this converts it to an actual DataFrame instance
dataFrameOf((it as ColumnGroup<*>).columns())
}
}
/**
* Recursive helper function to handle traversal across dimensions. Do not call directly,
* use [convertToDataFrameNestedGroups] instead.
*/
@PublishedApi
internal fun NDArray<*, *>.convertToDataFrameNestedGroupsRecursive(
indices: IntArray,
type: KType,
columnNameGenerator: (Int) -> String,
): BaseColumn<*> {
// If we are at the last dimension (1D case)
if (indices.size == shape.lastIndex) {
return List(shape[indices.size]) { i ->
this[intArrayOf(*indices, i)] // Collect values for this dimension
}.let {
DataColumn.createByType(name = "", values = it, type = type)
}
}
// For higher dimensions, recursively process smaller dimensions
return List(shape[indices.size]) { i ->
convertToDataFrameNestedGroupsRecursive(
indices = indices + i, // Add `i` to the current index array
type = type,
columnNameGenerator = columnNameGenerator,
).rename(columnNameGenerator(i))
}.toColumnGroup("")
}
// endregion
@@ -0,0 +1,23 @@
package org.jetbrains.kotlinx.dataframe.examples.multik
import org.jetbrains.kotlinx.dataframe.api.print
import org.jetbrains.kotlinx.multik.api.io.readNPY
import org.jetbrains.kotlinx.multik.api.mk
import org.jetbrains.kotlinx.multik.ndarray.data.D1
import java.io.File
/**
* Multik can read/write data from NPY/NPZ files.
* We can use this from DataFrame too!
*
* We use compatibilityLayer.kt for the conversions, check it out for the implementation details of the conversion!
*/
fun main() {
val npyFilename = "a1d.npy"
val npyFile = File(object {}.javaClass.classLoader.getResource(npyFilename)!!.toURI())
val mk1 = mk.readNPY<Long, D1>(npyFile)
val df1 = mk1.convertToDataFrame()
df1.print(borders = true, columnTypes = true)
}
@@ -0,0 +1,99 @@
package org.jetbrains.kotlinx.dataframe.examples.multik
import org.jetbrains.kotlinx.dataframe.api.cast
import org.jetbrains.kotlinx.dataframe.api.colsOf
import org.jetbrains.kotlinx.dataframe.api.describe
import org.jetbrains.kotlinx.dataframe.api.mean
import org.jetbrains.kotlinx.dataframe.api.meanFor
import org.jetbrains.kotlinx.dataframe.api.print
import org.jetbrains.kotlinx.dataframe.api.value
import org.jetbrains.kotlinx.multik.api.mk
import org.jetbrains.kotlinx.multik.api.rand
import org.jetbrains.kotlinx.multik.ndarray.data.get
/**
* Let's explore some ways we can combine Multik with Kotlin DataFrame.
*
* We will use compatibilityLayer.kt for the conversions.
* Take a look at that file for the implementation details!
*/
fun main() {
oneDimension()
twoDimensions()
higherDimensions()
}
fun oneDimension() {
// we can convert a 1D ndarray to a column of a DataFrame:
val mk1 = mk.rand<Double>(50)
val col1 by mk1.convertToColumn()
println(col1)
// or straight to a DataFrame. It will become the `value` column.
val df1 = mk1.convertToDataFrame()
println(df1)
// this allows us to perform any DF operation:
println(df1.mean { value })
df1.describe().print(borders = true)
// we can convert back to Multik:
val mk2 = df1.convertToMultik { value }
// or
df1.value.convertToMultik()
println(mk2)
}
fun twoDimensions() {
// we can also convert a 2D ndarray to a DataFrame
// This conversion will create columns like "col0", "col1", etc.
// (careful, when the number of columns is too large, this can cause problems)
// but will allow for similar access like in multik
// aka: `multikArray[x][y] == dataframe[x][y]`
val mk1 = mk.rand<Int>(5, 10)
println(mk1)
val df = mk1.convertToDataFrame()
df.print()
// this allows us to perform any DF operation:
val means = df.meanFor { ("col0".."col9").cast<Int>() }
means.print()
// we can convert back to Multik in multiple ways.
// Multik can only store one type of data, so we need to specify the type or select
// only the columns we want:
val mk2 = df.convertToMultik { colsOf<Int>() }
// or
df.convertToMultikOf<Int>()
// or if all columns are of the same type:
df.convertToMultik()
println(mk2)
}
fun higherDimensions() {
// Multik can store higher dimensions as well
// however; to convert this to a DataFrame, we need to specify how to do a particular conversion
// for instance, for 3d, we could store a list in each cell of the DF to represent the extra dimension:
val mk1 = mk.rand<Int>(5, 4, 3)
println(mk1)
val df1 = mk1.convertToDataFrameWithLists()
df1.print()
// Alternatively, this could be solved using column groups.
// This subdivides each column into more columns, while ensuring `multikArray[x][y][z] == dataframe[x][y][z]`
val df2 = mk1.convertToDataFrame()
df2.print()
// For even higher dimensions, we can keep adding more column groups
val mk2 = mk.rand<Int>(5, 4, 3, 2)
val df3 = mk2.convertToDataFrameNestedGroups()
df3.print()
// ...or use nested DataFrames (in FrameColumns)
// (for instance, a 4D matrix could be stored in a 2D DataFrame where each cell is another DataFrame)
// but, we'll leave that as an exercise for the reader :)
}
@@ -0,0 +1,115 @@
package org.jetbrains.kotlinx.dataframe.examples.multik
import kotlinx.datetime.LocalDate
import kotlinx.datetime.Month
import org.jetbrains.kotlinx.dataframe.annotations.DataSchema
import org.jetbrains.kotlinx.dataframe.api.append
import org.jetbrains.kotlinx.dataframe.api.cast
import org.jetbrains.kotlinx.dataframe.api.mapToFrame
import org.jetbrains.kotlinx.dataframe.api.print
import org.jetbrains.kotlinx.dataframe.api.single
import org.jetbrains.kotlinx.dataframe.api.toDataFrame
import org.jetbrains.kotlinx.multik.api.mk
import org.jetbrains.kotlinx.multik.api.rand
import org.jetbrains.kotlinx.multik.ndarray.data.D3Array
import org.jetbrains.kotlinx.multik.ndarray.data.D4Array
/**
* DataFrames can store anything inside, including Multik ndarrays.
* This can be useful for storing matrices for easier access later or to simply organize data read from other files.
* For example, MRI data is often stored as 3D arrays and sometimes even 4D arrays.
*/
fun main() {
// imaginary list of patient data
@Suppress("ktlint:standard:argument-list-wrapping")
val metadata = listOf(
MriMetadata(10012L, 25, "Healthy", LocalDate(2023, 1, 1)),
MriMetadata(10013L, 45, "Tuberculosis", LocalDate(2023, 2, 15)),
MriMetadata(10014L, 32, "Healthy", LocalDate(2023, 3, 22)),
MriMetadata(10015L, 58, "Pneumonia", LocalDate(2023, 4, 8)),
MriMetadata(10016L, 29, "Tuberculosis", LocalDate(2023, 5, 30)),
MriMetadata(10017L, 42, "Healthy", LocalDate(2023, 6, 15)),
MriMetadata(10018L, 37, "Healthy", LocalDate(2023, 7, 1)),
MriMetadata(10019L, 55, "Healthy", LocalDate(2023, 8, 15)),
MriMetadata(10020L, 28, "Healthy", LocalDate(2023, 9, 1)),
MriMetadata(10021L, 44, "Healthy", LocalDate(2023, 10, 15)),
MriMetadata(10022L, 31, "Healthy", LocalDate(2023, 11, 1)),
).toDataFrame()
// "reading" the results from "files"
val results = metadata.mapToFrame {
+patientId
+age
+diagnosis
+scanDate
"t1WeightedMri" from { readT1WeightedMri(patientId) }
"fMriBoldSeries" from { readFMRiBoldSeries(patientId) }
}.cast<MriResults>(verify = true)
.append()
results.print(borders = true)
// now when we want to check and visualize the T1-weighted MRI scan
// for that one healthy patient in July, we can do:
val scan = results
.single { scanDate.month == Month.JULY && diagnosis == "Healthy" }
.t1WeightedMri
// easy :)
visualize(scan)
}
@DataSchema
data class MriMetadata(
/** Unique patient ID. */
val patientId: Long,
/** Patient age. */
val age: Int,
/** Clinical diagnosis (e.g. "Healthy", "Tuberculosis") */
val diagnosis: String,
/** Date of the scan */
val scanDate: LocalDate,
)
@DataSchema
data class MriResults(
/** Unique patient ID. */
val patientId: Long,
/** Patient age. */
val age: Int,
/** Clinical diagnosis (e.g. "Healthy", "Tuberculosis") */
val diagnosis: String,
/** Date of the scan */
val scanDate: LocalDate,
/**
* T1-weighted anatomical MRI scan.
*
* Dimensions: (256 x 256 x 180)
* - 256 width x 256 height
* - 180 slices
*/
val t1WeightedMri: D3Array<Float>,
/**
* Blood oxygenation level-dependent (BOLD) time series from an fMRI scan.
*
* Dimensions: (64 x 64 x 30 x 200)
* - 64 width x 64 height
* - 30 slices
* - 200 timepoints
*/
val fMriBoldSeries: D4Array<Float>,
)
fun readT1WeightedMri(id: Long): D3Array<Float> {
// This should in practice, of course, read the actual data, but for this example we just return a dummy array
return mk.rand(256, 256, 180)
}
fun readFMRiBoldSeries(id: Long): D4Array<Float> {
// This should in practice, of course, read the actual data, but for this example we just return a dummy array
return mk.rand(64, 64, 30, 200)
}
fun visualize(scan: D3Array<Float>) {
// This would then actually visualize the scan
}
@@ -0,0 +1,77 @@
import org.jetbrains.kotlin.gradle.dsl.JvmTarget
plugins {
application
kotlin("jvm")
// uses the 'old' Gradle plugin instead of the compiler plugin for now
id("org.jetbrains.kotlinx.dataframe")
// only mandatory if `kotlin.dataframe.add.ksp=false` in gradle.properties
id("com.google.devtools.ksp")
}
repositories {
mavenLocal() // in case of local dataframe development
mavenCentral()
}
dependencies {
// implementation("org.jetbrains.kotlinx:dataframe:X.Y.Z")
implementation(project(":"))
// (kotlin) spark support
implementation(libs.kotlin.spark)
compileOnly(libs.spark)
implementation(libs.log4j.core)
implementation(libs.log4j.api)
}
/**
* Runs the kotlinSpark/typedDataset example with java 11.
*/
val runKotlinSparkTypedDataset by tasks.registering(JavaExec::class) {
classpath = sourceSets["main"].runtimeClasspath
javaLauncher = javaToolchains.launcherFor { languageVersion = JavaLanguageVersion.of(11) }
mainClass = "org.jetbrains.kotlinx.dataframe.examples.kotlinSpark.TypedDatasetKt"
}
/**
* Runs the kotlinSpark/untypedDataset example with java 11.
*/
val runKotlinSparkUntypedDataset by tasks.registering(JavaExec::class) {
classpath = sourceSets["main"].runtimeClasspath
javaLauncher = javaToolchains.launcherFor { languageVersion = JavaLanguageVersion.of(11) }
mainClass = "org.jetbrains.kotlinx.dataframe.examples.kotlinSpark.UntypedDatasetKt"
}
/**
* Runs the spark/typedDataset example with java 11.
*/
val runSparkTypedDataset by tasks.registering(JavaExec::class) {
classpath = sourceSets["main"].runtimeClasspath
javaLauncher = javaToolchains.launcherFor { languageVersion = JavaLanguageVersion.of(11) }
mainClass = "org.jetbrains.kotlinx.dataframe.examples.spark.TypedDatasetKt"
}
/**
* Runs the spark/untypedDataset example with java 11.
*/
val runSparkUntypedDataset by tasks.registering(JavaExec::class) {
classpath = sourceSets["main"].runtimeClasspath
javaLauncher = javaToolchains.launcherFor { languageVersion = JavaLanguageVersion.of(11) }
mainClass = "org.jetbrains.kotlinx.dataframe.examples.spark.UntypedDatasetKt"
}
kotlin {
compilerOptions {
jvmTarget = JvmTarget.JVM_11
freeCompilerArgs.add("-Xjdk-release=11")
}
}
tasks.withType<JavaCompile> {
sourceCompatibility = JavaVersion.VERSION_11.toString()
targetCompatibility = JavaVersion.VERSION_11.toString()
options.release.set(11)
}
@@ -0,0 +1,8 @@
@file:Suppress("ktlint:standard:no-empty-file")
package org.jetbrains.kotlinx.dataframe.examples.kotlinSpark
/*
* See ../spark/compatibilityLayer.kt for the implementation.
* It's the same with- and without the Kotlin Spark API.
*/
@@ -0,0 +1,78 @@
@file:Suppress("ktlint:standard:function-signature")
package org.jetbrains.kotlinx.dataframe.examples.kotlinSpark
import org.apache.spark.sql.Dataset
import org.jetbrains.kotlinx.dataframe.annotations.DataSchema
import org.jetbrains.kotlinx.dataframe.api.aggregate
import org.jetbrains.kotlinx.dataframe.api.groupBy
import org.jetbrains.kotlinx.dataframe.api.max
import org.jetbrains.kotlinx.dataframe.api.mean
import org.jetbrains.kotlinx.dataframe.api.min
import org.jetbrains.kotlinx.dataframe.api.print
import org.jetbrains.kotlinx.dataframe.api.schema
import org.jetbrains.kotlinx.dataframe.api.std
import org.jetbrains.kotlinx.dataframe.api.toDataFrame
import org.jetbrains.kotlinx.dataframe.api.toList
import org.jetbrains.kotlinx.spark.api.withSpark
/**
* With the Kotlin Spark API, normal Kotlin data classes are supported,
* meaning we can reuse the same class for Spark and DataFrame!
*
* Also, since we use an actual class to define the schema, we need no type conversion!
*
* See [Person] and [Name] for an example.
*
* NOTE: You will likely need to run this function with Java 8 or 11 for it to work correctly.
* Use the `runKotlinSparkTypedDataset` Gradle task to do so.
*/
fun main() = withSpark {
// Creating a Spark Dataset. Usually, this is loaded from some server or database.
val rawDataset: Dataset<Person> = listOf(
Person(Name("Alice", "Cooper"), 15, "London", 54, true),
Person(Name("Bob", "Dylan"), 45, "Dubai", 87, true),
Person(Name("Charlie", "Daniels"), 20, "Moscow", null, false),
Person(Name("Charlie", "Chaplin"), 40, "Milan", null, true),
Person(Name("Bob", "Marley"), 30, "Tokyo", 68, true),
Person(Name("Alice", "Wolf"), 20, null, 55, false),
Person(Name("Charlie", "Byrd"), 30, "Moscow", 90, true),
).toDS()
// we can perform large operations in Spark.
// DataFrames are in-memory structures, so this is a good place to limit the number of rows if you don't have the RAM ;)
val dataset = rawDataset.filter { it.age > 17 }
// and convert it to DataFrame via a typed List
val dataframe = dataset.collectAsList().toDataFrame()
dataframe.schema().print()
dataframe.print(columnTypes = true, borders = true)
// now we can use DataFrame-specific functions
val ageStats = dataframe
.groupBy { city }.aggregate {
mean { age } into "meanAge"
std { age } into "stdAge"
min { age } into "minAge"
max { age } into "maxAge"
}
ageStats.print(columnTypes = true, borders = true)
// and when we want to convert a DataFrame back to Spark, we can do the same trick via a typed List
val sparkDatasetAgain = dataframe.toList().toDS()
sparkDatasetAgain.printSchema()
sparkDatasetAgain.show()
}
@DataSchema
data class Name(val firstName: String, val lastName: String)
@DataSchema
data class Person(
val name: Name,
val age: Int,
val city: String?,
val weight: Int?,
val isHappy: Boolean,
)
@@ -0,0 +1,74 @@
@file:Suppress("ktlint:standard:function-signature")
package org.jetbrains.kotlinx.dataframe.examples.kotlinSpark
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.Row
import org.jetbrains.kotlinx.dataframe.api.aggregate
import org.jetbrains.kotlinx.dataframe.api.groupBy
import org.jetbrains.kotlinx.dataframe.api.max
import org.jetbrains.kotlinx.dataframe.api.mean
import org.jetbrains.kotlinx.dataframe.api.min
import org.jetbrains.kotlinx.dataframe.api.print
import org.jetbrains.kotlinx.dataframe.api.schema
import org.jetbrains.kotlinx.dataframe.api.std
import org.jetbrains.kotlinx.dataframe.examples.spark.convertToDataFrame
import org.jetbrains.kotlinx.dataframe.examples.spark.convertToDataFrameByInference
import org.jetbrains.kotlinx.dataframe.examples.spark.convertToSpark
import org.jetbrains.kotlinx.spark.api.col
import org.jetbrains.kotlinx.spark.api.gt
import org.jetbrains.kotlinx.spark.api.withSpark
/**
* Since we don't know the schema at compile time this time, we need to do
* some schema mapping in between Spark and DataFrame.
*
* We will use spark/compatibilityLayer.kt to do this.
* Take a look at that file for the implementation details!
*
* NOTE: You will likely need to run this function with Java 8 or 11 for it to work correctly.
* Use the `runKotlinSparkUntypedDataset` Gradle task to do so.
*/
fun main() = withSpark {
// Creating a Spark Dataframe (untyped Dataset). Usually, this is loaded from some server or database.
val rawDataset: Dataset<Row> = listOf(
Person(Name("Alice", "Cooper"), 15, "London", 54, true),
Person(Name("Bob", "Dylan"), 45, "Dubai", 87, true),
Person(Name("Charlie", "Daniels"), 20, "Moscow", null, false),
Person(Name("Charlie", "Chaplin"), 40, "Milan", null, true),
Person(Name("Bob", "Marley"), 30, "Tokyo", 68, true),
Person(Name("Alice", "Wolf"), 20, null, 55, false),
Person(Name("Charlie", "Byrd"), 30, "Moscow", 90, true),
).toDF()
// we can perform large operations in Spark.
// DataFrames are in-memory structures, so this is a good place to limit the number of rows if you don't have the RAM ;)
val dataset = rawDataset.filter(col("age") gt 17)
// Using inference
val df1 = dataset.convertToDataFrameByInference()
df1.schema().print()
df1.print(columnTypes = true, borders = true)
// Using full schema mapping
val df2 = dataset.convertToDataFrame()
df2.schema().print()
df2.print(columnTypes = true, borders = true)
// now we can use DataFrame-specific functions
val ageStats = df1
.groupBy("city").aggregate {
mean("age") into "meanAge"
std("age") into "stdAge"
min("age") into "minAge"
max("age") into "maxAge"
}
ageStats.print(columnTypes = true, borders = true)
// and when we want to convert a DataFrame back to Spark, we will use the `convertToSpark()` extension function
// This performs the necessary schema mapping under the hood.
val sparkDataset = df2.convertToSpark(spark, sc)
sparkDataset.printSchema()
sparkDataset.show()
}
@@ -0,0 +1,330 @@
package org.jetbrains.kotlinx.dataframe.examples.spark
import org.apache.spark.api.java.JavaRDD
import org.apache.spark.api.java.JavaSparkContext
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.Row
import org.apache.spark.sql.RowFactory
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.types.ArrayType
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.types.DataTypes
import org.apache.spark.sql.types.Decimal
import org.apache.spark.sql.types.DecimalType
import org.apache.spark.sql.types.MapType
import org.apache.spark.sql.types.StructType
import org.apache.spark.unsafe.types.CalendarInterval
import org.jetbrains.kotlinx.dataframe.AnyFrame
import org.jetbrains.kotlinx.dataframe.DataColumn
import org.jetbrains.kotlinx.dataframe.DataFrame
import org.jetbrains.kotlinx.dataframe.DataRow
import org.jetbrains.kotlinx.dataframe.api.rows
import org.jetbrains.kotlinx.dataframe.api.schema
import org.jetbrains.kotlinx.dataframe.api.toDataFrame
import org.jetbrains.kotlinx.dataframe.columns.ColumnGroup
import org.jetbrains.kotlinx.dataframe.columns.TypeSuggestion
import org.jetbrains.kotlinx.dataframe.schema.ColumnSchema
import org.jetbrains.kotlinx.dataframe.schema.DataFrameSchema
import java.math.BigDecimal
import java.math.BigInteger
import java.sql.Date
import java.sql.Timestamp
import java.time.Instant
import java.time.LocalDate
import kotlin.reflect.KType
import kotlin.reflect.KTypeProjection
import kotlin.reflect.full.createType
import kotlin.reflect.full.isSubtypeOf
import kotlin.reflect.full.withNullability
import kotlin.reflect.typeOf
// region Spark to DataFrame
/**
* Converts an untyped Spark [Dataset] (Dataframe) to a Kotlin [DataFrame].
* [StructTypes][StructType] are converted to [ColumnGroups][ColumnGroup].
*
* DataFrame supports type inference to do the conversion automatically.
* This is usually fine for smaller data sets, but when working with larger datasets, a type map might be a good idea.
* See [convertToDataFrame] for more information.
*/
fun Dataset<Row>.convertToDataFrameByInference(
schema: StructType = schema(),
prefix: List<String> = emptyList(),
): AnyFrame {
val columns = schema.fields().map { field ->
val name = field.name()
when (val dataType = field.dataType()) {
is StructType ->
// a column group can be easily created from a dataframe and a name
DataColumn.createColumnGroup(
name = name,
df = this.convertToDataFrameByInference(dataType, prefix + name),
)
else ->
// we can use DataFrame type inference to create a column with the correct type
// from Spark we use `select()` to select a single column
// and `collectAsList()` to get all the values in a list of single-celled rows
DataColumn.createByInference(
name = name,
values = this.select((prefix + name).joinToString("."))
.collectAsList()
.map { it[0] },
suggestedType = TypeSuggestion.Infer,
// Spark provides nullability :) you can leave this out if you want this to be inferred too
nullable = field.nullable(),
)
}
}
return columns.toDataFrame()
}
/**
* Converts an untyped Spark [Dataset] (Dataframe) to a Kotlin [DataFrame].
* [StructTypes][StructType] are converted to [ColumnGroups][ColumnGroup].
*
* This version uses a [type-map][DataType.convertToDataFrame] to convert the schemas with a fallback to inference.
* For smaller data sets, inference is usually fine too.
* See [convertToDataFrameByInference] for more information.
*/
fun Dataset<Row>.convertToDataFrame(schema: StructType = schema(), prefix: List<String> = emptyList()): AnyFrame {
val columns = schema.fields().map { field ->
val name = field.name()
when (val dataType = field.dataType()) {
is StructType ->
// a column group can be easily created from a dataframe and a name
DataColumn.createColumnGroup(
name = name,
df = convertToDataFrame(dataType, prefix + name),
)
else ->
// we create a column with the correct type using our type-map with fallback to inference
// from Spark we use `select()` to select a single column
// and `collectAsList()` to get all the values in a list of single-celled rows
DataColumn.createByInference(
name = name,
values = select((prefix + name).joinToString("."))
.collectAsList()
.map { it[0] },
suggestedType =
dataType.convertToDataFrame()
?.let(TypeSuggestion::Use)
?: TypeSuggestion.Infer, // fallback to inference if needed
nullable = field.nullable(),
)
}
}
return columns.toDataFrame()
}
/**
* Returns the corresponding [Kotlin type][KType] for a given Spark [DataType].
*
* This list may be incomplete, but it can at least give you a good start.
*
* @return The [KType] that corresponds to the Spark [DataType], or null if no matching [KType] is found.
*/
fun DataType.convertToDataFrame(): KType? =
when {
this == DataTypes.ByteType -> typeOf<Byte>()
this == DataTypes.ShortType -> typeOf<Short>()
this == DataTypes.IntegerType -> typeOf<Int>()
this == DataTypes.LongType -> typeOf<Long>()
this == DataTypes.BooleanType -> typeOf<Boolean>()
this == DataTypes.FloatType -> typeOf<Float>()
this == DataTypes.DoubleType -> typeOf<Double>()
this == DataTypes.StringType -> typeOf<String>()
this == DataTypes.DateType -> typeOf<Date>()
this == DataTypes.TimestampType -> typeOf<Timestamp>()
this is DecimalType -> typeOf<Decimal>()
this == DataTypes.CalendarIntervalType -> typeOf<CalendarInterval>()
this == DataTypes.NullType -> nullableNothingType
this == DataTypes.BinaryType -> typeOf<ByteArray>()
this is ArrayType -> {
when (elementType()) {
DataTypes.ShortType -> typeOf<ShortArray>()
DataTypes.IntegerType -> typeOf<IntArray>()
DataTypes.LongType -> typeOf<LongArray>()
DataTypes.FloatType -> typeOf<FloatArray>()
DataTypes.DoubleType -> typeOf<DoubleArray>()
DataTypes.BooleanType -> typeOf<BooleanArray>()
else -> null
}
}
this is MapType -> {
val key = keyType().convertToDataFrame() ?: return null
val value = valueType().convertToDataFrame() ?: return null
Map::class.createType(
listOf(
KTypeProjection.invariant(key),
KTypeProjection.invariant(value.withNullability(valueContainsNull())),
),
)
}
else -> null
}
// endregion
// region DataFrame to Spark
/**
* Converts the [DataFrame] to a Spark [Dataset] of [Rows][Row] using the provided [SparkSession] and [JavaSparkContext].
*
* Spark needs both the data and the schema to be converted to create a correct [Dataset],
* so we need to map our types somehow.
*
* @param spark The [SparkSession] object to use for creating the [DataFrame].
* @param sc The [JavaSparkContext] object to use for converting the [DataFrame] to [RDD][JavaRDD].
* @return A [Dataset] of [Rows][Row] representing the converted DataFrame.
*/
fun DataFrame<*>.convertToSpark(spark: SparkSession, sc: JavaSparkContext): Dataset<Row> {
// Convert each row to spark rows
val rows = sc.parallelize(this.rows().map { it.convertToSpark() })
// convert the data schema to a spark StructType
val schema = this.schema().convertToSpark()
return spark.createDataFrame(rows, schema)
}
/**
* Converts a [DataRow] to a Spark [Row] object.
*
* @return The converted Spark [Row].
*/
fun DataRow<*>.convertToSpark(): Row =
RowFactory.create(
*values().map {
when (it) {
// a row can be nested inside another row if it's a column group
is DataRow<*> -> it.convertToSpark()
is DataFrame<*> -> error("nested dataframes are not supported")
else -> it
}
}.toTypedArray(),
)
/**
* Converts a [DataFrameSchema] to a Spark [StructType].
*
* @return The converted Spark [StructType].
*/
fun DataFrameSchema.convertToSpark(): StructType =
DataTypes.createStructType(
this.columns.map { (name, schema) ->
DataTypes.createStructField(name, schema.convertToSpark(), schema.nullable)
},
)
/**
* Converts a [ColumnSchema] object to Spark [DataType].
*
* @return The Spark [DataType] corresponding to the given [ColumnSchema] object.
* @throws IllegalArgumentException if the column type or kind is unknown.
*/
fun ColumnSchema.convertToSpark(): DataType =
when (this) {
is ColumnSchema.Value -> type.convertToSpark() ?: error("unknown data type: $type")
is ColumnSchema.Group -> schema.convertToSpark()
is ColumnSchema.Frame -> error("nested dataframes are not supported")
else -> error("unknown column kind: $this")
}
/**
* Returns the corresponding Spark [DataType] for a given [Kotlin type][KType].
*
* This list may be incomplete, but it can at least give you a good start.
*
* @return The Spark [DataType] that corresponds to the [Kotlin type][KType], or null if no matching [DataType] is found.
*/
fun KType.convertToSpark(): DataType? =
when {
isSubtypeOf(typeOf<Byte?>()) -> DataTypes.ByteType
isSubtypeOf(typeOf<Short?>()) -> DataTypes.ShortType
isSubtypeOf(typeOf<Int?>()) -> DataTypes.IntegerType
isSubtypeOf(typeOf<Long?>()) -> DataTypes.LongType
isSubtypeOf(typeOf<Boolean?>()) -> DataTypes.BooleanType
isSubtypeOf(typeOf<Float?>()) -> DataTypes.FloatType
isSubtypeOf(typeOf<Double?>()) -> DataTypes.DoubleType
isSubtypeOf(typeOf<String?>()) -> DataTypes.StringType
isSubtypeOf(typeOf<LocalDate?>()) -> DataTypes.DateType
isSubtypeOf(typeOf<Date?>()) -> DataTypes.DateType
isSubtypeOf(typeOf<Timestamp?>()) -> DataTypes.TimestampType
isSubtypeOf(typeOf<Instant?>()) -> DataTypes.TimestampType
isSubtypeOf(typeOf<Decimal?>()) -> DecimalType.SYSTEM_DEFAULT()
isSubtypeOf(typeOf<BigDecimal?>()) -> DecimalType.SYSTEM_DEFAULT()
isSubtypeOf(typeOf<BigInteger?>()) -> DecimalType.SYSTEM_DEFAULT()
isSubtypeOf(typeOf<CalendarInterval?>()) -> DataTypes.CalendarIntervalType
isSubtypeOf(nullableNothingType) -> DataTypes.NullType
isSubtypeOf(typeOf<ByteArray?>()) -> DataTypes.BinaryType
isSubtypeOf(typeOf<ShortArray?>()) -> DataTypes.createArrayType(DataTypes.ShortType, false)
isSubtypeOf(typeOf<IntArray?>()) -> DataTypes.createArrayType(DataTypes.IntegerType, false)
isSubtypeOf(typeOf<LongArray?>()) -> DataTypes.createArrayType(DataTypes.LongType, false)
isSubtypeOf(typeOf<FloatArray?>()) -> DataTypes.createArrayType(DataTypes.FloatType, false)
isSubtypeOf(typeOf<DoubleArray?>()) -> DataTypes.createArrayType(DataTypes.DoubleType, false)
isSubtypeOf(typeOf<BooleanArray?>()) -> DataTypes.createArrayType(DataTypes.BooleanType, false)
isSubtypeOf(typeOf<Array<*>>()) ->
error("non-primitive arrays are not supported for now, you can add it yourself")
isSubtypeOf(typeOf<List<*>>()) -> error("lists are not supported for now, you can add it yourself")
isSubtypeOf(typeOf<Set<*>>()) -> error("sets are not supported for now, you can add it yourself")
classifier == Map::class -> {
val (key, value) = arguments
DataTypes.createMapType(
key.type?.convertToSpark(),
value.type?.convertToSpark(),
value.type?.isMarkedNullable ?: true,
)
}
else -> null
}
private val nullableNothingType: KType = typeOf<List<Nothing?>>().arguments.first().type!!
// endregion
@@ -0,0 +1,105 @@
@file:Suppress("ktlint:standard:function-signature")
package org.jetbrains.kotlinx.dataframe.examples.spark
import org.apache.spark.SparkConf
import org.apache.spark.api.java.JavaSparkContext
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.Encoders
import org.apache.spark.sql.SparkSession
import org.jetbrains.kotlinx.dataframe.annotations.DataSchema
import org.jetbrains.kotlinx.dataframe.api.aggregate
import org.jetbrains.kotlinx.dataframe.api.groupBy
import org.jetbrains.kotlinx.dataframe.api.max
import org.jetbrains.kotlinx.dataframe.api.mean
import org.jetbrains.kotlinx.dataframe.api.min
import org.jetbrains.kotlinx.dataframe.api.print
import org.jetbrains.kotlinx.dataframe.api.schema
import org.jetbrains.kotlinx.dataframe.api.std
import org.jetbrains.kotlinx.dataframe.api.toDataFrame
import org.jetbrains.kotlinx.dataframe.api.toList
import java.io.Serializable
/**
* For Spark, Kotlin data classes are supported if we:
* - Add [@JvmOverloads][JvmOverloads] to the constructor
* - Make all parameter arguments mutable and with defaults
* - Make them [Serializable]
*
* But by adding [@DataSchema][DataSchema] we can reuse the same class for Spark and DataFrame!
*
* See [Person] and [Name] for an example.
*
* Also, since we use an actual class to define the schema, we need no type conversion!
*
* NOTE: You will likely need to run this function with Java 8 or 11 for it to work correctly.
* Use the `runSparkTypedDataset` Gradle task to do so.
*/
fun main() {
val spark = SparkSession.builder()
.master(SparkConf().get("spark.master", "local[*]"))
.appName("Kotlin Spark Sample")
.getOrCreate()
val sc = JavaSparkContext(spark.sparkContext())
// Creating a Spark Dataset. Usually, this is loaded from some server or database.
val rawDataset: Dataset<Person> = spark.createDataset(
listOf(
Person(Name("Alice", "Cooper"), 15, "London", 54, true),
Person(Name("Bob", "Dylan"), 45, "Dubai", 87, true),
Person(Name("Charlie", "Daniels"), 20, "Moscow", null, false),
Person(Name("Charlie", "Chaplin"), 40, "Milan", null, true),
Person(Name("Bob", "Marley"), 30, "Tokyo", 68, true),
Person(Name("Alice", "Wolf"), 20, null, 55, false),
Person(Name("Charlie", "Byrd"), 30, "Moscow", 90, true),
),
beanEncoderOf(),
)
// we can perform large operations in Spark.
// DataFrames are in-memory structures, so this is a good place to limit the number of rows if you don't have the RAM ;)
val dataset = rawDataset.filter { it.age > 17 }
// and convert it to DataFrame via a typed List
val dataframe = dataset.collectAsList().toDataFrame()
dataframe.schema().print()
dataframe.print(columnTypes = true, borders = true)
// now we can use DataFrame-specific functions
val ageStats = dataframe
.groupBy { city }.aggregate {
mean { age } into "meanAge"
std { age } into "stdAge"
min { age } into "minAge"
max { age } into "maxAge"
}
ageStats.print(columnTypes = true, borders = true)
// and when we want to convert a DataFrame back to Spark, we can do the same trick via a typed List
val sparkDatasetAgain = spark.createDataset(dataframe.toList(), beanEncoderOf())
sparkDatasetAgain.printSchema()
sparkDatasetAgain.show()
spark.stop()
}
/** Creates a [bean encoder][Encoders.bean] for the given [T] instance. */
inline fun <reified T : Serializable> beanEncoderOf(): Encoder<T> = Encoders.bean(T::class.java)
@DataSchema
data class Name
@JvmOverloads
constructor(var firstName: String = "", var lastName: String = "") : Serializable
@DataSchema
data class Person
@JvmOverloads
constructor(
var name: Name = Name(),
var age: Int = -1,
var city: String? = null,
var weight: Int? = null,
var isHappy: Boolean = false,
) : Serializable
@@ -0,0 +1,87 @@
@file:Suppress("ktlint:standard:function-signature")
package org.jetbrains.kotlinx.dataframe.examples.spark
import org.apache.spark.SparkConf
import org.apache.spark.api.java.JavaSparkContext
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.Row
import org.apache.spark.sql.SparkSession
import org.jetbrains.kotlinx.dataframe.api.aggregate
import org.jetbrains.kotlinx.dataframe.api.groupBy
import org.jetbrains.kotlinx.dataframe.api.max
import org.jetbrains.kotlinx.dataframe.api.mean
import org.jetbrains.kotlinx.dataframe.api.min
import org.jetbrains.kotlinx.dataframe.api.print
import org.jetbrains.kotlinx.dataframe.api.schema
import org.jetbrains.kotlinx.dataframe.api.std
import org.jetbrains.kotlinx.dataframe.examples.spark.convertToDataFrame
import org.jetbrains.kotlinx.dataframe.examples.spark.convertToDataFrameByInference
import org.jetbrains.kotlinx.dataframe.examples.spark.convertToSpark
import org.jetbrains.kotlinx.spark.api.col
import org.jetbrains.kotlinx.spark.api.gt
/**
* Since we don't know the schema at compile time this time, we need to do
* some schema mapping in between Spark and DataFrame.
*
* We will use spark/compatibilityLayer.kt to do this.
* Take a look at that file for the implementation details!
*
* NOTE: You will likely need to run this function with Java 8 or 11 for it to work correctly.
* Use the `runSparkUntypedDataset` Gradle task to do so.
*/
fun main() {
val spark = SparkSession.builder()
.master(SparkConf().get("spark.master", "local[*]"))
.appName("Kotlin Spark Sample")
.getOrCreate()
val sc = JavaSparkContext(spark.sparkContext())
// Creating a Spark Dataframe (untyped Dataset). Usually, this is loaded from some server or database.
val rawDataset: Dataset<Row> = spark.createDataset(
listOf(
Person(Name("Alice", "Cooper"), 15, "London", 54, true),
Person(Name("Bob", "Dylan"), 45, "Dubai", 87, true),
Person(Name("Charlie", "Daniels"), 20, "Moscow", null, false),
Person(Name("Charlie", "Chaplin"), 40, "Milan", null, true),
Person(Name("Bob", "Marley"), 30, "Tokyo", 68, true),
Person(Name("Alice", "Wolf"), 20, null, 55, false),
Person(Name("Charlie", "Byrd"), 30, "Moscow", 90, true),
),
beanEncoderOf<Person>(),
).toDF()
// we can perform large operations in Spark.
// DataFrames are in-memory structures, so this is a good place to limit the number of rows if you don't have the RAM ;)
val dataset = rawDataset.filter(col("age") gt 17)
// Using inference
val df1 = dataset.convertToDataFrameByInference()
df1.schema().print()
df1.print(columnTypes = true, borders = true)
// Using full schema mapping
val df2 = dataset.convertToDataFrame()
df2.schema().print()
df2.print(columnTypes = true, borders = true)
// now we can use DataFrame-specific functions
val ageStats = df1
.groupBy("city").aggregate {
mean("age") into "meanAge"
std("age") into "stdAge"
min("age") into "minAge"
max("age") into "maxAge"
}
ageStats.print(columnTypes = true, borders = true)
// and when we want to convert a DataFrame back to Spark, we will use the `convertToSpark()` extension function
// This performs the necessary schema mapping under the hood.
val sparkDataset = df2.convertToSpark(spark, sc)
sparkDataset.printSchema()
sparkDataset.show()
spark.stop()
}