init research
This commit is contained in:
@@ -0,0 +1,52 @@
|
||||
import org.jetbrains.kotlin.gradle.dsl.JvmTarget
|
||||
|
||||
plugins {
|
||||
application
|
||||
kotlin("jvm")
|
||||
|
||||
id("org.jetbrains.kotlinx.dataframe")
|
||||
|
||||
// only mandatory if `kotlin.dataframe.add.ksp=false` in gradle.properties
|
||||
id("com.google.devtools.ksp")
|
||||
}
|
||||
|
||||
repositories {
|
||||
mavenCentral()
|
||||
mavenLocal() // in case of local dataframe development
|
||||
}
|
||||
|
||||
application.mainClass = "org.jetbrains.kotlinx.dataframe.examples.titanic.ml.TitanicKt"
|
||||
|
||||
dependencies {
|
||||
// implementation("org.jetbrains.kotlinx:dataframe:X.Y.Z")
|
||||
implementation(project(":"))
|
||||
|
||||
// note: needs to target java 11 for these dependencies
|
||||
implementation("org.jetbrains.kotlinx:kotlin-deeplearning-api:0.5.2")
|
||||
implementation("org.jetbrains.kotlinx:kotlin-deeplearning-impl:0.5.2")
|
||||
implementation("org.jetbrains.kotlinx:kotlin-deeplearning-tensorflow:0.5.2")
|
||||
implementation("org.jetbrains.kotlinx:kotlin-deeplearning-dataset:0.5.2")
|
||||
}
|
||||
|
||||
dataframes {
|
||||
schema {
|
||||
data = "src/main/resources/titanic.csv"
|
||||
name = "org.jetbrains.kotlinx.dataframe.examples.titanic.ml.Passenger"
|
||||
csvOptions {
|
||||
delimiter = ';'
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
kotlin {
|
||||
compilerOptions {
|
||||
jvmTarget = JvmTarget.JVM_11
|
||||
freeCompilerArgs.add("-Xjdk-release=11")
|
||||
}
|
||||
}
|
||||
|
||||
tasks.withType<JavaCompile> {
|
||||
sourceCompatibility = JavaVersion.VERSION_11.toString()
|
||||
targetCompatibility = JavaVersion.VERSION_11.toString()
|
||||
options.release.set(11)
|
||||
}
|
||||
+95
@@ -0,0 +1,95 @@
|
||||
package org.jetbrains.kotlinx.dataframe.examples.titanic.ml
|
||||
|
||||
import org.jetbrains.kotlinx.dataframe.ColumnSelector
|
||||
import org.jetbrains.kotlinx.dataframe.DataFrame
|
||||
import org.jetbrains.kotlinx.dataframe.api.*
|
||||
import org.jetbrains.kotlinx.dl.api.core.Sequential
|
||||
import org.jetbrains.kotlinx.dl.api.core.activation.Activations
|
||||
import org.jetbrains.kotlinx.dl.api.core.initializer.HeNormal
|
||||
import org.jetbrains.kotlinx.dl.api.core.initializer.Zeros
|
||||
import org.jetbrains.kotlinx.dl.api.core.layer.core.Dense
|
||||
import org.jetbrains.kotlinx.dl.api.core.layer.core.Input
|
||||
import org.jetbrains.kotlinx.dl.api.core.loss.Losses
|
||||
import org.jetbrains.kotlinx.dl.api.core.metric.Metrics
|
||||
import org.jetbrains.kotlinx.dl.api.core.optimizer.Adam
|
||||
import org.jetbrains.kotlinx.dl.dataset.OnHeapDataset
|
||||
import java.util.Locale
|
||||
|
||||
private const val SEED = 12L
|
||||
private const val TEST_BATCH_SIZE = 100
|
||||
private const val EPOCHS = 50
|
||||
private const val TRAINING_BATCH_SIZE = 50
|
||||
|
||||
private val model = Sequential.of(
|
||||
Input(9),
|
||||
Dense(50, Activations.Relu, kernelInitializer = HeNormal(SEED), biasInitializer = Zeros()),
|
||||
Dense(50, Activations.Relu, kernelInitializer = HeNormal(SEED), biasInitializer = Zeros()),
|
||||
Dense(2, Activations.Linear, kernelInitializer = HeNormal(SEED), biasInitializer = Zeros())
|
||||
)
|
||||
|
||||
fun main() {
|
||||
|
||||
// Set Locale for correct number parsing
|
||||
Locale.setDefault(Locale.FRANCE)
|
||||
|
||||
val df = Passenger.readCsv()
|
||||
|
||||
// Calculating imputing values
|
||||
val (train, test) = df
|
||||
// imputing
|
||||
.fillNulls { sibsp and parch and age and fare }.perCol { it.mean() }
|
||||
.fillNulls { sex }.with { "female" }
|
||||
// one hot encoding
|
||||
.pivotMatches { pclass and sex }
|
||||
// feature extraction
|
||||
.select { survived and pclass and sibsp and parch and age and fare and sex }
|
||||
.shuffle()
|
||||
.toTrainTest(0.7) { survived }
|
||||
|
||||
model.use {
|
||||
it.compile(
|
||||
optimizer = Adam(),
|
||||
loss = Losses.SOFT_MAX_CROSS_ENTROPY_WITH_LOGITS,
|
||||
metric = Metrics.ACCURACY
|
||||
)
|
||||
|
||||
it.summary()
|
||||
it.fit(dataset = train, epochs = EPOCHS, batchSize = TRAINING_BATCH_SIZE)
|
||||
|
||||
val accuracy = it.evaluate(dataset = test, batchSize = TEST_BATCH_SIZE).metrics[Metrics.ACCURACY]
|
||||
|
||||
println("Accuracy: $accuracy")
|
||||
}
|
||||
}
|
||||
|
||||
fun <T> DataFrame<T>.toTrainTest(
|
||||
trainRatio: Double,
|
||||
yColumn: ColumnSelector<T, Number>,
|
||||
): Pair<OnHeapDataset, OnHeapDataset> =
|
||||
toOnHeapDataset(yColumn)
|
||||
.split(trainRatio)
|
||||
|
||||
private fun <T> DataFrame<T>.toOnHeapDataset(yColumn: ColumnSelector<T, Number>): OnHeapDataset =
|
||||
OnHeapDataset.create(
|
||||
dataframe = this,
|
||||
yColumn = yColumn,
|
||||
)
|
||||
|
||||
private fun <T> OnHeapDataset.Companion.create(
|
||||
dataframe: DataFrame<T>,
|
||||
yColumn: ColumnSelector<T, Number>,
|
||||
): OnHeapDataset {
|
||||
|
||||
fun extractX(): Array<FloatArray> =
|
||||
dataframe.remove(yColumn)
|
||||
.convert { colsAtAnyDepth().filter { !it.isColumnGroup() } }.toFloat()
|
||||
.merge { colsAtAnyDepth().colsOf<Float>() }.by { it.toFloatArray() }.into("X")
|
||||
.getColumn("X").cast<FloatArray>().toTypedArray()
|
||||
|
||||
fun extractY(): FloatArray = dataframe.get(yColumn).toFloatArray()
|
||||
|
||||
return create(
|
||||
::extractX,
|
||||
::extractY,
|
||||
)
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user