(ns tech.v3.dataset.categorical-test (:require [tech.v3.dataset.categorical :as ds-cat] [tech.v3.dataset.modelling :as ds-mod] [tech.v3.datatype :as dtype] [tech.v3.dataset.column-filters :as cf] [clojure.test :refer [deftest is] :as t] [tech.v3.dataset :as ds])) (deftest prediction (is (= [:no :yes] (-> (ds/->dataset {:yes [0.3 0.5] :no [0.7 0.5]}) (ds-mod/probability-distributions->label-column :val) (ds-cat/reverse-map-categorical-xforms) :val)))) (deftest prob-dist (let [prob (-> (ds/->dataset {:yes [0.3 0.5] :no [0.7 0.5]}) (ds-mod/probability-distributions->label-column :val) (ds-cat/reverse-map-categorical-xforms))] (is (= (:yes prob) [0.3 0.5])) (is (= (:no prob) [0.7 0.5])) (is (= (:val prob) [:no :yes])))) (deftest cat-to-number (is (= (set (-> (ds/->dataset {:x [:a :b] :y ["1" "0"]}) (ds/categorical->number [:y]) :y)) (set [0 1])))) (defn- cat->num [table-args] (-> (ds/->dataset {:y [:a :b :c :d]}) (ds/categorical->number [:y] table-args) :y meta :categorical-map :lookup-table clojure.set/map-invert)) (deftest test-categorical->number [] (is (= {5 :a, 2 :b, 0 :d, 1 :c} (cat->num [[:a 5] [:b 2]]))) (is (= {5 :a, 0 :b, 1 :d, 2 :c} (cat->num [[:a 5] [:b 0]]))) (is (= (cat->num []) {0 :d, 1 :c, 2 :a, 3 :b})) (is (= (cat->num [[:not-present 1]]) {1 :not-present, 0 :d, 2 :c, 3 :a, 4 :b})) (is (= (cat->num [[:a 1 :b 1]]) {1 :a, 0 :d, 2 :c, 3 :b}))) (deftest cat-map-regression (is (every? #(Double/isFinite %) (-> (ds/->dataset "test/data/titanic.csv") (ds/update-column "Survived" (fn [col] (let [val-map {0 :drowned 1 :survived}] (dtype/emap val-map :keyword col)))) (ds/categorical->number cf/categorical) (ds/column "Survived"))))) (deftest categorical-assignments-are-integers (is (= #{0 1 2 3} (-> (ds/->dataset {:x1 [1 2 4 5 6 5 6 7] :x2 [5 6 6 7 8 2 4 6] :y [:a :b :b :a :c :a :b :d]}) (ds/categorical->number [:y]) (get :y) distinct set)))) (defn- =-invert-cat [target-1 target-2 lookup-one lookup-two result-datatype expected-result ] (let [ds (ds/->dataset {:target [target-1 target-2]}) inverted (ds-cat/invert-categorical-map ds {:lookup-table {:one lookup-one :two lookup-two}, :src-column :target, :result-datatype result-datatype}) inverted-target (-> inverted :target)] (= expected-result inverted-target))) ;(format "expected %s, found: %s" expected-result) (seq inverted-target))) (deftest invert-cat--works (is (=-invert-cat 1 2 1 2 :int [:one :two])) ; TODO - should pass ? (is (=-invert-cat 1.0 2.0 1 2 :int [:one :two])) ; TODO - should pass ? (is (=-invert-cat 1.99999 2.99999 1 2 :int [:one :two])) ; TODO - should pass ? (is (=-invert-cat 1.2 1.3 1 2 :int [:one :one]))) (deftest invert-cat--throws (is (thrown? Exception (=-invert-cat 1.0 2.0 1.0 2.0 :float [:one :two]) ;; => Execution error at tech.v3.dataset.categorical/invert-categorical-map$fn (categorical.clj:177). ;; Unable to find src value for numeric value 1.0 )) (is (thrown? Exception (=-invert-cat 1 2 4 5 :int [:one :two]))) ;; => Execution error at tech.v3.dataset.categorical/invert-categorical-map$fn (categorical.clj:177). ;; Unable to find src value for numeric value 1 (is (thrown? Exception (=-invert-cat 1 2 1.0 2.0 :int [:one :two])))) ;; => Execution error at tech.v3.dataset.categorical/invert-categorical-map$fn (categorical.clj:177). ;; Unable to find src value for numeric value 1 (defn- is-roundtrip-ok [raw-model-prediction] (let [ train-ds (-> (ds/->dataset {:target [:a :b :c]}) (ds/categorical->number [:target]) ) cat-map (-> train-ds :target meta :categorical-map) prediction-ds (-> (ds/->dataset {:target raw-model-prediction}) (ds/assoc-metadata [:target] :categorical-map cat-map) (ds-cat/reverse-map-categorical-xforms))] (is (= [:c :a :b] (:target prediction-ds))) )) (deftest round-trip ;; only this should pass (is-roundtrip-ok [0 1 2]) ;; currently these all pass, while I would like them to all fail (is-roundtrip-ok [0.0 1.2 2.2]) (is-roundtrip-ok [0.9 1.9 2.9]) (is-roundtrip-ok (float-array [0 1 2])) (is-roundtrip-ok (float-array [0 1.9 2.9])) (is-roundtrip-ok (double-array [0 1.5 2.2])))