90 lines
4.0 KiB
Clojure
Vendored
90 lines
4.0 KiB
Clojure
Vendored
(ns tech.v3.dataset.tribuo-test
|
|
(:require [tech.v3.datatype :as dtype]
|
|
[tech.v3.datatype.functional :as dfn]
|
|
[tech.v3.dataset :as ds]
|
|
[tech.v3.dataset.modelling :as ds-model]
|
|
[tech.v3.dataset.categorical :as ds-cat]
|
|
[tech.v3.libs.tribuo :as tribuo]
|
|
[clojure.test :refer [deftest is testing]])
|
|
(:import [org.tribuo.classification.xgboost XGBoostClassificationTrainer]
|
|
[org.tribuo.classification.sgd.linear LogisticRegressionTrainer]
|
|
[org.tribuo.regression.xgboost XGBoostRegressionTrainer]))
|
|
|
|
|
|
(defn classification-example-ds
|
|
[x]
|
|
(let [x (if (integer? x)
|
|
(vec (repeatedly x rand))
|
|
x)
|
|
y (repeatedly (count x) rand)
|
|
label (dtype/emap #(if (< 0.25 % 0.75)
|
|
"green"
|
|
"red")
|
|
:string x)]
|
|
(ds/->dataset {:x x
|
|
:y y
|
|
:label label})))
|
|
|
|
|
|
(deftest classification-pathway
|
|
(let [ds (classification-example-ds 10000)]
|
|
(testing "xgboost"
|
|
(let [
|
|
;;This is not necessary for tribuo's classification pathway. Below is just setup
|
|
;;to test that if someone has a pipeline that is already using categorical mapping
|
|
;;they can still classify without too many changes.
|
|
cat-data (ds-cat/fit-categorical-map ds :label)
|
|
ds (ds-cat/transform-categorical-map ds cat-data)
|
|
{:keys [test-ds train-ds]} (ds-model/train-test-split ds)
|
|
;;You can pull in many different classification trainers
|
|
model (tribuo/train-classification (XGBoostClassificationTrainer. 6) train-ds :label)
|
|
predict-ds (tribuo/predict-classification model (ds/remove-columns test-ds [:label]))
|
|
num-correct (dfn/sum (dfn/eq (predict-ds :prediction)
|
|
;;reverse map the categorical mapping to get back string
|
|
;;labels
|
|
(-> (ds-cat/invert-categorical-map test-ds cat-data)
|
|
(ds/column :label))))
|
|
accuracy (/ num-correct (ds/row-count test-ds))]
|
|
(is (not (nil? (ds/column predict-ds :prediction))))
|
|
(is (== 2 (ds/column-count (ds/drop-columns predict-ds [:prediction]))))
|
|
(is (> accuracy 0.9))))
|
|
(testing "logistic"
|
|
(let [{:keys [test-ds train-ds]} (ds-model/train-test-split ds)
|
|
model (tribuo/train-classification (LogisticRegressionTrainer.) train-ds :label)
|
|
predict-ds (tribuo/predict-classification model (ds/remove-columns test-ds [:label]))
|
|
])))
|
|
)
|
|
|
|
|
|
(deftest regression-pathway
|
|
(let [ds (ds/->dataset "test/data/winequality-red.csv" {:separator \;})
|
|
target-cname "quality"
|
|
{:keys [test-ds train-ds]} (ds-model/train-test-split ds)
|
|
model (tribuo/train-regression (XGBoostRegressionTrainer. 50) train-ds "quality")
|
|
predictions (tribuo/predict-regression model (ds/remove-columns test-ds ["quality"]))
|
|
mae (-> (dfn/- (predictions :prediction) (test-ds "quality"))
|
|
(dfn/abs)
|
|
(dfn/mean))]
|
|
(is (< mae 0.5))))
|
|
|
|
(deftest tribuo-trainer
|
|
(let [config-components
|
|
[{:name "trainer"
|
|
:type "org.tribuo.classification.dtree.CARTClassificationTrainer"
|
|
:properties {:maxDepth "6"
|
|
:seed "12345"
|
|
:fractionFeaturesInSplit "0.5"}}]
|
|
trainer (tribuo/trainer config-components "trainer")]
|
|
|
|
(is (= "class org.tribuo.classification.dtree.CARTClassificationTrainer"
|
|
(str (class trainer))))))
|
|
|
|
|
|
(deftest test-keyword-name
|
|
(testing "string name (OK)"
|
|
(is (-> (ds/->dataset [{"a" 1}] {:dataset-name "string name"})
|
|
(tribuo/make-regression-datasource "a"))))
|
|
(testing "keyword name (Error)"
|
|
(is (-> (ds/->dataset [{"a" 1}] {:dataset-name :keyword/name})
|
|
(tribuo/make-regression-datasource "a")))))
|