Files
df-research/tech.ml.dataset/test/tech/v3/dataset/modelling_test.clj
2026-02-08 11:20:43 -10:00

46 lines
1.8 KiB
Clojure
Vendored

(ns tech.v3.dataset.modelling-test
(:require [tech.v3.dataset.modelling :as modelling]
[tech.v3.dataset :as ds]
[tech.v3.dataset.categorical :as ds-cat]
[tech.v3.dataset.test-utils :as test-utils]
[tech.v3.datatype :as dtype]
[clojure.test :refer [deftest is]]))
(deftest k-fold-sanity
(let [dataset-seq (modelling/k-fold-datasets (test-utils/mapseq-fruit-dataset) 5 {})]
(is (= 5 (count dataset-seq)))
(is (= [[7 47] [7 47] [7 47] [7 47] [7 48]]
(->> dataset-seq
(mapv (comp dtype/shape :train-ds)))))
(is (= [[7 12] [7 12] [7 12] [7 12] [7 11]]
(->> dataset-seq
(mapv (comp dtype/shape :test-ds)))))))
(deftest train-test-split-sanity
(let [dataset (modelling/train-test-split
(test-utils/mapseq-fruit-dataset) {})]
(is (= [7 41]
(dtype/shape (:train-ds dataset))))
(is (= [7 18]
(dtype/shape (:test-ds dataset))))))
(deftest prob-dist->label-col
(let [ds (ds/->dataset (tech.v3.dataset/->dataset
{:y-0 [0.0 0.5 0.3 0.1]
:y-1 [0.3 0.8 0.2 0.3]}))
prob-dist-ds (modelling/probability-distributions->label-column ds :y)
label-ds (ds-cat/reverse-map-categorical-xforms prob-dist-ds)]
(is (= [:y-1 :y-1 :y-0 :y-1]
(label-ds :y)))))
(deftest issue-267-prob-dist-fail-on-nan-missing
(is (thrown? Throwable
(-> (tech.v3.dataset/->dataset {:y-0 [Double/NaN] :y-1 [0.3]})
(modelling/probability-distributions->label-column :y))))
(is (thrown? Throwable
(-> (tech.v3.dataset/->dataset {:y-0 [nil] :y-1 [0.3]} )
(tech.v3.dataset.modelling/probability-distributions->label-column :y)))))