(ns tech.v3.dataset.tribuo-test (:require [tech.v3.datatype :as dtype] [tech.v3.datatype.functional :as dfn] [tech.v3.dataset :as ds] [tech.v3.dataset.modelling :as ds-model] [tech.v3.dataset.categorical :as ds-cat] [tech.v3.libs.tribuo :as tribuo] [clojure.test :refer [deftest is testing]]) (:import [org.tribuo.classification.xgboost XGBoostClassificationTrainer] [org.tribuo.classification.sgd.linear LogisticRegressionTrainer] [org.tribuo.regression.xgboost XGBoostRegressionTrainer])) (defn classification-example-ds [x] (let [x (if (integer? x) (vec (repeatedly x rand)) x) y (repeatedly (count x) rand) label (dtype/emap #(if (< 0.25 % 0.75) "green" "red") :string x)] (ds/->dataset {:x x :y y :label label}))) (deftest classification-pathway (let [ds (classification-example-ds 10000)] (testing "xgboost" (let [ ;;This is not necessary for tribuo's classification pathway. Below is just setup ;;to test that if someone has a pipeline that is already using categorical mapping ;;they can still classify without too many changes. cat-data (ds-cat/fit-categorical-map ds :label) ds (ds-cat/transform-categorical-map ds cat-data) {:keys [test-ds train-ds]} (ds-model/train-test-split ds) ;;You can pull in many different classification trainers model (tribuo/train-classification (XGBoostClassificationTrainer. 6) train-ds :label) predict-ds (tribuo/predict-classification model (ds/remove-columns test-ds [:label])) num-correct (dfn/sum (dfn/eq (predict-ds :prediction) ;;reverse map the categorical mapping to get back string ;;labels (-> (ds-cat/invert-categorical-map test-ds cat-data) (ds/column :label)))) accuracy (/ num-correct (ds/row-count test-ds))] (is (not (nil? (ds/column predict-ds :prediction)))) (is (== 2 (ds/column-count (ds/drop-columns predict-ds [:prediction])))) (is (> accuracy 0.9)))) (testing "logistic" (let [{:keys [test-ds train-ds]} (ds-model/train-test-split ds) model (tribuo/train-classification (LogisticRegressionTrainer.) train-ds :label) predict-ds (tribuo/predict-classification model (ds/remove-columns test-ds [:label])) ]))) ) (deftest regression-pathway (let [ds (ds/->dataset "test/data/winequality-red.csv" {:separator \;}) target-cname "quality" {:keys [test-ds train-ds]} (ds-model/train-test-split ds) model (tribuo/train-regression (XGBoostRegressionTrainer. 50) train-ds "quality") predictions (tribuo/predict-regression model (ds/remove-columns test-ds ["quality"])) mae (-> (dfn/- (predictions :prediction) (test-ds "quality")) (dfn/abs) (dfn/mean))] (is (< mae 0.5)))) (deftest tribuo-trainer (let [config-components [{:name "trainer" :type "org.tribuo.classification.dtree.CARTClassificationTrainer" :properties {:maxDepth "6" :seed "12345" :fractionFeaturesInSplit "0.5"}}] trainer (tribuo/trainer config-components "trainer")] (is (= "class org.tribuo.classification.dtree.CARTClassificationTrainer" (str (class trainer)))))) (deftest test-keyword-name (testing "string name (OK)" (is (-> (ds/->dataset [{"a" 1}] {:dataset-name "string name"}) (tribuo/make-regression-datasource "a")))) (testing "keyword name (Error)" (is (-> (ds/->dataset [{"a" 1}] {:dataset-name :keyword/name}) (tribuo/make-regression-datasource "a")))))