Open In Colab

Lesson 16: Case Studies

Pragmatic AI Labs

alt text

This notebook was produced by Pragmatic AI Labs. You can continue learning about these topics by:

16.4 Ludwig (Open Source AutoML)

Github Project URL: https://uber.github.io/ludwig/

alt text

Install Ludwig

!pip install --upgrade numpy #must restart colab runtime
!pip install --upgrade scikit-image
!pip install -q ludwig
!python -m spacy download en 
Requirement already up-to-date: numpy in /usr/local/lib/python3.6/dist-packages (1.16.1)
Collecting scikit-image
[?25l  Downloading https://files.pythonhosted.org/packages/24/06/d560630eb9e36d90d69fe57d9ff762d8f501664ce478b8a0ae132b3c3008/scikit_image-0.14.2-cp36-cp36m-manylinux1_x86_64.whl (25.3MB)
    100% |████████████████████████████████| 25.3MB 1.9MB/s 
[?25hCollecting pillow>=4.3.0 (from scikit-image)
[?25l  Downloading https://files.pythonhosted.org/packages/85/5e/e91792f198bbc5a0d7d3055ad552bc4062942d27eaf75c3e2783cf64eae5/Pillow-5.4.1-cp36-cp36m-manylinux1_x86_64.whl (2.0MB)
    100% |████████████████████████████████| 2.0MB 18.3MB/s 
[?25hRequirement already satisfied, skipping upgrade: scipy>=0.17.0 in /usr/local/lib/python3.6/dist-packages (from scikit-image) (1.1.0)
Requirement already satisfied, skipping upgrade: matplotlib>=2.0.0 in /usr/local/lib/python3.6/dist-packages (from scikit-image) (3.0.2)
Requirement already satisfied, skipping upgrade: six>=1.10.0 in /usr/local/lib/python3.6/dist-packages (from scikit-image) (1.11.0)
Requirement already satisfied, skipping upgrade: cloudpickle>=0.2.1 in /usr/local/lib/python3.6/dist-packages (from scikit-image) (0.6.1)
Requirement already satisfied, skipping upgrade: PyWavelets>=0.4.0 in /usr/local/lib/python3.6/dist-packages (from scikit-image) (1.0.1)
Requirement already satisfied, skipping upgrade: networkx>=1.8 in /usr/local/lib/python3.6/dist-packages (from scikit-image) (2.2)
Collecting dask[array]>=1.0.0 (from scikit-image)
[?25l  Downloading https://files.pythonhosted.org/packages/7c/2b/cf9e5477bec3bd3b4687719876ea38e9d8c9dc9d3526365c74e836e6a650/dask-1.1.1-py2.py3-none-any.whl (701kB)
    100% |████████████████████████████████| 706kB 25.2MB/s 
[?25hRequirement already satisfied, skipping upgrade: numpy>=1.8.2 in /usr/local/lib/python3.6/dist-packages (from scipy>=0.17.0->scikit-image) (1.16.1)
Requirement already satisfied, skipping upgrade: cycler>=0.10 in /usr/local/lib/python3.6/dist-packages (from matplotlib>=2.0.0->scikit-image) (0.10.0)
Requirement already satisfied, skipping upgrade: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib>=2.0.0->scikit-image) (2.3.1)
Requirement already satisfied, skipping upgrade: kiwisolver>=1.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib>=2.0.0->scikit-image) (1.0.1)
Requirement already satisfied, skipping upgrade: python-dateutil>=2.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib>=2.0.0->scikit-image) (2.5.3)
Requirement already satisfied, skipping upgrade: decorator>=4.3.0 in /usr/local/lib/python3.6/dist-packages (from networkx>=1.8->scikit-image) (4.3.2)
Requirement already satisfied, skipping upgrade: toolz>=0.7.3; extra == "array" in /usr/local/lib/python3.6/dist-packages (from dask[array]>=1.0.0->scikit-image) (0.9.0)
Requirement already satisfied, skipping upgrade: setuptools in /usr/local/lib/python3.6/dist-packages (from kiwisolver>=1.0.1->matplotlib>=2.0.0->scikit-image) (40.8.0)
featuretools 0.4.1 has requirement pandas>=0.23.0, but you'll have pandas 0.22.0 which is incompatible.
albumentations 0.1.12 has requirement imgaug<0.2.7,>=0.2.5, but you'll have imgaug 0.2.8 which is incompatible.
Installing collected packages: pillow, dask, scikit-image
  Found existing installation: Pillow 4.0.0
    Uninstalling Pillow-4.0.0:
      Successfully uninstalled Pillow-4.0.0
  Found existing installation: dask 0.20.2
    Uninstalling dask-0.20.2:
      Successfully uninstalled dask-0.20.2
  Found existing installation: scikit-image 0.13.1
    Uninstalling scikit-image-0.13.1:
      Successfully uninstalled scikit-image-0.13.1
Successfully installed dask-1.1.1 pillow-5.4.1 scikit-image-0.14.2

Requirement already satisfied: en_core_web_sm==2.0.0 from https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-2.0.0/en_core_web_sm-2.0.0.tar.gz#egg=en_core_web_sm==2.0.0 in /usr/local/lib/python3.6/dist-packages (2.0.0)

    Linking successful
    /usr/local/lib/python3.6/dist-packages/en_core_web_sm -->
    /usr/local/lib/python3.6/dist-packages/spacy/data/en

    You can now load the model via spacy.load('en')


Basic Ideas

  • Training Models
  • Prediction (Inference)
  • Datatypes
  • binary
  • numerical
  • category
  • set
  • bag
  • sequence
  • text
  • timeseries
  • image

Topic Modeling Example

!wget https://raw.githubusercontent.com/uchidalab/book-dataset/master/Task1/book30-listing-train.csv
!wget https://raw.githubusercontent.com/noahgift/recommendations/master/model_definition.yaml
--2019-02-18 02:44:21--  https://raw.githubusercontent.com/uchidalab/book-dataset/master/Task1/book30-listing-train.csv
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 151.101.0.133, 151.101.64.133, 151.101.128.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|151.101.0.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 9728786 (9.3M) [text/plain]
Saving to: ‘book30-listing-train.csv.3’

book30-listing-trai 100%[===================>]   9.28M  --.-KB/s    in 0.1s    

2019-02-18 02:44:23 (64.4 MB/s) - ‘book30-listing-train.csv.3’ saved [9728786/9728786]

--2019-02-18 02:44:24--  https://raw.githubusercontent.com/noahgift/recommendations/master/model_definition.yaml
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 151.101.0.133, 151.101.64.133, 151.101.128.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|151.101.0.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 180 [text/plain]
Saving to: ‘model_definition.yaml.2’

model_definition.ya 100%[===================>]     180  --.-KB/s    in 0s      

2019-02-18 02:44:25 (34.7 MB/s) - ‘model_definition.yaml.2’ saved [180/180]


Ingest

import pandas as pd
df = pd.read_csv("https://media.githubusercontent.com/media/noahgift/recommendations/master/data/book30-listing-train-with-headers.csv")
df = df.drop("Unnamed: 0", axis=1)
df.head()
ASIN FILENAME IMAGE URL TITLE AUTHOR CATEGORYID CATEGORY
0 1404803335 1404803335.jpg http://ecx.images-amazon.com/images/I/51UJnL3T... Magnets: Pulling Together, Pushing Apart (Amaz... Natalie M. Rosinsky 4 Children's Books
1 1446276082 1446276082.jpg http://ecx.images-amazon.com/images/I/51MGUKhk... Energy Security (SAGE Library of International... NaN 10 Engineering & Transportation
2 1491522666 1491522666.jpg http://ecx.images-amazon.com/images/I/51qKvjsi... An Amish Gathering: Life in Lancaster County Beth Wiseman 9 Christian Books & Bibles
3 970096410 0970096410.jpg http://ecx.images-amazon.com/images/I/51qoUENb... City of Rocks Idaho: A Climber's Guide (Region... Dave Bingham 26 Sports & Outdoors
4 8436808053 8436808053.jpg http://ecx.images-amazon.com/images/I/41aDW5pz... Como vencer el insomnio. Tecnicas, reglas y co... Choliz Montanes 11 Health, Fitness & Dieting
df.to_csv("book30-listing-train-with-headers.csv")

EDA

Columns

df.columns
Index(['ASIN', 'FILENAME', 'IMAGE URL', 'TITLE', 'AUTHOR', 'CATEGORYID',
       'CATEGORY'],
      dtype='object')

Shape

df.shape
(51299, 7)

Training w/Ludwig

!head book30-listing-train-with-headers.csv
,ASIN,FILENAME,IMAGE URL,TITLE,AUTHOR,CATEGORYID,CATEGORY
0,1404803335,1404803335.jpg,http://ecx.images-amazon.com/images/I/51UJnL3Tx6L.jpg,"Magnets: Pulling Together, Pushing Apart (Amazing Science)",Natalie M. Rosinsky,4,Children's Books
1,1446276082,1446276082.jpg,http://ecx.images-amazon.com/images/I/51MGUKhkyhL.jpg,Energy Security (SAGE Library of International Security),,10,Engineering & Transportation
2,1491522666,1491522666.jpg,http://ecx.images-amazon.com/images/I/51qKvjsi3ML.jpg,An Amish Gathering: Life in Lancaster County,Beth Wiseman,9,Christian Books & Bibles
3,970096410,0970096410.jpg,http://ecx.images-amazon.com/images/I/51qoUENb1CL.jpg,City of Rocks Idaho: A Climber's Guide (Regional Rock Climbing Series),Dave Bingham,26,Sports & Outdoors
4,8436808053,8436808053.jpg,http://ecx.images-amazon.com/images/I/41aDW5pzZBL.jpg,"Como vencer el insomnio. Tecnicas, reglas y consejos practicos para dormir mejor (BIBLIOTECA PRACTICA) (Spanish Edition)",Choliz Montanes,11,"Health, Fitness & Dieting"
5,1848291388,1848291388.jpg,http://ecx.images-amazon.com/images/I/51Lpg7xmrBL.jpg,John Martin Littlejohn: An Enigma of Osteopathy,John O'Brien,16,Medical Books
6,73402656,0073402656.jpg,http://ecx.images-amazon.com/images/I/51WccSzFUrL.jpg,Chemistry: The Molecular Nature of Matter and Change,Martin Silberberg,23,Science & Math
7,323045979,0323045979.jpg,http://ecx.images-amazon.com/images/I/51rJir5EpnL.jpg,"Mosby's Oncology Nursing Advisor: A Comprehensive Guide to Clinical Practice, 1e",Susan Newton MS  RN  AOCN  AOCNS,16,Medical Books
8,1847176968,1847176968.jpg,http://ecx.images-amazon.com/images/I/61KoC743OzL.jpg,Ireland's Wild Atlantic Way,Carsten Krieger,29,Travel

!cat model_definition.yaml
input_features:
    -
        name: TITLE
        type: text
        encoder: parallel_cnn
        level: word

output_features:
    -
        name: CATEGORY
        type: category
!ludwig experiment --data_csv book30-listing-train-with-headers.csv --model_definition_file model_definition.yaml

 _         _        _      
| |_  _ __| |_ __ _(_)__ _ 
| | || / _` \ V  V / / _` |
|_|\_,_\__,_|\_/\_/|_\__, |
                     |___/ 
ludwig v0.1.0 - Experiment

Experiment name: experiment
Model name: run
Output path: results/experiment_run_0

ludwig_version: '0.1.0'
command: ('/usr/local/bin/ludwig experiment --data_csv '
 'book30-listing-train-with-headers.csv --model_definition_file '
 'model_definition.yaml')
dataset_type: 'book30-listing-train-with-headers.csv'
model_definition: {   'combiner': {'type': 'concat'},
    'input_features': [   {   'encoder': 'parallel_cnn',
                              'level': 'word',
                              'name': 'TITLE',
                              'tied_weights': None,
                              'type': 'text'}],
    'output_features': [   {   'dependencies': [],
                               'loss': {   'class_distance_temperature': 0,
                                           'class_weights': 1,
                                           'confidence_penalty': 0,
                                           'distortion': 1,
                                           'labels_smoothing': 0,
                                           'negative_samples': 0,
                                           'robust_lambda': 0,
                                           'sampler': None,
                                           'type': 'softmax_cross_entropy',
                                           'unique': False,
                                           'weight': 1},
                               'name': 'CATEGORY',
                               'reduce_dependencies': 'sum',
                               'reduce_input': 'sum',
                               'top_k': 3,
                               'type': 'category'}],
    'preprocessing': {   'bag': {   'fill_value': '',
                                    'format': 'space',
                                    'lowercase': 10000,
                                    'missing_value_strategy': 'fill_with_const',
                                    'most_common': False},
                         'binary': {   'fill_value': 0,
                                       'missing_value_strategy': 'fill_with_const'},
                         'category': {   'fill_value': '<UNK>',
                                         'lowercase': False,
                                         'missing_value_strategy': 'fill_with_const',
                                         'most_common': 10000},
                         'force_split': False,
                         'image': {'missing_value_strategy': 'backfill'},
                         'numerical': {   'fill_value': 0,
                                          'missing_value_strategy': 'fill_with_const'},
                         'sequence': {   'fill_value': '',
                                         'format': 'space',
                                         'lowercase': False,
                                         'missing_value_strategy': 'fill_with_const',
                                         'most_common': 20000,
                                         'padding': 'right',
                                         'padding_symbol': '<PAD>',
                                         'sequence_length_limit': 256,
                                         'unknown_symbol': '<UNK>'},
                         'set': {   'fill_value': '',
                                    'format': 'space',
                                    'lowercase': False,
                                    'missing_value_strategy': 'fill_with_const',
                                    'most_common': 10000},
                         'split_probabilities': (0.7, 0.1, 0.2),
                         'stratify': None,
                         'text': {   'char_format': 'characters',
                                     'char_most_common': 70,
                                     'char_sequence_length_limit': 1024,
                                     'fill_value': '',
                                     'lowercase': True,
                                     'missing_value_strategy': 'fill_with_const',
                                     'padding': 'right',
                                     'padding_symbol': '<PAD>',
                                     'unknown_symbol': '<UNK>',
                                     'word_format': 'space_punct',
                                     'word_most_common': 20000,
                                     'word_sequence_length_limit': 256},
                         'timeseries': {   'fill_value': '',
                                           'format': 'space',
                                           'missing_value_strategy': 'fill_with_const',
                                           'padding': 'right',
                                           'padding_value': 0,
                                           'timeseries_length_limit': 256}},
    'training': {   'batch_size': 128,
                    'bucketing_field': None,
                    'decay': False,
                    'decay_rate': 0.96,
                    'decay_steps': 10000,
                    'dropout_rate': 0.0,
                    'early_stop': 3,
                    'epochs': 200,
                    'gradient_clipping': None,
                    'increase_batch_size_on_plateau': 0,
                    'increase_batch_size_on_plateau_max': 512,
                    'increase_batch_size_on_plateau_patience': 5,
                    'increase_batch_size_on_plateau_rate': 2,
                    'learning_rate': 0.001,
                    'learning_rate_warmup_epochs': 5,
                    'optimizer': {   'beta1': 0.9,
                                     'beta2': 0.999,
                                     'epsilon': 1e-08,
                                     'type': 'adam'},
                    'reduce_learning_rate_on_plateau': 0,
                    'reduce_learning_rate_on_plateau_patience': 5,
                    'reduce_learning_rate_on_plateau_rate': 0.5,
                    'regularization_lambda': 0,
                    'regularizer': 'l2',
                    'staircase': False,
                    'validation_field': 'combined',
                    'validation_measure': 'loss'}}

Using full raw csv, no hdf5 and json file with the same name have been found
Building dataset (it may take a while)
Loading NLP pipeline
Writing dataset
Writing train set metadata with vocabulary
Training set: 36059
Validation set: 5042
Test set: 10198
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
Instructions for updating:
Colocations handled automatically by placer.
From /usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
Instructions for updating:
Colocations handled automatically by placer.
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/losses/losses_impl.py:209: to_float (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.cast instead.
From /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/losses/losses_impl.py:209: to_float (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.cast instead.
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/math_ops.py:3066: to_int32 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.cast instead.
From /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/math_ops.py:3066: to_int32 (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Use tf.cast instead.
WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/math_grad.py:102: div (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Deprecated in favor of operator or tf.math.divide.
From /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/math_grad.py:102: div (from tensorflow.python.ops.math_ops) is deprecated and will be removed in a future version.
Instructions for updating:
Deprecated in favor of operator or tf.math.divide.

╒══════════╕
│ TRAINING │
╘══════════╛

2019-02-18 01:21:33.899464: I tensorflow/core/platform/profile_utils/cpu_utils.cc:94] CPU Frequency: 2200000000 Hz
2019-02-18 01:21:33.899801: I tensorflow/compiler/xla/service/service.cc:150] XLA service 0x318ac00 executing computations on platform Host. Devices:
2019-02-18 01:21:33.899835: I tensorflow/compiler/xla/service/service.cc:158]   StreamExecutor device (0): <undefined>, <undefined>
2019-02-18 01:21:34.055715: I tensorflow/stream_executor/cuda/cuda_gpu_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero
2019-02-18 01:21:34.056285: I tensorflow/compiler/xla/service/service.cc:150] XLA service 0x318a100 executing computations on platform CUDA. Devices:
2019-02-18 01:21:34.056320: I tensorflow/compiler/xla/service/service.cc:158]   StreamExecutor device (0): Tesla K80, Compute Capability 3.7
2019-02-18 01:21:34.056733: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1433] Found device 0 with properties: 
name: Tesla K80 major: 3 minor: 7 memoryClockRate(GHz): 0.8235
pciBusID: 0000:00:04.0
totalMemory: 11.17GiB freeMemory: 11.10GiB
2019-02-18 01:21:34.056767: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1512] Adding visible gpu devices: 0
2019-02-18 01:21:43.842054: I tensorflow/core/common_runtime/gpu/gpu_device.cc:984] Device interconnect StreamExecutor with strength 1 edge matrix:
2019-02-18 01:21:43.842116: I tensorflow/core/common_runtime/gpu/gpu_device.cc:990]      0 
2019-02-18 01:21:43.842133: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1003] 0:   N 
2019-02-18 01:21:43.842364: W tensorflow/core/common_runtime/gpu/gpu_bfc_allocator.cc:42] Overriding allow_growth setting because the TF_FORCE_GPU_ALLOW_GROWTH environment variable is set. Original config value was 0.
2019-02-18 01:21:43.842446: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1115] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 10752 MB memory) -> physical GPU (device: 0, name: Tesla K80, pci bus id: 0000:00:04.0, compute capability: 3.7)

Epoch   1
Training:   0% 0/282 [00:00<?, ?it/s]2019-02-18 01:21:44.623868: I tensorflow/stream_executor/dso_loader.cc:152] successfully opened CUDA library libcublas.so.10.0 locally
Training: 100% 282/282 [00:20<00:00, 13.95it/s]
Evaluation train: 100% 282/282 [00:04<00:00, 59.52it/s]
Evaluation vali : 100% 40/40 [00:00<00:00, 59.33it/s]
Evaluation test : 100% 80/80 [00:01<00:00, 52.80it/s]
Took 27.3456s
╒════════════╤════════╤════════════╤═════════════╕
│ CATEGORY   │   loss │   accuracy │   hits_at_k │
╞════════════╪════════╪════════════╪═════════════╡
│ train      │ 3.3077 │     0.0791 │      0.1855 │
├────────────┼────────┼────────────┼─────────────┤
│ vali       │ 3.3093 │     0.0768 │      0.1813 │
├────────────┼────────┼────────────┼─────────────┤
│ test       │ 3.3204 │     0.0757 │      0.1757 │
╘════════════╧════════╧════════════╧═════════════╛
Validation loss on combined improved, model saved


Epoch   2
Training: 100% 282/282 [00:17<00:00, 16.85it/s]
Evaluation train: 100% 282/282 [00:04<00:00, 59.75it/s]
Evaluation vali : 100% 40/40 [00:00<00:00, 60.45it/s]
Evaluation test : 100% 80/80 [00:01<00:00, 59.97it/s]
Took 23.9509s
╒════════════╤════════╤════════════╤═════════════╕
│ CATEGORY   │   loss │   accuracy │   hits_at_k │
╞════════════╪════════╪════════════╪═════════════╡
│ train      │ 3.2823 │     0.0904 │      0.2009 │
├────────────┼────────┼────────────┼─────────────┤
│ vali       │ 3.2871 │     0.0851 │      0.1962 │
├────────────┼────────┼────────────┼─────────────┤
│ test       │ 3.3021 │     0.0828 │      0.1897 │
╘════════════╧════════╧════════════╧═════════════╛
Validation loss on combined improved, model saved


Epoch   3
Training: 100% 282/282 [00:17<00:00, 16.92it/s]
Evaluation train: 100% 282/282 [00:04<00:00, 59.96it/s]
Evaluation vali : 100% 40/40 [00:00<00:00, 59.63it/s]
Evaluation test : 100% 80/80 [00:01<00:00, 60.28it/s]
Took 23.9664s
╒════════════╤════════╤════════════╤═════════════╕
│ CATEGORY   │   loss │   accuracy │   hits_at_k │
╞════════════╪════════╪════════════╪═════════════╡
│ train      │ 3.2728 │     0.0940 │      0.2102 │
├────────────┼────────┼────────────┼─────────────┤
│ vali       │ 3.2773 │     0.0898 │      0.2071 │
├────────────┼────────┼────────────┼─────────────┤
│ test       │ 3.2966 │     0.0838 │      0.1968 │
╘════════════╧════════╧════════════╧═════════════╛
Validation loss on combined improved, model saved


Epoch   4
Training: 100% 282/282 [00:17<00:00, 16.86it/s]
Evaluation train: 100% 282/282 [00:04<00:00, 59.63it/s]
Evaluation vali : 100% 40/40 [00:00<00:00, 61.04it/s]
Evaluation test : 100% 80/80 [00:01<00:00, 59.69it/s]
Took 23.9503s
╒════════════╤════════╤════════════╤═════════════╕
│ CATEGORY   │   loss │   accuracy │   hits_at_k │
╞════════════╪════════╪════════════╪═════════════╡
│ train      │ 3.2530 │     0.0970 │      0.2159 │
├────────────┼────────┼────────────┼─────────────┤
│ vali       │ 3.2623 │     0.0926 │      0.2081 │
├────────────┼────────┼────────────┼─────────────┤
│ test       │ 3.2824 │     0.0884 │      0.2033 │
╘════════════╧════════╧════════════╧═════════════╛
Validation loss on combined improved, model saved


Epoch   5
Training: 100% 282/282 [00:17<00:00, 16.83it/s]
Evaluation train: 100% 282/282 [00:04<00:00, 60.61it/s]
Evaluation vali : 100% 40/40 [00:00<00:00, 59.69it/s]
Evaluation test : 100% 80/80 [00:01<00:00, 60.04it/s]
Took 23.9652s
╒════════════╤════════╤════════════╤═════════════╕
│ CATEGORY   │   loss │   accuracy │   hits_at_k │
╞════════════╪════════╪════════════╪═════════════╡
│ train      │ 3.2445 │     0.0983 │      0.2182 │
├────────────┼────────┼────────────┼─────────────┤
│ vali       │ 3.2562 │     0.0908 │      0.2130 │
├────────────┼────────┼────────────┼─────────────┤
│ test       │ 3.2762 │     0.0875 │      0.2024 │
╘════════════╧════════╧════════════╧═════════════╛
Validation loss on combined improved, model saved


Epoch   6
Training: 100% 282/282 [00:17<00:00, 16.85it/s]
Evaluation train: 100% 282/282 [00:04<00:00, 59.89it/s]
Evaluation vali : 100% 40/40 [00:00<00:00, 60.76it/s]
Evaluation test : 100% 80/80 [00:01<00:00, 60.00it/s]
Took 23.9497s
╒════════════╤════════╤════════════╤═════════════╕
│ CATEGORY   │   loss │   accuracy │   hits_at_k │
╞════════════╪════════╪════════════╪═════════════╡
│ train      │ 3.2367 │     0.1004 │      0.2211 │
├────────────┼────────┼────────────┼─────────────┤
│ vali       │ 3.2543 │     0.0898 │      0.2098 │
├────────────┼────────┼────────────┼─────────────┤
│ test       │ 3.2740 │     0.0868 │      0.2043 │
╘════════════╧════════╧════════════╧═════════════╛
Validation loss on combined improved, model saved


Epoch   7
Training: 100% 282/282 [00:17<00:00, 16.83it/s]
Evaluation train: 100% 282/282 [00:04<00:00, 60.08it/s]
Evaluation vali : 100% 40/40 [00:00<00:00, 61.33it/s]
Evaluation test : 100% 80/80 [00:01<00:00, 60.27it/s]
Took 23.9176s
╒════════════╤════════╤════════════╤═════════════╕
│ CATEGORY   │   loss │   accuracy │   hits_at_k │
╞════════════╪════════╪════════════╪═════════════╡
│ train      │ 3.2357 │     0.1012 │      0.2220 │
├────────────┼────────┼────────────┼─────────────┤
│ vali       │ 3.2567 │     0.0916 │      0.2108 │
├────────────┼────────┼────────────┼─────────────┤
│ test       │ 3.2771 │     0.0880 │      0.2010 │
╘════════════╧════════╧════════════╧═════════════╛
Last improvement of loss on combined happened 1 epoch ago


Epoch   8
Training: 100% 282/282 [00:17<00:00, 16.75it/s]
Evaluation train: 100% 282/282 [00:04<00:00, 59.96it/s]
Evaluation vali : 100% 40/40 [00:00<00:00, 60.53it/s]
Evaluation test : 100% 80/80 [00:01<00:00, 60.30it/s]
Took 23.9056s
╒════════════╤════════╤════════════╤═════════════╕
│ CATEGORY   │   loss │   accuracy │   hits_at_k │
╞════════════╪════════╪════════════╪═════════════╡
│ train      │ 3.2256 │     0.1046 │      0.2259 │
├────────────┼────────┼────────────┼─────────────┤
│ vali       │ 3.2541 │     0.0934 │      0.2114 │
├────────────┼────────┼────────────┼─────────────┤
│ test       │ 3.2751 │     0.0913 │      0.1995 │
╘════════════╧════════╧════════════╧═════════════╛
Validation loss on combined improved, model saved


Epoch   9
Training: 100% 282/282 [00:17<00:00, 16.76it/s]
Evaluation train: 100% 282/282 [00:04<00:00, 60.44it/s]
Evaluation vali : 100% 40/40 [00:00<00:00, 61.39it/s]
Evaluation test : 100% 80/80 [00:01<00:00, 60.05it/s]
Took 23.9047s
╒════════════╤════════╤════════════╤═════════════╕
│ CATEGORY   │   loss │   accuracy │   hits_at_k │
╞════════════╪════════╪════════════╪═════════════╡
│ train      │ 3.2222 │     0.1041 │      0.2277 │
├────────────┼────────┼────────────┼─────────────┤
│ vali       │ 3.2547 │     0.0962 │      0.2164 │
├────────────┼────────┼────────────┼─────────────┤
│ test       │ 3.2755 │     0.0917 │      0.2004 │
╘════════════╧════════╧════════════╧═════════════╛
Last improvement of loss on combined happened 1 epoch ago


Epoch  10
Training: 100% 282/282 [00:17<00:00, 16.97it/s]
Evaluation train: 100% 282/282 [00:04<00:00, 59.94it/s]
Evaluation vali : 100% 40/40 [00:00<00:00, 61.11it/s]
Evaluation test : 100% 80/80 [00:01<00:00, 59.82it/s]
Took 23.9255s
╒════════════╤════════╤════════════╤═════════════╕
│ CATEGORY   │   loss │   accuracy │   hits_at_k │
╞════════════╪════════╪════════════╪═════════════╡
│ train      │ 3.2181 │     0.1053 │      0.2331 │
├────────────┼────────┼────────────┼─────────────┤
│ vali       │ 3.2575 │     0.0958 │      0.2196 │
├────────────┼────────┼────────────┼─────────────┤
│ test       │ 3.2789 │     0.0886 │      0.2082 │
╘════════════╧════════╧════════════╧═════════════╛
Last improvement of loss on combined happened 2 epochs ago


Epoch  11
Training: 100% 282/282 [00:17<00:00, 16.88it/s]
Evaluation train: 100% 282/282 [00:04<00:00, 59.83it/s]
Evaluation vali : 100% 40/40 [00:00<00:00, 60.48it/s]
Evaluation test : 100% 80/80 [00:01<00:00, 60.10it/s]
Took 23.8798s
╒════════════╤════════╤════════════╤═════════════╕
│ CATEGORY   │   loss │   accuracy │   hits_at_k │
╞════════════╪════════╪════════════╪═════════════╡
│ train      │ 3.2211 │     0.1051 │      0.2338 │
├────────────┼────────┼────────────┼─────────────┤
│ vali       │ 3.2667 │     0.0936 │      0.2140 │
├────────────┼────────┼────────────┼─────────────┤
│ test       │ 3.2868 │     0.0891 │      0.2045 │
╘════════════╧════════╧════════════╧═════════════╛
Last improvement of loss on combined happened 3 epochs ago

EARLY STOPPING due to lack of validation improvement, it has been 3 epochs since last validation accuracy improvement

Best validation model epoch: 8
Best validation model loss on validation set combined: 3.2541212318720016
Best validation model loss on test set combined: 3.275094079606602

╒═════════╕
│ PREDICT │
╘═════════╛

Evaluation: 100% 80/80 [00:01<00:00, 57.96it/s]

===== CATEGORY =====
accuracy: 0.0891351245342224
hits_at_k: 0.20445185330456953
loss: 3.286845474856628
overall_stats: { 'avg_f1_score_macro': 0.06812071846149517,
  'avg_f1_score_micro': 0.0891351245342224,
  'avg_f1_score_weighted': 0.06790552679270521,
  'avg_precision_macro': 0.09177260758729153,
  'avg_precision_micro': 0.0891351245342224,
  'avg_precision_weighted': 0.0891351245342224,
  'avg_recall_macro': 0.09056387599530688,
  'avg_recall_micro': 0.0891351245342224,
  'avg_recall_weighted': 0.0891351245342224,
  'kappa_score': 0.058034041734078334,
  'overall_accuracy': 0.0891351245342224}
per_class_stats: {<UNK>: {   'accuracy': 1.0,
    'f1_score': 0,
    'fall_out': 0.0,
    'false_discovery_rate': 1.0,
    'false_negative_rate': 1.0,
    'false_negatives': 0,
    'false_omission_rate': 0.0,
    'false_positive_rate': 0.0,
    'false_positives': 0,
    'hit_rate': 0,
    'informedness': 0.0,
    'markedness': 0.0,
    'matthews_correlation_coefficient': 0,
    'miss_rate': 1.0,
    'negative_predictive_value': 1.0,
    'positive_predictive_value': 0,
    'precision': 0,
    'recall': 0,
    'sensitivity': 0,
    'specificity': 1.0,
    'true_negative_rate': 1.0,
    'true_negatives': 10198,
    'true_positive_rate': 0,
    'true_positives': 0},
  Children's Books: {   'accuracy': 0.9269464600902138,
    'f1_score': 0.10991636798088411,
    'fall_out': 0.02880446004542636,
    'false_discovery_rate': 0.8584615384615385,
    'false_negative_rate': 0.91015625,
    'false_negatives': 466,
    'false_omission_rate': 0.04719943279651573,
    'false_positive_rate': 0.02880446004542636,
    'false_positives': 279,
    'hit_rate': 0.08984375,
    'informedness': 0.06103928995457375,
    'markedness': 0.09433902874194589,
    'matthews_correlation_coefficient': 0.07588403869993006,
    'miss_rate': 0.91015625,
    'negative_predictive_value': 0.9528005672034843,
    'positive_predictive_value': 0.14153846153846153,
    'precision': 0.14153846153846153,
    'recall': 0.08984375,
    'sensitivity': 0.08984375,
    'specificity': 0.9711955399545736,
    'true_negative_rate': 0.9711955399545736,
    'true_negatives': 9407,
    'true_positive_rate': 0.08984375,
    'true_positives': 46},
  Engineering & Transportation: {   'accuracy': 0.963326142380859,
    'f1_score': 0.04591836734693878,
    'fall_out': 0.029946629768728972,
    'false_discovery_rate': 0.9711538461538461,
    'false_negative_rate': 0.8875,
    'false_negatives': 71,
    'false_omission_rate': 0.0071818733562614145,
    'false_positive_rate': 0.029946629768728972,
    'false_positives': 303,
    'hit_rate': 0.1125,
    'informedness': 0.08255337023127107,
    'markedness': 0.02166428048989233,
    'matthews_correlation_coefficient': 0.042290180516003875,
    'miss_rate': 0.8875,
    'negative_predictive_value': 0.9928181266437386,
    'positive_predictive_value': 0.028846153846153848,
    'precision': 0.028846153846153848,
    'recall': 0.1125,
    'sensitivity': 0.1125,
    'specificity': 0.970053370231271,
    'true_negative_rate': 0.970053370231271,
    'true_negatives': 9815,
    'true_positive_rate': 0.1125,
    'true_positives': 9},
  Christian Books & Bibles: {   'accuracy': 0.9656795450088252,
    'f1_score': 0.005681818181818181,
    'fall_out': 0.034229109454688156,
    'false_discovery_rate': 0.9971428571428571,
    'false_negative_rate': 0.5,
    'false_negatives': 1,
    'false_omission_rate': 0.00010154346060109454,
    'false_positive_rate': 0.034229109454688156,
    'false_positives': 349,
    'hit_rate': 0.5,
    'informedness': 0.46577089054531173,
    'markedness': 0.002755599396541797,
    'matthews_correlation_coefficient': 0.035825660983621235,
    'miss_rate': 0.5,
    'negative_predictive_value': 0.9998984565393989,
    'positive_predictive_value': 0.002857142857142857,
    'precision': 0.002857142857142857,
    'recall': 0.5,
    'sensitivity': 0.5,
    'specificity': 0.9657708905453118,
    'true_negative_rate': 0.9657708905453118,
    'true_negatives': 9847,
    'true_positive_rate': 0.5,
    'true_positives': 1},
  Sports & Outdoors: {   'accuracy': 0.963424200823691,
    'f1_score': 0,
    'fall_out': 0.03297244094488194,
    'false_discovery_rate': 1.0,
    'false_negative_rate': 1.0,
    'false_negatives': 38,
    'false_omission_rate': 0.0038527831288655,
    'false_positive_rate': 0.03297244094488194,
    'false_positives': 335,
    'hit_rate': 0.0,
    'informedness': -0.03297244094488194,
    'markedness': -0.0038527831288655,
    'matthews_correlation_coefficient': -0.011271009901067143,
    'miss_rate': 1.0,
    'negative_predictive_value': 0.9961472168711345,
    'positive_predictive_value': 0.0,
    'precision': 0.0,
    'recall': 0.0,
    'sensitivity': 0.0,
    'specificity': 0.9670275590551181,
    'true_negative_rate': 0.9670275590551181,
    'true_negatives': 9825,
    'true_positive_rate': 0.0,
    'true_positives': 0},
  Health, Fitness & Dieting: {   'accuracy': 0.9297901549323396,
    'f1_score': 0.0427807486631016,
    'fall_out': 0.03329248366013071,
    'false_discovery_rate': 0.9532163742690059,
    'false_negative_rate': 0.9605911330049262,
    'false_negatives': 390,
    'false_omission_rate': 0.03956980519480524,
    'false_positive_rate': 0.03329248366013071,
    'false_positives': 326,
    'hit_rate': 0.03940886699507389,
    'informedness': 0.006116383334943132,
    'markedness': 0.0072138205361889085,
    'matthews_correlation_coefficient': 0.0066424763235420695,
    'miss_rate': 0.9605911330049262,
    'negative_predictive_value': 0.9604301948051948,
    'positive_predictive_value': 0.04678362573099415,
    'precision': 0.04678362573099415,
    'recall': 0.03940886699507389,
    'sensitivity': 0.03940886699507389,
    'specificity': 0.9667075163398693,
    'true_negative_rate': 0.9667075163398693,
    'true_negatives': 9466,
    'true_positive_rate': 0.03940886699507389,
    'true_positives': 16},
  Medical Books: {   'accuracy': 0.9540105903118259,
    'f1_score': 0.07495069033530573,
    'fall_out': 0.0315180530620387,
    'false_discovery_rate': 0.9432835820895522,
    'false_negative_rate': 0.8895348837209303,
    'false_negatives': 153,
    'false_omission_rate': 0.0155125215451688,
    'false_positive_rate': 0.0315180530620387,
    'false_positives': 316,
    'hit_rate': 0.11046511627906977,
    'informedness': 0.07894706321703104,
    'markedness': 0.04120389636527899,
    'matthews_correlation_coefficient': 0.05703443355673547,
    'miss_rate': 0.8895348837209303,
    'negative_predictive_value': 0.9844874784548312,
    'positive_predictive_value': 0.056716417910447764,
    'precision': 0.056716417910447764,
    'recall': 0.11046511627906977,
    'sensitivity': 0.11046511627906977,
    'specificity': 0.9684819469379613,
    'true_negative_rate': 0.9684819469379613,
    'true_negatives': 9710,
    'true_positive_rate': 0.11046511627906977,
    'true_positives': 19},
  Science & Math: {   'accuracy': 0.9558737007256325,
    'f1_score': 0.030172413793103446,
    'fall_out': 0.036212525972098564,
    'false_discovery_rate': 0.9812332439678284,
    'false_negative_rate': 0.9230769230769231,
    'false_negatives': 84,
    'false_omission_rate': 0.008549618320610741,
    'false_positive_rate': 0.036212525972098564,
    'false_positives': 366,
    'hit_rate': 0.07692307692307693,
    'informedness': 0.04071055095097842,
    'markedness': 0.010217137711560742,
    'matthews_correlation_coefficient': 0.020394737198102416,
    'miss_rate': 0.9230769230769231,
    'negative_predictive_value': 0.9914503816793893,
    'positive_predictive_value': 0.01876675603217158,
    'precision': 0.01876675603217158,
    'recall': 0.07692307692307693,
    'sensitivity': 0.07692307692307693,
    'specificity': 0.9637874740279014,
    'true_negative_rate': 0.9637874740279014,
    'true_negatives': 9741,
    'true_positive_rate': 0.07692307692307693,
    'true_positives': 7},
  Travel: {   'accuracy': 0.9540105903118259,
    'f1_score': 0.016771488469601678,
    'fall_out': 0.030505433157212658,
    'false_discovery_rate': 0.9870967741935484,
    'false_negative_rate': 0.9760479041916168,
    'false_negatives': 163,
    'false_omission_rate': 0.016484627831715226,
    'false_positive_rate': 0.030505433157212658,
    'false_positives': 306,
    'hit_rate': 0.023952095808383235,
    'informedness': -0.006553337348829458,
    'markedness': -0.0035814020252635803,
    'matthews_correlation_coefficient': -0.004844598606007852,
    'miss_rate': 0.9760479041916168,
    'negative_predictive_value': 0.9835153721682848,
    'positive_predictive_value': 0.012903225806451613,
    'precision': 0.012903225806451613,
    'recall': 0.023952095808383235,
    'sensitivity': 0.023952095808383235,
    'specificity': 0.9694945668427873,
    'true_negative_rate': 0.9694945668427873,
    'true_negatives': 9725,
    'true_positive_rate': 0.023952095808383235,
    'true_positives': 4},
  Business & Money: {   'accuracy': 0.9681310060796234,
    'f1_score': 0,
    'fall_out': 0.030823598704230903,
    'false_discovery_rate': 1.0,
    'false_negative_rate': 1.0,
    'false_negatives': 11,
    'false_omission_rate': 0.0011129097531363819,
    'false_positive_rate': 0.030823598704230903,
    'false_positives': 314,
    'hit_rate': 0.0,
    'informedness': -0.030823598704230903,
    'markedness': -0.0011129097531363819,
    'matthews_correlation_coefficient': -0.00585695173487886,
    'miss_rate': 1.0,
    'negative_predictive_value': 0.9988870902468636,
    'positive_predictive_value': 0.0,
    'precision': 0.0,
    'recall': 0.0,
    'sensitivity': 0.0,
    'specificity': 0.9691764012957691,
    'true_negative_rate': 0.9691764012957691,
    'true_negatives': 9873,
    'true_positive_rate': 0.0,
    'true_positives': 0},
  Cookbooks, Food & Wine: {   'accuracy': 0.9595018631104139,
    'f1_score': 0.019002375296912115,
    'fall_out': 0.03492846571287622,
    'false_discovery_rate': 0.9888268156424581,
    'false_negative_rate': 0.9365079365079365,
    'false_negatives': 59,
    'false_omission_rate': 0.00599593495934958,
    'false_positive_rate': 0.03492846571287622,
    'false_positives': 354,
    'hit_rate': 0.06349206349206349,
    'informedness': 0.028563597779187155,
    'markedness': 0.005177249398192307,
    'matthews_correlation_coefficient': 0.012160627837924515,
    'miss_rate': 0.9365079365079365,
    'negative_predictive_value': 0.9940040650406504,
    'positive_predictive_value': 0.0111731843575419,
    'precision': 0.0111731843575419,
    'recall': 0.06349206349206349,
    'sensitivity': 0.06349206349206349,
    'specificity': 0.9650715342871238,
    'true_negative_rate': 0.9650715342871238,
    'true_negatives': 9781,
    'true_positive_rate': 0.06349206349206349,
    'true_positives': 4},
  Politics & Social Sciences: {   'accuracy': 0.928025102961365,
    'f1_score': 0.0516795865633075,
    'fall_out': 0.035834609494640124,
    'false_discovery_rate': 0.9460916442048517,
    'false_negative_rate': 0.9503722084367245,
    'false_negatives': 383,
    'false_omission_rate': 0.03897425460466064,
    'false_positive_rate': 0.035834609494640124,
    'false_positives': 351,
    'hit_rate': 0.04962779156327544,
    'informedness': 0.013793182068635224,
    'markedness': 0.014934101190487548,
    'matthews_correlation_coefficient': 0.01435230910870509,
    'miss_rate': 0.9503722084367245,
    'negative_predictive_value': 0.9610257453953394,
    'positive_predictive_value': 0.05390835579514825,
    'precision': 0.05390835579514825,
    'recall': 0.04962779156327544,
    'sensitivity': 0.04962779156327544,
    'specificity': 0.9641653905053599,
    'true_negative_rate': 0.9641653905053599,
    'true_negatives': 9444,
    'true_positive_rate': 0.04962779156327544,
    'true_positives': 20},
  Crafts, Hobbies & Home: {   'accuracy': 0.9681310060796234,
    'f1_score': 0,
    'fall_out': 0.0312990580847724,
    'false_discovery_rate': 1.0,
    'false_negative_rate': 1.0,
    'false_negatives': 6,
    'false_omission_rate': 0.000607348921955686,
    'false_positive_rate': 0.0312990580847724,
    'false_positives': 319,
    'hit_rate': 0.0,
    'informedness': -0.0312990580847724,
    'markedness': -0.000607348921955686,
    'matthews_correlation_coefficient': -0.004359982704783838,
    'miss_rate': 1.0,
    'negative_predictive_value': 0.9993926510780443,
    'positive_predictive_value': 0.0,
    'precision': 0.0,
    'recall': 0.0,
    'sensitivity': 0.0,
    'specificity': 0.9687009419152276,
    'true_negative_rate': 0.9687009419152276,
    'true_negatives': 9873,
    'true_positive_rate': 0.0,
    'true_positives': 0},
  Religion & Spirituality: {   'accuracy': 0.957834869582271,
    'f1_score': 0.009216589861751152,
    'fall_out': 0.03517091483896462,
    'false_discovery_rate': 0.994413407821229,
    'false_negative_rate': 0.9736842105263158,
    'false_negatives': 74,
    'false_omission_rate': 0.007520325203252076,
    'false_positive_rate': 0.03517091483896462,
    'false_positives': 356,
    'hit_rate': 0.02631578947368421,
    'informedness': -0.008855125365280436,
    'markedness': -0.0019337330244810769,
    'matthews_correlation_coefficient': -0.004138048858431092,
    'miss_rate': 0.9736842105263158,
    'negative_predictive_value': 0.9924796747967479,
    'positive_predictive_value': 0.00558659217877095,
    'precision': 0.00558659217877095,
    'recall': 0.02631578947368421,
    'sensitivity': 0.02631578947368421,
    'specificity': 0.9648290851610354,
    'true_negative_rate': 0.9648290851610354,
    'true_negatives': 9766,
    'true_positive_rate': 0.02631578947368421,
    'true_positives': 2},
  Literature & Fiction: {   'accuracy': 0.9111590507942734,
    'f1_score': 0.12884615384615386,
    'fall_out': 0.03088559722659945,
    'false_discovery_rate': 0.814404432132964,
    'false_negative_rate': 0.9013254786450663,
    'false_negatives': 612,
    'false_omission_rate': 0.06221408966148212,
    'false_positive_rate': 0.03088559722659945,
    'false_positives': 294,
    'hit_rate': 0.09867452135493372,
    'informedness': 0.06778892412833426,
    'markedness': 0.12338147820555401,
    'matthews_correlation_coefficient': 0.09145434743585469,
    'miss_rate': 0.9013254786450663,
    'negative_predictive_value': 0.9377859103385179,
    'positive_predictive_value': 0.18559556786703602,
    'precision': 0.18559556786703602,
    'recall': 0.09867452135493372,
    'sensitivity': 0.09867452135493372,
    'specificity': 0.9691144027734006,
    'true_negative_rate': 0.9691144027734006,
    'true_negatives': 9225,
    'true_positive_rate': 0.09867452135493372,
    'true_positives': 67},
  Humor & Entertainment: {   'accuracy': 0.9680329476367915,
    'f1_score': 0,
    'fall_out': 0.031492200529775305,
    'false_discovery_rate': 1.0,
    'false_negative_rate': 1.0,
    'false_negatives': 5,
    'false_omission_rate': 0.0005062265870203753,
    'false_positive_rate': 0.031492200529775305,
    'false_positives': 321,
    'hit_rate': 0.0,
    'informedness': -0.031492200529775305,
    'markedness': -0.0005062265870203753,
    'matthews_correlation_coefficient': -0.003992767109655738,
    'miss_rate': 1.0,
    'negative_predictive_value': 0.9994937734129796,
    'positive_predictive_value': 0.0,
    'precision': 0.0,
    'recall': 0.0,
    'sensitivity': 0.0,
    'specificity': 0.9685077994702247,
    'true_negative_rate': 0.9685077994702247,
    'true_negatives': 9872,
    'true_positive_rate': 0.0,
    'true_positives': 0},
  Law: {   'accuracy': 0.9264561678760541,
    'f1_score': 0.05778894472361809,
    'fall_out': 0.03343246846477288,
    'false_discovery_rate': 0.9340974212034384,
    'false_negative_rate': 0.9485458612975392,
    'false_negatives': 424,
    'false_omission_rate': 0.04305005584323285,
    'false_positive_rate': 0.03343246846477288,
    'false_positives': 326,
    'hit_rate': 0.05145413870246085,
    'informedness': 0.01802167023768808,
    'markedness': 0.022852522953328736,
    'matthews_correlation_coefficient': 0.020293857020391354,
    'miss_rate': 0.9485458612975392,
    'negative_predictive_value': 0.9569499441567672,
    'positive_predictive_value': 0.0659025787965616,
    'precision': 0.0659025787965616,
    'recall': 0.05145413870246085,
    'sensitivity': 0.05145413870246085,
    'specificity': 0.9665675315352271,
    'true_negative_rate': 0.9665675315352271,
    'true_negatives': 9425,
    'true_positive_rate': 0.05145413870246085,
    'true_positives': 23},
  Computers & Technology: {   'accuracy': 0.9531280643263385,
    'f1_score': 0.047808764940239036,
    'fall_out': 0.03508597554915016,
    'false_discovery_rate': 0.9671232876712329,
    'false_negative_rate': 0.9124087591240876,
    'false_negatives': 125,
    'false_omission_rate': 0.01271229533204521,
    'false_positive_rate': 0.03508597554915016,
    'false_positives': 353,
    'hit_rate': 0.08759124087591241,
    'informedness': 0.05250526532676236,
    'markedness': 0.02016441699672189,
    'matthews_correlation_coefficient': 0.03253825540148643,
    'miss_rate': 0.9124087591240876,
    'negative_predictive_value': 0.9872877046679548,
    'positive_predictive_value': 0.03287671232876712,
    'precision': 0.03287671232876712,
    'recall': 0.08759124087591241,
    'sensitivity': 0.08759124087591241,
    'specificity': 0.9649140244508498,
    'true_negative_rate': 0.9649140244508498,
    'true_negatives': 9708,
    'true_positive_rate': 0.08759124087591241,
    'true_positives': 12},
  Test Preparation: {   'accuracy': 0.9327319082172975,
    'f1_score': 0.15099009900990099,
    'fall_out': 0.027274598600247058,
    'false_discovery_rate': 0.8128834355828221,
    'false_negative_rate': 0.8734439834024896,
    'false_negatives': 421,
    'false_omission_rate': 0.04264586709886553,
    'false_positive_rate': 0.027274598600247058,
    'false_positives': 265,
    'hit_rate': 0.12655601659751037,
    'informedness': 0.09928141799726342,
    'markedness': 0.1444706973183123,
    'matthews_correlation_coefficient': 0.11976333198778118,
    'miss_rate': 0.8734439834024896,
    'negative_predictive_value': 0.9573541329011345,
    'positive_predictive_value': 0.18711656441717792,
    'precision': 0.18711656441717792,
    'recall': 0.12655601659751037,
    'sensitivity': 0.12655601659751037,
    'specificity': 0.9727254013997529,
    'true_negative_rate': 0.9727254013997529,
    'true_negatives': 9451,
    'true_positive_rate': 0.12655601659751037,
    'true_positives': 61},
  Arts & Photography: {   'accuracy': 0.941557168072171,
    'f1_score': 0.04487179487179488,
    'fall_out': 0.03171076550191876,
    'false_discovery_rate': 0.9573170731707317,
    'false_negative_rate': 0.9527027027027027,
    'false_negatives': 282,
    'false_omission_rate': 0.02857142857142858,
    'false_positive_rate': 0.03171076550191876,
    'false_positives': 314,
    'hit_rate': 0.0472972972972973,
    'informedness': 0.0155865317953785,
    'markedness': 0.014111498257839639,
    'matthews_correlation_coefficient': 0.01483068832779676,
    'miss_rate': 0.9527027027027027,
    'negative_predictive_value': 0.9714285714285714,
    'positive_predictive_value': 0.042682926829268296,
    'precision': 0.042682926829268296,
    'recall': 0.0472972972972973,
    'sensitivity': 0.0472972972972973,
    'specificity': 0.9682892344980812,
    'true_negative_rate': 0.9682892344980812,
    'true_negatives': 9588,
    'true_positive_rate': 0.0472972972972973,
    'true_positives': 14},
  Parenting & Relationships: {   'accuracy': 0.9494018434987253,
    'f1_score': 0.0851063829787234,
    'fall_out': 0.035068438405435054,
    'false_discovery_rate': 0.9359999999999999,
    'false_negative_rate': 0.873015873015873,
    'false_negatives': 165,
    'false_omission_rate': 0.016797312430011146,
    'false_positive_rate': 0.035068438405435054,
    'false_positives': 351,
    'hit_rate': 0.12698412698412698,
    'informedness': 0.09191568857869203,
    'markedness': 0.04720268756998891,
    'matthews_correlation_coefficient': 0.0658685625375291,
    'miss_rate': 0.873015873015873,
    'negative_predictive_value': 0.9832026875699889,
    'positive_predictive_value': 0.064,
    'precision': 0.064,
    'recall': 0.12698412698412698,
    'sensitivity': 0.12698412698412698,
    'specificity': 0.964931561594565,
    'true_negative_rate': 0.964931561594565,
    'true_negatives': 9658,
    'true_positive_rate': 0.12698412698412698,
    'true_positives': 24},
  Romance: {   'accuracy': 0.9173367326926848,
    'f1_score': 0.11542497376705142,
    'fall_out': 0.030138700594431134,
    'false_discovery_rate': 0.8401162790697674,
    'false_negative_rate': 0.909688013136289,
    'false_negatives': 554,
    'false_omission_rate': 0.05622082403085038,
    'false_positive_rate': 0.030138700594431134,
    'false_positives': 289,
    'hit_rate': 0.090311986863711,
    'informedness': 0.06017328626927987,
    'markedness': 0.10366289689938224,
    'matthews_correlation_coefficient': 0.07897934648140213,
    'miss_rate': 0.909688013136289,
    'negative_predictive_value': 0.9437791759691496,
    'positive_predictive_value': 0.15988372093023256,
    'precision': 0.15988372093023256,
    'recall': 0.090311986863711,
    'sensitivity': 0.090311986863711,
    'specificity': 0.9698612994055689,
    'true_negative_rate': 0.9698612994055689,
    'true_negatives': 9300,
    'true_positive_rate': 0.090311986863711,
    'true_positives': 55},
  History: {   'accuracy': 0.9323396744459698,
    'f1_score': 0.07999999999999999,
    'fall_out': 0.030978427563643773,
    'false_discovery_rate': 0.9099099099099099,
    'false_negative_rate': 0.9280575539568345,
    'false_negatives': 387,
    'false_omission_rate': 0.0392295995945261,
    'false_positive_rate': 0.030978427563643773,
    'false_positives': 303,
    'hit_rate': 0.07194244604316546,
    'informedness': 0.04096401847952169,
    'markedness': 0.05086049049556407,
    'matthews_correlation_coefficient': 0.04564482525476266,
    'miss_rate': 0.9280575539568345,
    'negative_predictive_value': 0.9607704004054739,
    'positive_predictive_value': 0.09009009009009009,
    'precision': 0.09009009009009009,
    'recall': 0.07194244604316546,
    'sensitivity': 0.07194244604316546,
    'specificity': 0.9690215724363562,
    'true_negative_rate': 0.9690215724363562,
    'true_negatives': 9478,
    'true_positive_rate': 0.07194244604316546,
    'true_positives': 30},
  Comics & Graphic Novels: {   'accuracy': 0.9580309864679349,
    'f1_score': 0.218978102189781,
    'fall_out': 0.028708612583775106,
    'false_discovery_rate': 0.8270893371757925,
    'false_negative_rate': 0.7014925373134329,
    'false_negatives': 141,
    'false_omission_rate': 0.014313267688559561,
    'false_positive_rate': 0.028708612583775106,
    'false_positives': 287,
    'hit_rate': 0.29850746268656714,
    'informedness': 0.2697988501027919,
    'markedness': 0.15859739513564786,
    'matthews_correlation_coefficient': 0.20685597607247405,
    'miss_rate': 0.7014925373134329,
    'negative_predictive_value': 0.9856867323114404,
    'positive_predictive_value': 0.1729106628242075,
    'precision': 0.1729106628242075,
    'recall': 0.29850746268656714,
    'sensitivity': 0.29850746268656714,
    'specificity': 0.9712913874162249,
    'true_negative_rate': 0.9712913874162249,
    'true_negatives': 9710,
    'true_positive_rate': 0.29850746268656714,
    'true_positives': 60},
  Reference: {   'accuracy': 0.9581290449107668,
    'f1_score': 0.027334851936218676,
    'fall_out': 0.034220156265453494,
    'false_discovery_rate': 0.9829545454545454,
    'false_negative_rate': 0.9310344827586207,
    'false_negatives': 81,
    'false_omission_rate': 0.00822669104204754,
    'false_positive_rate': 0.034220156265453494,
    'false_positives': 346,
    'hit_rate': 0.06896551724137931,
    'informedness': 0.03474536097592584,
    'markedness': 0.008818763503406934,
    'matthews_correlation_coefficient': 0.01750460286002505,
    'miss_rate': 0.9310344827586207,
    'negative_predictive_value': 0.9917733089579525,
    'positive_predictive_value': 0.017045454545454544,
    'precision': 0.017045454545454544,
    'recall': 0.06896551724137931,
    'sensitivity': 0.06896551724137931,
    'specificity': 0.9657798437345465,
    'true_negative_rate': 0.9657798437345465,
    'true_negatives': 9765,
    'true_positive_rate': 0.06896551724137931,
    'true_positives': 6},
  Teen & Young Adult: {   'accuracy': 0.9515591292410277,
    'f1_score': 0.04263565891472868,
    'fall_out': 0.034176962933439636,
    'false_discovery_rate': 0.9689265536723164,
    'false_negative_rate': 0.9320987654320988,
    'false_negatives': 151,
    'false_omission_rate': 0.015339292970337315,
    'false_positive_rate': 0.034176962933439636,
    'false_positives': 343,
    'hit_rate': 0.06790123456790123,
    'informedness': 0.03372427163446168,
    'markedness': 0.015734153357346292,
    'matthews_correlation_coefficient': 0.023035252587315484,
    'miss_rate': 0.9320987654320988,
    'negative_predictive_value': 0.9846607070296627,
    'positive_predictive_value': 0.031073446327683617,
    'precision': 0.031073446327683617,
    'recall': 0.06790123456790123,
    'sensitivity': 0.06790123456790123,
    'specificity': 0.9658230370665604,
    'true_negative_rate': 0.9658230370665604,
    'true_negatives': 9693,
    'true_positive_rate': 0.06790123456790123,
    'true_positives': 11},
  Self-Help: {   'accuracy': 0.8456560109825456,
    'f1_score': 0.11173814898419863,
    'fall_out': 0.025380130330398987,
    'false_discovery_rate': 0.691588785046729,
    'false_negative_rate': 0.9317711922811854,
    'false_negatives': 1352,
    'false_omission_rate': 0.1368836691303027,
    'false_positive_rate': 0.025380130330398987,
    'false_positives': 222,
    'hit_rate': 0.06822880771881461,
    'informedness': 0.04284867738841558,
    'markedness': 0.17152754582296836,
    'matthews_correlation_coefficient': 0.08573055741213308,
    'miss_rate': 0.9317711922811854,
    'negative_predictive_value': 0.8631163308696973,
    'positive_predictive_value': 0.308411214953271,
    'precision': 0.308411214953271,
    'recall': 0.06822880771881461,
    'sensitivity': 0.06822880771881461,
    'specificity': 0.974619869669601,
    'true_negative_rate': 0.974619869669601,
    'true_negatives': 8525,
    'true_positive_rate': 0.06822880771881461,
    'true_positives': 99},
  Calendars: {   'accuracy': 0.8434006667974112,
    'f1_score': 0.19627579265223954,
    'fall_out': 0.014421385860007074,
    'false_discovery_rate': 0.3867924528301887,
    'false_negative_rate': 0.8831635710005992,
    'false_negatives': 1474,
    'false_omission_rate': 0.14919028340080975,
    'false_positive_rate': 0.014421385860007074,
    'false_positives': 123,
    'hit_rate': 0.11683642899940083,
    'informedness': 0.10241504313939376,
    'markedness': 0.46401726376900143,
    'matthews_correlation_coefficient': 0.21799621117424445,
    'miss_rate': 0.8831635710005992,
    'negative_predictive_value': 0.8508097165991902,
    'positive_predictive_value': 0.6132075471698113,
    'precision': 0.6132075471698113,
    'recall': 0.11683642899940083,
    'sensitivity': 0.11683642899940083,
    'specificity': 0.9855786141399929,
    'true_negative_rate': 0.9855786141399929,
    'true_negatives': 8406,
    'true_positive_rate': 0.11683642899940083,
    'true_positives': 195},
  Science Fiction & Fantasy: {   'accuracy': 0.9561678760541282,
    'f1_score': 0.11485148514851486,
    'fall_out': 0.027994401119776025,
    'false_discovery_rate': 0.9061488673139159,
    'false_negative_rate': 0.8520408163265306,
    'false_negatives': 167,
    'false_omission_rate': 0.016887450702801066,
    'false_positive_rate': 0.027994401119776025,
    'false_positives': 280,
    'hit_rate': 0.14795918367346939,
    'informedness': 0.11996478255369336,
    'markedness': 0.07696368198328307,
    'matthews_correlation_coefficient': 0.09608814377255998,
    'miss_rate': 0.8520408163265306,
    'negative_predictive_value': 0.9831125492971989,
    'positive_predictive_value': 0.09385113268608414,
    'precision': 0.09385113268608414,
    'recall': 0.14795918367346939,
    'sensitivity': 0.14795918367346939,
    'specificity': 0.972005598880224,
    'true_negative_rate': 0.972005598880224,
    'true_negatives': 9722,
    'true_positive_rate': 0.14795918367346939,
    'true_positives': 29},
  Mystery, Thriller & Suspense: {   'accuracy': 0.9433222200431457,
    'f1_score': 0.12158054711246201,
    'fall_out': 0.030069859269008847,
    'false_discovery_rate': 0.8813056379821959,
    'false_negative_rate': 0.8753894080996885,
    'false_negatives': 281,
    'false_omission_rate': 0.028496095730656146,
    'false_positive_rate': 0.030069859269008847,
    'false_positives': 297,
    'hit_rate': 0.12461059190031153,
    'informedness': 0.09454073263130258,
    'markedness': 0.09019826628714811,
    'matthews_correlation_coefficient': 0.09234397748018172,
    'miss_rate': 0.8753894080996885,
    'negative_predictive_value': 0.9715039042693439,
    'positive_predictive_value': 0.11869436201780416,
    'precision': 0.11869436201780416,
    'recall': 0.12461059190031153,
    'sensitivity': 0.12461059190031153,
    'specificity': 0.9699301407309912,
    'true_negative_rate': 0.9699301407309912,
    'true_negatives': 9580,
    'true_positive_rate': 0.12461059190031153,
    'true_positives': 40},
  Biographies & Memoirs: {   'accuracy': 0.8951755246126691,
    'f1_score': 0.09329940627650551,
    'fall_out': 0.03210666666666662,
    'false_discovery_rate': 0.8455056179775281,
    'false_negative_rate': 0.9331713244228432,
    'false_negatives': 768,
    'false_omission_rate': 0.0780329201381833,
    'false_positive_rate': 0.03210666666666662,
    'false_positives': 301,
    'hit_rate': 0.06682867557715674,
    'informedness': 0.03472200891049004,
    'markedness': 0.0764614618842887,
    'matthews_correlation_coefficient': 0.05152567865497131,
    'miss_rate': 0.9331713244228432,
    'negative_predictive_value': 0.9219670798618167,
    'positive_predictive_value': 0.1544943820224719,
    'precision': 0.1544943820224719,
    'recall': 0.06682867557715674,
    'sensitivity': 0.06682867557715674,
    'specificity': 0.9678933333333334,
    'true_negative_rate': 0.9678933333333334,
    'true_negatives': 9074,
    'true_positive_rate': 0.06682867557715674,
    'true_positives': 55}}

Finished: experiment_run
Saved to: results/experiment_run_0

16.5 Sklearn Algorithm Cheatsheet

Sklearn model selection

url: https://scikit-learn.org/stable/tutorial/machine_learning_map/index.html

cheat-sheet

Model Explainability with SHAP

https://github.com/slundberg/shap: A unified approach to explain the output of any machine learning model.

Install SHAP

!pip install -q shap
import sklearn
import shap

shap.initjs()

Load Census Data

Adult datasets

  • Predict whether income exceeds $50K/yr based on census data. Also known as “Census Income” dataset.
X,y = shap.datasets.adult()
X_display,y_display = shap.datasets.adult(display=True)
X_train, X_valid, y_train, y_valid = sklearn.model_selection.train_test_split(X, y, test_size=0.2, random_state=7)
X_train.shape, y_train.shape
((26048, 12), (26048,))

Train a k-nearest neighbor Classifier

knn = sklearn.neighbors.KNeighborsClassifier()
knn.fit(X_train, y_train)
KNeighborsClassifier(algorithm='auto', leaf_size=30, metric='minkowski',
           metric_params=None, n_jobs=None, n_neighbors=5, p=2,
           weights='uniform')

Explain predictions

f = lambda x: knn.predict_proba(x)[:,1]
med = X_train.median().values.reshape((1,X_train.shape[1]))
explainer = shap.KernelExplainer(f, med)
shap_values_single = explainer.shap_values(X.iloc[0,:], nsamples=1000)

#Plot
shap.initjs()
shap.force_plot(explainer.expected_value, shap_values_single, X_display.iloc[0,:])
Visualization omitted, Javascript library not loaded!
Have you run `initjs()` in this notebook? If this notebook was from another user you must also trust this notebook (File -> Trust notebook). If you are viewing this notebook on github the Javascript has been stripped for security. If you are using JupyterLab this error is because a JupyterLab extension has not yet been written.

16.6 Recommendations

Install Surprise

 !pip install -q scikit-surprise
    100% |████████████████████████████████| 3.3MB 10.9MB/s 
[?25h  Building wheel for scikit-surprise (setup.py) ... [?25ldone
[?25h
from surprise import SVD
from surprise import Dataset
from surprise.model_selection import cross_validate

# Load the movielens-100k dataset (download it if needed).
data = Dataset.load_builtin('ml-100k')

# Use the famous SVD algorithm.
algo = SVD()

# Run 5-fold cross-validation and print results.
cross_validate(algo, data, measures=['RMSE', 'MAE'], cv=5, verbose=True)
Evaluating RMSE, MAE of algorithm SVD on 5 split(s).

                  Fold 1  Fold 2  Fold 3  Fold 4  Fold 5  Mean    Std     
RMSE (testset)    0.9460  0.9371  0.9344  0.9300  0.9354  0.9366  0.0053  
MAE (testset)     0.7441  0.7390  0.7338  0.7312  0.7397  0.7376  0.0045  
Fit time          5.30    5.22    5.25    5.23    5.23    5.24    0.03    
Test time         0.16    0.26    0.16    0.15    0.16    0.18    0.04    

{'fit_time': (5.302802085876465,
  5.2162816524505615,
  5.2515764236450195,
  5.2256152629852295,
  5.226689577102661),
 'test_mae': array([0.74405382, 0.73902602, 0.73379062, 0.73123877, 0.73968219]),
 'test_rmse': array([0.94601002, 0.93705768, 0.93435584, 0.93001856, 0.93540059]),
 'test_time': (0.16068744659423828,
  0.26168084144592285,
  0.1584162712097168,
  0.1538381576538086,
  0.16183090209960938)}

Handcoded Similarity Engine

"""Data Science Algorithms"""


def tanimoto(list1, list2):
    """tanimoto coefficient
    In [2]: list2=['39229', '31995', '32015']
    In [3]: list1=['31936', '35989', '27489', '39229', '15468', '31993', '26478']
    In [4]: tanimoto(list1,list2)
    Out[4]: 0.1111111111111111
    Uses intersection of two sets to determine numerical score
    """

    intersection = set(list1).intersection(set(list2))
    return float(len(intersection))/(len(list1) + len(list2) - len(intersection))

Collaborative Filtering Recommendation Exploration

Knn Exploration of MovieLens with Surprise

import io  # needed because of weird encoding of u.item file
from surprise import KNNBaseline
from surprise import Dataset
from surprise import get_dataset_dir

Helper Function to Convert IDS to Names

def read_item_names():
    """Read the u.item file from MovieLens 100-k dataset and return two
    mappings to convert raw ids into movie names and movie names into raw ids.
    """

    file_name = get_dataset_dir() + '/ml-100k/ml-100k/u.item'
    rid_to_name = {}
    name_to_rid = {}
    with io.open(file_name, 'r', encoding='ISO-8859-1') as f:
        for line in f:
            line = line.split('|')
            rid_to_name[line[0]] = line[1]
            name_to_rid[line[1]] = line[0]

    return rid_to_name, name_to_rid

Train KNN based model

# First, train the algorithm to compute the similarities between items
data = Dataset.load_builtin('ml-100k')
trainset = data.build_full_trainset()
sim_options = {'name': 'pearson_baseline', 'user_based': False}
algo = KNNBaseline(sim_options=sim_options)
algo.fit(trainset)


Estimating biases using als...
Computing the pearson_baseline similarity matrix...
Done computing similarity matrix.

<surprise.prediction_algorithms.knns.KNNBaseline at 0x7f596007c1d0>

Recommendations

# Read the mappings raw id <-> movie name
rid_to_name, name_to_rid = read_item_names()

# Retrieve inner id of the movie Toy Story
toy_story_raw_id = name_to_rid['Toy Story (1995)']
toy_story_inner_id = algo.trainset.to_inner_iid(toy_story_raw_id)

# Retrieve inner ids of the nearest neighbors of Toy Story.
toy_story_neighbors = algo.get_neighbors(toy_story_inner_id, k=10)

# Convert inner ids of the neighbors into names.
toy_story_neighbors = (algo.trainset.to_raw_iid(inner_id)
                       for inner_id in toy_story_neighbors)
toy_story_neighbors = (rid_to_name[rid]
                       for rid in toy_story_neighbors)

for movie in toy_story_neighbors:
  print(movie)

Beauty and the Beast (1991)
Raiders of the Lost Ark (1981)
That Thing You Do! (1996)
Lion King, The (1994)
Craft, The (1996)
Liar Liar (1997)
Aladdin (1992)
Cool Hand Luke (1967)
Winnie the Pooh and the Blustery Day (1968)
Indiana Jones and the Last Crusade (1989)