(ns tech.v3.dataset.mapseq-test (:require [tech.v3.dataset :as ds] [tech.v3.dataset.column :as ds-col] [tech.v3.dataset.column-filters :as cf] [tech.v3.dataset.math :as ds-math] [tech.v3.dataset.modelling :as ds-mod] [tech.v3.dataset.categorical :as ds-cat] [tech.v3.dataset.test-utils :as test-utils] [tech.v3.datatype :as dtype] [tech.v3.datatype.functional :as dtype-fn] [tech.v3.tensor :as dtt] [clojure.set :as set] [clojure.test :refer [deftest is testing]])) (deftest mapseq-classification-test (let [src-ds (test-utils/mapseq-fruit-dataset) dataset (ds/bind-> src-ds ds (ds/remove-columns [:fruit-subtype :fruit-label]) (ds/categorical->number cf/categorical) (ds/update (cf/difference ds (cf/categorical ds)) #(ds-math/transform-minmax % (ds-math/fit-minmax %))) (ds-mod/set-inference-target :fruit-name)) mapseq-ds (ds/mapseq-reader (test-utils/mapseq-fruit-dataset)) src-keys (set (keys (first mapseq-ds))) result-keys (->> (ds/columns dataset) (map ds-col/column-name) (set)) non-categorical (ds/column-names (cf/difference dataset (cf/categorical dataset)))] (is (= #{59} (->> (ds/columns dataset) (map dtype/ecount) set))) ;;Column names can be keywords. (is (= src-keys (set (->> (ds/columns src-ds) (map ds-col/column-name))))) (is (= (set/difference src-keys #{:fruit-subtype :fruit-label}) result-keys)) ;; Map back from values to keys for labels. For tablesaw, column values ;; are never keywords. (is (= (mapv :fruit-name mapseq-ds) (vec (first (vals (ds-mod/labels dataset)))))) (is (= {:fruit-name :classification} (ds-mod/model-type dataset))) (is (= {:fruit-name :classification, :mass :regression, :width :regression, :height :regression, :color-score :regression} (ds-mod/model-type dataset (ds/column-names dataset)))) ;;Does the post-transformation value of fruit-name map to the ;;pre-transformation value of fruit-name? (is (= (mapv :fruit-name mapseq-ds) (->> (ds-cat/reverse-map-categorical-xforms dataset) (ds/mapseq-reader) (mapv :fruit-name)))) (is (= (as-> (ds/select dataset :all (range 10)) dataset (ds/mapseq-reader dataset) (group-by :fruit-name dataset)) (as-> (ds/select dataset :all (range 10)) ds (ds/group-by-column ds :fruit-name) (map (fn [[k group-ds]] [k (vec (ds/mapseq-reader group-ds))]) ds) (into {} ds)))) ;;forward map from input value to encoded value. ;;After ETL, column values are all doubles (let [apple-value (get (ds-mod/inference-target-label-map dataset) :apple)] (is (= #{:apple} (as-> dataset ds (ds/filter ds #(= apple-value (:fruit-name %))) ;;Use full version of ->flyweight to do reverse mapping of numeric ;;fruit name back to input label. (ds-cat/reverse-map-categorical-xforms ds) (ds/mapseq-reader ds) (map :fruit-name ds) (set ds))))) ;; Ensure range map works (is (= (vec (repeat (count non-categorical) [-0.5 0.5])) (->> non-categorical (mapv (fn [colname] (let [{col-min :min col-max :max} (-> (ds/column dataset colname) (ds-col/stats [:min :max]))] [col-min col-max])))))) ;;Concatenation should work (is (= (mapv :fruit-name (concat mapseq-ds mapseq-ds)) (->> (-> (ds/concat dataset dataset) (ds-cat/reverse-map-categorical-xforms) (ds/mapseq-reader)) (mapv :fruit-name)))) (let [new-ds (ds/bind-> (ds/->dataset (map hash-map (repeat :mass) (range 20))) dataset ;;The mean should happen in double or floating point space. (assoc :mass-avg (dtype-fn/fixed-rolling-window (dtype/elemwise-cast (dataset :mass) :float64) 5 dtype-fn/mean)))] (is (= [{:mass 0, :mass-avg 0.6} {:mass 1, :mass-avg 1.2} {:mass 2, :mass-avg 2.0} {:mass 3, :mass-avg 3.0} {:mass 4, :mass-avg 4.0} {:mass 5, :mass-avg 5.0} {:mass 6, :mass-avg 6.0} {:mass 7, :mass-avg 7.0} {:mass 8, :mass-avg 8.0} {:mass 9, :mass-avg 9.0}] (-> (ds/select new-ds [:mass :mass-avg] (range 10)) ds/mapseq-reader))) (let [sorted-ds (ds/sort-by-column new-ds :mass-avg >)] (is (= [{:mass 19, :mass-avg 18.4} {:mass 18, :mass-avg 17.8} {:mass 17, :mass-avg 17.0} {:mass 16, :mass-avg 16.0} {:mass 15, :mass-avg 15.0} {:mass 14, :mass-avg 14.0} {:mass 13, :mass-avg 13.0} {:mass 12, :mass-avg 12.0} {:mass 11, :mass-avg 11.0} {:mass 10, :mass-avg 10.0}] (-> (ds/select sorted-ds [:mass :mass-avg] (range 10)) ds/mapseq-reader))))) (let [nth-db (ds/take-nth src-ds 5)] (is (= [7 12] (dtype/shape nth-db))) (is (= [{:mass 192.0, :width 8} {:mass 80.0, :width 5} {:mass 166.0, :width 6} {:mass 156.0, :width 7} {:mass 160.0, :width 7} {:mass 356.0, :width 9} {:mass 158.0, :width 7} {:mass 150.0, :width 7} {:mass 154.0, :width 7} {:mass 186.0, :width 7}] (->> (-> (ds/select nth-db [:mass :width] (range 10)) ds/mapseq-reader) (map #(update % :width int)))))))) (deftest one-hot (testing "Testing one-hot into multiple column groups" (let [src-ds (test-utils/mapseq-fruit-dataset) dataset (-> src-ds (ds/remove-columns [:fruit-subtype :fruit-label]) (ds-mod/set-inference-target :fruit-name) (ds/categorical->one-hot [:fruit-name]))] (is (= {:one-hot-table {:orange :fruit-name-orange, :mandarin :fruit-name-mandarin, :apple :fruit-name-apple, :lemon :fruit-name-lemon}, :src-column :fruit-name, :result-datatype :int64} (into {} (first (ds-cat/dataset->one-hot-maps dataset))))) (is (= #{:mass :fruit-name-orange :fruit-name-mandarin :width :fruit-name-apple :color-score :fruit-name-lemon :height} (->> (ds/columns dataset) (map ds-col/column-name) set))) (is (= (->> (ds/mapseq-reader src-ds) (take 20) (mapv :fruit-name)) (->> (first (vals (ds-mod/labels dataset))) (take 20) vec))) (is (= {:color-score :regression, :fruit-name-orange :classification, :fruit-name-lemon :classification, :fruit-name-mandarin :classification, :fruit-name-apple :classification, :height :regression :width :regression, :mass :regression, } (ds-mod/model-type dataset (ds/column-names dataset))))))) (deftest generalized-mapseq-ds (let [ds (ds/->dataset [{:a 1 :b {:a 1 :b 2}} {:a 2}])] (is (= #{:int64 :persistent-map} (set (map dtype/get-datatype (vals ds))))))) (deftest tensors-in-mapseq (let [ds (ds/->dataset [{:a (dtt/->tensor (partition 3 (range 9))) :b "hello"} {:a (dtt/->tensor (partition 3 (range 9))) :b "goodbye"}])] (is (= #{:tensor :string} (set (map dtype/get-datatype (vals ds))))))) (deftest datetime-missing (let [ds (ds/->dataset [{:d "1971-01-01"} {:d "1970-01-01"} {:d nil} {:d "0001-01-01"}] {:parser-fn {:d :local-date}})] (is (= 1 (dtype/ecount (ds-col/missing (ds :d)))))))