Files
2026-02-08 11:20:43 -10:00

223 lines
8.7 KiB
Clojure
Vendored

(ns tech.v3.dataset.mapseq-test
(:require [tech.v3.dataset :as ds]
[tech.v3.dataset.column :as ds-col]
[tech.v3.dataset.column-filters :as cf]
[tech.v3.dataset.math :as ds-math]
[tech.v3.dataset.modelling :as ds-mod]
[tech.v3.dataset.categorical :as ds-cat]
[tech.v3.dataset.test-utils :as test-utils]
[tech.v3.datatype :as dtype]
[tech.v3.datatype.functional :as dtype-fn]
[tech.v3.tensor :as dtt]
[clojure.set :as set]
[clojure.test :refer [deftest is testing]]))
(deftest mapseq-classification-test
(let [src-ds (test-utils/mapseq-fruit-dataset)
dataset (ds/bind-> src-ds ds
(ds/remove-columns [:fruit-subtype :fruit-label])
(ds/categorical->number cf/categorical)
(ds/update (cf/difference ds (cf/categorical ds))
#(ds-math/transform-minmax % (ds-math/fit-minmax %)))
(ds-mod/set-inference-target :fruit-name))
mapseq-ds (ds/mapseq-reader (test-utils/mapseq-fruit-dataset))
src-keys (set (keys (first mapseq-ds)))
result-keys (->> (ds/columns dataset)
(map ds-col/column-name)
(set))
non-categorical (ds/column-names
(cf/difference dataset (cf/categorical dataset)))]
(is (= #{59}
(->> (ds/columns dataset)
(map dtype/ecount)
set)))
;;Column names can be keywords.
(is (= src-keys
(set (->> (ds/columns src-ds)
(map ds-col/column-name)))))
(is (= (set/difference src-keys #{:fruit-subtype :fruit-label})
result-keys))
;; Map back from values to keys for labels. For tablesaw, column values
;; are never keywords.
(is (= (mapv :fruit-name mapseq-ds)
(vec (first (vals (ds-mod/labels dataset))))))
(is (= {:fruit-name :classification}
(ds-mod/model-type dataset)))
(is (= {:fruit-name :classification,
:mass :regression,
:width :regression,
:height :regression,
:color-score :regression}
(ds-mod/model-type dataset (ds/column-names dataset))))
;;Does the post-transformation value of fruit-name map to the
;;pre-transformation value of fruit-name?
(is (= (mapv :fruit-name mapseq-ds)
(->> (ds-cat/reverse-map-categorical-xforms dataset)
(ds/mapseq-reader)
(mapv :fruit-name))))
(is (= (as-> (ds/select dataset :all (range 10)) dataset
(ds/mapseq-reader dataset)
(group-by :fruit-name dataset))
(as-> (ds/select dataset :all (range 10)) ds
(ds/group-by-column ds :fruit-name)
(map (fn [[k group-ds]]
[k (vec (ds/mapseq-reader group-ds))])
ds)
(into {} ds))))
;;forward map from input value to encoded value.
;;After ETL, column values are all doubles
(let [apple-value (get (ds-mod/inference-target-label-map dataset) :apple)]
(is (= #{:apple}
(as-> dataset ds
(ds/filter ds #(= apple-value (:fruit-name %)))
;;Use full version of ->flyweight to do reverse mapping of numeric
;;fruit name back to input label.
(ds-cat/reverse-map-categorical-xforms ds)
(ds/mapseq-reader ds)
(map :fruit-name ds)
(set ds)))))
;; Ensure range map works
(is (= (vec (repeat (count non-categorical) [-0.5 0.5]))
(->> non-categorical
(mapv (fn [colname]
(let [{col-min :min
col-max :max} (-> (ds/column dataset colname)
(ds-col/stats [:min :max]))]
[col-min col-max]))))))
;;Concatenation should work
(is (= (mapv :fruit-name
(concat mapseq-ds mapseq-ds))
(->> (-> (ds/concat dataset dataset)
(ds-cat/reverse-map-categorical-xforms)
(ds/mapseq-reader))
(mapv :fruit-name))))
(let [new-ds (ds/bind-> (ds/->dataset (map hash-map (repeat :mass) (range 20))) dataset
;;The mean should happen in double or floating point space.
(assoc :mass-avg
(dtype-fn/fixed-rolling-window
(dtype/elemwise-cast (dataset :mass) :float64)
5 dtype-fn/mean)))]
(is (= [{:mass 0, :mass-avg 0.6}
{:mass 1, :mass-avg 1.2}
{:mass 2, :mass-avg 2.0}
{:mass 3, :mass-avg 3.0}
{:mass 4, :mass-avg 4.0}
{:mass 5, :mass-avg 5.0}
{:mass 6, :mass-avg 6.0}
{:mass 7, :mass-avg 7.0}
{:mass 8, :mass-avg 8.0}
{:mass 9, :mass-avg 9.0}]
(-> (ds/select new-ds [:mass :mass-avg] (range 10))
ds/mapseq-reader)))
(let [sorted-ds (ds/sort-by-column new-ds :mass-avg >)]
(is (= [{:mass 19, :mass-avg 18.4}
{:mass 18, :mass-avg 17.8}
{:mass 17, :mass-avg 17.0}
{:mass 16, :mass-avg 16.0}
{:mass 15, :mass-avg 15.0}
{:mass 14, :mass-avg 14.0}
{:mass 13, :mass-avg 13.0}
{:mass 12, :mass-avg 12.0}
{:mass 11, :mass-avg 11.0}
{:mass 10, :mass-avg 10.0}]
(-> (ds/select sorted-ds [:mass :mass-avg] (range 10))
ds/mapseq-reader)))))
(let [nth-db (ds/take-nth src-ds 5)]
(is (= [7 12] (dtype/shape nth-db)))
(is (= [{:mass 192.0, :width 8}
{:mass 80.0, :width 5}
{:mass 166.0, :width 6}
{:mass 156.0, :width 7}
{:mass 160.0, :width 7}
{:mass 356.0, :width 9}
{:mass 158.0, :width 7}
{:mass 150.0, :width 7}
{:mass 154.0, :width 7}
{:mass 186.0, :width 7}]
(->> (-> (ds/select nth-db [:mass :width] (range 10))
ds/mapseq-reader)
(map #(update % :width int))))))))
(deftest one-hot
(testing "Testing one-hot into multiple column groups"
(let [src-ds (test-utils/mapseq-fruit-dataset)
dataset (-> src-ds
(ds/remove-columns [:fruit-subtype :fruit-label])
(ds-mod/set-inference-target :fruit-name)
(ds/categorical->one-hot [:fruit-name]))]
(is (= {:one-hot-table
{:orange :fruit-name-orange,
:mandarin :fruit-name-mandarin,
:apple :fruit-name-apple,
:lemon :fruit-name-lemon},
:src-column :fruit-name,
:result-datatype :int64}
(into {} (first (ds-cat/dataset->one-hot-maps dataset)))))
(is (= #{:mass :fruit-name-orange :fruit-name-mandarin :width :fruit-name-apple :color-score
:fruit-name-lemon :height}
(->> (ds/columns dataset)
(map ds-col/column-name)
set)))
(is (= (->> (ds/mapseq-reader src-ds)
(take 20)
(mapv :fruit-name))
(->> (first (vals (ds-mod/labels dataset)))
(take 20)
vec)))
(is (= {:color-score :regression,
:fruit-name-orange :classification,
:fruit-name-lemon :classification,
:fruit-name-mandarin :classification,
:fruit-name-apple :classification,
:height :regression
:width :regression,
:mass :regression,
}
(ds-mod/model-type dataset (ds/column-names dataset)))))))
(deftest generalized-mapseq-ds
(let [ds (ds/->dataset [{:a 1 :b {:a 1 :b 2}}
{:a 2}])]
(is (= #{:int64 :persistent-map}
(set (map dtype/get-datatype (vals ds)))))))
(deftest tensors-in-mapseq
(let [ds (ds/->dataset [{:a (dtt/->tensor (partition 3 (range 9)))
:b "hello"}
{:a (dtt/->tensor (partition 3 (range 9)))
:b "goodbye"}])]
(is (= #{:tensor :string}
(set (map dtype/get-datatype (vals ds)))))))
(deftest datetime-missing
(let [ds (ds/->dataset [{:d "1971-01-01"}
{:d "1970-01-01"}
{:d nil}
{:d "0001-01-01"}]
{:parser-fn {:d :local-date}})]
(is (= 1 (dtype/ecount (ds-col/missing (ds :d)))))))