This guide trains a neural network model to classify spotify artists, like taylor swift and beyonce. It is adapted from this colab
import tensorflow as tf
# Helper libraries
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
from google.colab import drive
drive.mount('/content/drive')
Mounted at /content/drive
df_taylor = pd.read_csv("/content/drive/MyDrive/python_scratch/spotify_data/taylor.csv")
df_beyonce = pd.read_csv("/content/drive/MyDrive/python_scratch/spotify_data/beyonce.csv")
df_beatles = pd.read_csv("/content/drive/MyDrive/python_scratch/spotify_data/beatles.csv")
df_nirvana = pd.read_csv("/content/drive/MyDrive/python_scratch/spotify_data/nirvana.csv")
df_rolling = pd.read_csv("/content/drive/MyDrive/python_scratch/spotify_data/rolling_stones.csv")
label_names = [ "Taylor Swift", "Beyonce", "The Beatles", "Nirvana", "The Rolling Stones"]
df_taylor.head()
Unnamed: 0 | artist_name | artist_id | album_id | album_type | album_images | album_release_date | album_release_year | album_release_date_precision | danceability | energy | key | loudness | mode | speechiness | acousticness | instrumentalness | liveness | valence | tempo | track_id | analysis_url | time_signature | artists | available_markets | disc_number | duration_ms | explicit | track_href | is_local | track_name | track_preview_url | track_number | type | track_uri | external_urls.spotify | album_name | key_name | mode_name | key_mode | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 1 | Taylor Swift | 06HL4z0CvFAxyc27GXpf02 | 6DEjYFkNZh67HP7R9PSZvv | album | list(height = c(640, 300, 64), url = c("https:... | 2017-11-10 | 2017 | day | 0.613 | 0.764 | 2 | -6.509 | 1 | 0.1360 | 0.05270 | 0.000000 | 0.1970 | 0.4170 | 160.015 | 2yLa0QULdQr0qAIvVwN6B5 | https://api.spotify.com/v1/audio-analysis/2yLa... | 4 | list(href = "https://api.spotify.com/v1/artist... | c("AD", "AE", "AR", "AT", "AU", "BE", "BG", "B... | 1 | 208186 | False | https://api.spotify.com/v1/tracks/2yLa0QULdQr0... | False | ...Ready For It? | NaN | 1 | track | spotify:track:2yLa0QULdQr0qAIvVwN6B5 | https://open.spotify.com/track/2yLa0QULdQr0qAI... | reputation | D | major | D major |
1 | 2 | Taylor Swift | 06HL4z0CvFAxyc27GXpf02 | 6DEjYFkNZh67HP7R9PSZvv | album | list(height = c(640, 300, 64), url = c("https:... | 2017-11-10 | 2017 | day | 0.649 | 0.589 | 2 | -6.237 | 1 | 0.0558 | 0.00845 | 0.000000 | 0.1080 | 0.1510 | 159.073 | 2x0WlnmfG39ZuDmstl9xfX | https://api.spotify.com/v1/audio-analysis/2x0W... | 4 | list(href = c("https://api.spotify.com/v1/arti... | c("AD", "AE", "AR", "AT", "AU", "BE", "BG", "B... | 1 | 244826 | False | https://api.spotify.com/v1/tracks/2x0WlnmfG39Z... | False | End Game | NaN | 2 | track | spotify:track:2x0WlnmfG39ZuDmstl9xfX | https://open.spotify.com/track/2x0WlnmfG39ZuDm... | reputation | D | major | D major |
2 | 3 | Taylor Swift | 06HL4z0CvFAxyc27GXpf02 | 6DEjYFkNZh67HP7R9PSZvv | album | list(height = c(640, 300, 64), url = c("https:... | 2017-11-10 | 2017 | day | 0.696 | 0.602 | 0 | -6.156 | 0 | 0.1590 | 0.06790 | 0.000021 | 0.0696 | 0.3050 | 82.989 | 4svZDCRz4cJoneBpjpx8DJ | https://api.spotify.com/v1/audio-analysis/4svZ... | 4 | list(href = "https://api.spotify.com/v1/artist... | c("AD", "AE", "AR", "AT", "AU", "BE", "BG", "B... | 1 | 238253 | False | https://api.spotify.com/v1/tracks/4svZDCRz4cJo... | False | I Did Something Bad | NaN | 3 | track | spotify:track:4svZDCRz4cJoneBpjpx8DJ | https://open.spotify.com/track/4svZDCRz4cJoneB... | reputation | C | minor | C minor |
3 | 4 | Taylor Swift | 06HL4z0CvFAxyc27GXpf02 | 6DEjYFkNZh67HP7R9PSZvv | album | list(height = c(640, 300, 64), url = c("https:... | 2017-11-10 | 2017 | day | 0.615 | 0.534 | 9 | -6.719 | 0 | 0.0386 | 0.10600 | 0.000018 | 0.0607 | 0.1930 | 135.917 | 1R0a2iXumgCiFb7HEZ7gUE | https://api.spotify.com/v1/audio-analysis/1R0a... | 4 | list(href = "https://api.spotify.com/v1/artist... | c("AD", "AE", "AR", "AT", "AU", "BE", "BG", "B... | 1 | 236413 | False | https://api.spotify.com/v1/tracks/1R0a2iXumgCi... | False | Don’t Blame Me | NaN | 4 | track | spotify:track:1R0a2iXumgCiFb7HEZ7gUE | https://open.spotify.com/track/1R0a2iXumgCiFb7... | reputation | A | minor | A minor |
4 | 5 | Taylor Swift | 06HL4z0CvFAxyc27GXpf02 | 6DEjYFkNZh67HP7R9PSZvv | album | list(height = c(640, 300, 64), url = c("https:... | 2017-11-10 | 2017 | day | 0.750 | 0.404 | 9 | -10.178 | 0 | 0.0682 | 0.21600 | 0.000357 | 0.0911 | 0.0499 | 95.045 | 6NFyWDv5CjfwuzoCkw47Xf | https://api.spotify.com/v1/audio-analysis/6NFy... | 4 | list(href = "https://api.spotify.com/v1/artist... | c("AD", "AE", "AR", "AT", "AU", "BE", "BG", "B... | 1 | 232253 | False | https://api.spotify.com/v1/tracks/6NFyWDv5Cjfw... | False | Delicate | NaN | 5 | track | spotify:track:6NFyWDv5CjfwuzoCkw47Xf | https://open.spotify.com/track/6NFyWDv5Cjfwuzo... | reputation | A | minor | A minor |
feat_taylor = df_taylor.loc[:, "danceability":"valence"]
feat_beyonce = df_beyonce.loc[:,"danceability":"valence"]
feat_beatles = df_beatles.loc[:,"danceability":"valence"]
feat_nirvana = df_nirvana.loc[:,"danceability":"valence"]
feat_rolling = df_rolling.loc[:,"danceability":"valence"]
features = np.concatenate([
feat_taylor,
feat_beyonce,
feat_beatles,
feat_nirvana,
feat_rolling,
])
len(feat_taylor),len(feat_beyonce),len(feat_beatles),len(feat_nirvana),len(feat_rolling),
(547, 438, 139, 213, 743)
labels = np.concatenate([
np.zeros(len(feat_taylor)) +0,
np.zeros(len(feat_beyonce))+1,
np.zeros(len(feat_beatles))+2,
np.zeros(len(feat_nirvana))+3,
np.zeros(len(feat_rolling))+4
])
features.shape, labels.shape
((2080, 10), (2080,))
from sklearn.model_selection import train_test_split
train_features, test_features, train_labels, test_labels = train_test_split(features, labels)
To see what it looks like
plt.figure(figsize=(10,10))
i=0
for feat, label in zip(train_features, train_labels):
plt.subplot(5,5,i+1)
plt.xticks([])
plt.yticks([])
plt.grid(False)
plt.plot(feat)
plt.xlabel(label_names[int(label)])
i+=1
if i >= 25:
break
plt.show()
Building the neural network requires configuring the layers of the model, then compiling the model.
The basic building block of a neural network is the layer. Layers extract representations from the data fed into them. Hopefully, these representations are meaningful for the problem at hand.
Most of deep learning consists of chaining together simple layers. Most layers, such as tf.keras.layers.Dense
, have parameters that are learned during training.
input_shape = features[0].shape
print(input_shape)
(10,)
model = tf.keras.Sequential(
[
tf.keras.layers.Input(input_shape),
tf.keras.layers.experimental.preprocessing.Normalization(),
tf.keras.layers.Dense(128, activation='relu'),
tf.keras.layers.Dropout(0.3),
tf.keras.layers.Dense(len(label_names))]
)
The network consists of a sequence of two tf.keras.layers.Dense
layers. These are densely connected, or fully connected, neural layers. The first Dense
layer has 128 nodes (or neurons). The second (and last) layer returns a logits array with length of 5. Each node contains a score that indicates the current image belongs to one of the 5 classes.
Before the model is ready for training, it needs a few more settings. These are added during the model's compile step:
model.compile(optimizer='adam',
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'])
model.summary()
Model: "sequential_5" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= normalization_3 (Normalizati (None, 10) 21 _________________________________________________________________ dense_6 (Dense) (None, 128) 1408 _________________________________________________________________ dropout_1 (Dropout) (None, 128) 0 _________________________________________________________________ dense_7 (Dense) (None, 5) 645 ================================================================= Total params: 2,074 Trainable params: 2,053 Non-trainable params: 21 _________________________________________________________________
Training the neural network model requires the following steps:
train_images
and train_labels
arrays.test_images
array.test_labels
array.To start training, call the model.fit
method—so called because it "fits" the model to the training data:
history = model.fit(
x=train_features,
y=train_labels,
batch_size=16,
validation_data=(test_features,test_labels),
epochs=200,
verbose=1,
)
Epoch 1/200 98/98 [==============================] - 1s 4ms/step - loss: 1.8549 - accuracy: 0.2702 - val_loss: 1.3879 - val_accuracy: 0.3865 Epoch 2/200 98/98 [==============================] - 0s 2ms/step - loss: 1.5215 - accuracy: 0.3620 - val_loss: 1.3430 - val_accuracy: 0.4365 Epoch 3/200 98/98 [==============================] - 0s 2ms/step - loss: 1.4507 - accuracy: 0.3788 - val_loss: 1.2912 - val_accuracy: 0.4558 Epoch 4/200 98/98 [==============================] - 0s 2ms/step - loss: 1.4419 - accuracy: 0.3768 - val_loss: 1.2512 - val_accuracy: 0.5173 Epoch 5/200 98/98 [==============================] - 0s 2ms/step - loss: 1.3494 - accuracy: 0.4305 - val_loss: 1.2390 - val_accuracy: 0.4942 Epoch 6/200 98/98 [==============================] - 0s 2ms/step - loss: 1.2862 - accuracy: 0.4691 - val_loss: 1.1962 - val_accuracy: 0.5404 Epoch 7/200 98/98 [==============================] - 0s 2ms/step - loss: 1.2596 - accuracy: 0.4845 - val_loss: 1.1650 - val_accuracy: 0.5712 Epoch 8/200 98/98 [==============================] - 0s 2ms/step - loss: 1.2541 - accuracy: 0.5016 - val_loss: 1.1309 - val_accuracy: 0.5692 Epoch 9/200 98/98 [==============================] - 0s 2ms/step - loss: 1.1691 - accuracy: 0.5548 - val_loss: 1.1227 - val_accuracy: 0.5885 Epoch 10/200 98/98 [==============================] - 0s 2ms/step - loss: 1.1581 - accuracy: 0.5551 - val_loss: 1.0900 - val_accuracy: 0.5865 Epoch 11/200 98/98 [==============================] - 0s 2ms/step - loss: 1.1401 - accuracy: 0.5384 - val_loss: 1.0774 - val_accuracy: 0.5731 Epoch 12/200 98/98 [==============================] - 0s 2ms/step - loss: 1.1348 - accuracy: 0.5648 - val_loss: 1.0497 - val_accuracy: 0.5846 Epoch 13/200 98/98 [==============================] - 0s 2ms/step - loss: 1.0976 - accuracy: 0.5949 - val_loss: 1.0385 - val_accuracy: 0.6019 Epoch 14/200 98/98 [==============================] - 0s 2ms/step - loss: 1.0713 - accuracy: 0.5965 - val_loss: 1.0268 - val_accuracy: 0.6135 Epoch 15/200 98/98 [==============================] - 0s 2ms/step - loss: 1.0490 - accuracy: 0.6078 - val_loss: 1.0132 - val_accuracy: 0.6135 Epoch 16/200 98/98 [==============================] - 0s 2ms/step - loss: 1.0446 - accuracy: 0.5956 - val_loss: 0.9863 - val_accuracy: 0.6385 Epoch 17/200 98/98 [==============================] - 0s 2ms/step - loss: 1.0179 - accuracy: 0.6062 - val_loss: 0.9833 - val_accuracy: 0.6269 Epoch 18/200 98/98 [==============================] - 0s 2ms/step - loss: 1.0083 - accuracy: 0.6237 - val_loss: 0.9585 - val_accuracy: 0.6442 Epoch 19/200 98/98 [==============================] - 0s 2ms/step - loss: 0.9950 - accuracy: 0.6365 - val_loss: 0.9590 - val_accuracy: 0.6481 Epoch 20/200 98/98 [==============================] - 0s 2ms/step - loss: 0.9715 - accuracy: 0.6376 - val_loss: 0.9356 - val_accuracy: 0.6519 Epoch 21/200 98/98 [==============================] - 0s 2ms/step - loss: 0.9699 - accuracy: 0.6300 - val_loss: 0.9325 - val_accuracy: 0.6615 Epoch 22/200 98/98 [==============================] - 0s 2ms/step - loss: 0.9629 - accuracy: 0.6495 - val_loss: 0.9320 - val_accuracy: 0.6731 Epoch 23/200 98/98 [==============================] - 0s 2ms/step - loss: 0.9356 - accuracy: 0.6668 - val_loss: 0.9111 - val_accuracy: 0.6808 Epoch 24/200 98/98 [==============================] - 0s 2ms/step - loss: 0.9326 - accuracy: 0.6652 - val_loss: 0.9037 - val_accuracy: 0.6827 Epoch 25/200 98/98 [==============================] - 0s 2ms/step - loss: 0.9298 - accuracy: 0.6510 - val_loss: 0.8909 - val_accuracy: 0.6750 Epoch 26/200 98/98 [==============================] - 0s 2ms/step - loss: 0.9096 - accuracy: 0.6685 - val_loss: 0.8938 - val_accuracy: 0.6962 Epoch 27/200 98/98 [==============================] - 0s 2ms/step - loss: 0.8954 - accuracy: 0.7010 - val_loss: 0.8903 - val_accuracy: 0.6846 Epoch 28/200 98/98 [==============================] - 0s 2ms/step - loss: 0.9177 - accuracy: 0.6729 - val_loss: 0.8740 - val_accuracy: 0.6827 Epoch 29/200 98/98 [==============================] - 0s 2ms/step - loss: 0.8575 - accuracy: 0.6971 - val_loss: 0.8717 - val_accuracy: 0.7154 Epoch 30/200 98/98 [==============================] - 0s 2ms/step - loss: 0.8754 - accuracy: 0.6878 - val_loss: 0.8647 - val_accuracy: 0.6962 Epoch 31/200 98/98 [==============================] - 0s 2ms/step - loss: 0.8645 - accuracy: 0.6887 - val_loss: 0.8544 - val_accuracy: 0.7038 Epoch 32/200 98/98 [==============================] - 0s 2ms/step - loss: 0.8635 - accuracy: 0.6909 - val_loss: 0.8595 - val_accuracy: 0.6981 Epoch 33/200 98/98 [==============================] - 0s 2ms/step - loss: 0.8344 - accuracy: 0.7049 - val_loss: 0.8475 - val_accuracy: 0.7173 Epoch 34/200 98/98 [==============================] - 0s 2ms/step - loss: 0.7965 - accuracy: 0.7388 - val_loss: 0.8382 - val_accuracy: 0.6885 Epoch 35/200 98/98 [==============================] - 0s 2ms/step - loss: 0.8558 - accuracy: 0.6974 - val_loss: 0.8311 - val_accuracy: 0.7135 Epoch 36/200 98/98 [==============================] - 0s 2ms/step - loss: 0.8133 - accuracy: 0.7032 - val_loss: 0.8285 - val_accuracy: 0.7231 Epoch 37/200 98/98 [==============================] - 0s 2ms/step - loss: 0.8482 - accuracy: 0.7072 - val_loss: 0.8231 - val_accuracy: 0.7115 Epoch 38/200 98/98 [==============================] - 0s 2ms/step - loss: 0.8476 - accuracy: 0.6985 - val_loss: 0.8210 - val_accuracy: 0.7250 Epoch 39/200 98/98 [==============================] - 0s 2ms/step - loss: 0.8194 - accuracy: 0.6926 - val_loss: 0.8365 - val_accuracy: 0.7115 Epoch 40/200 98/98 [==============================] - 0s 2ms/step - loss: 0.7889 - accuracy: 0.7238 - val_loss: 0.8133 - val_accuracy: 0.7173 Epoch 41/200 98/98 [==============================] - 0s 2ms/step - loss: 0.8137 - accuracy: 0.6914 - val_loss: 0.8236 - val_accuracy: 0.6962 Epoch 42/200 98/98 [==============================] - 0s 2ms/step - loss: 0.7941 - accuracy: 0.7224 - val_loss: 0.8047 - val_accuracy: 0.7346 Epoch 43/200 98/98 [==============================] - 0s 2ms/step - loss: 0.8500 - accuracy: 0.6994 - val_loss: 0.7987 - val_accuracy: 0.7077 Epoch 44/200 98/98 [==============================] - 0s 2ms/step - loss: 0.7814 - accuracy: 0.7332 - val_loss: 0.8004 - val_accuracy: 0.7154 Epoch 45/200 98/98 [==============================] - 0s 2ms/step - loss: 0.8015 - accuracy: 0.7039 - val_loss: 0.7938 - val_accuracy: 0.7173 Epoch 46/200 98/98 [==============================] - 0s 2ms/step - loss: 0.7738 - accuracy: 0.7291 - val_loss: 0.7998 - val_accuracy: 0.7058 Epoch 47/200 98/98 [==============================] - 0s 2ms/step - loss: 0.7917 - accuracy: 0.7129 - val_loss: 0.7897 - val_accuracy: 0.7212 Epoch 48/200 98/98 [==============================] - 0s 2ms/step - loss: 0.7550 - accuracy: 0.7412 - val_loss: 0.7828 - val_accuracy: 0.7288 Epoch 49/200 98/98 [==============================] - 0s 2ms/step - loss: 0.7551 - accuracy: 0.7445 - val_loss: 0.7936 - val_accuracy: 0.7288 Epoch 50/200 98/98 [==============================] - 0s 2ms/step - loss: 0.7342 - accuracy: 0.7449 - val_loss: 0.7988 - val_accuracy: 0.7231 Epoch 51/200 98/98 [==============================] - 0s 2ms/step - loss: 0.7770 - accuracy: 0.7305 - val_loss: 0.7902 - val_accuracy: 0.7115 Epoch 52/200 98/98 [==============================] - 0s 2ms/step - loss: 0.7622 - accuracy: 0.7286 - val_loss: 0.7815 - val_accuracy: 0.7115 Epoch 53/200 98/98 [==============================] - 0s 2ms/step - loss: 0.7479 - accuracy: 0.7433 - val_loss: 0.7718 - val_accuracy: 0.7308 Epoch 54/200 98/98 [==============================] - 0s 2ms/step - loss: 0.7601 - accuracy: 0.7326 - val_loss: 0.7758 - val_accuracy: 0.7135 Epoch 55/200 98/98 [==============================] - 0s 2ms/step - loss: 0.7482 - accuracy: 0.7304 - val_loss: 0.7772 - val_accuracy: 0.7115 Epoch 56/200 98/98 [==============================] - 0s 2ms/step - loss: 0.7244 - accuracy: 0.7408 - val_loss: 0.7645 - val_accuracy: 0.7154 Epoch 57/200 98/98 [==============================] - 0s 2ms/step - loss: 0.7281 - accuracy: 0.7209 - val_loss: 0.7828 - val_accuracy: 0.7077 Epoch 58/200 98/98 [==============================] - 0s 2ms/step - loss: 0.7702 - accuracy: 0.7325 - val_loss: 0.7668 - val_accuracy: 0.7173 Epoch 59/200 98/98 [==============================] - 0s 2ms/step - loss: 0.7191 - accuracy: 0.7443 - val_loss: 0.7591 - val_accuracy: 0.7173 Epoch 60/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6780 - accuracy: 0.7536 - val_loss: 0.7579 - val_accuracy: 0.7288 Epoch 61/200 98/98 [==============================] - 0s 2ms/step - loss: 0.7097 - accuracy: 0.7394 - val_loss: 0.7622 - val_accuracy: 0.7327 Epoch 62/200 98/98 [==============================] - 0s 2ms/step - loss: 0.7462 - accuracy: 0.7265 - val_loss: 0.7587 - val_accuracy: 0.7327 Epoch 63/200 98/98 [==============================] - 0s 2ms/step - loss: 0.7447 - accuracy: 0.7361 - val_loss: 0.7499 - val_accuracy: 0.7308 Epoch 64/200 98/98 [==============================] - 0s 2ms/step - loss: 0.7401 - accuracy: 0.7401 - val_loss: 0.7597 - val_accuracy: 0.7346 Epoch 65/200 98/98 [==============================] - 0s 2ms/step - loss: 0.7384 - accuracy: 0.7327 - val_loss: 0.7453 - val_accuracy: 0.7308 Epoch 66/200 98/98 [==============================] - 0s 2ms/step - loss: 0.7053 - accuracy: 0.7589 - val_loss: 0.7621 - val_accuracy: 0.7269 Epoch 67/200 98/98 [==============================] - 0s 2ms/step - loss: 0.7289 - accuracy: 0.7550 - val_loss: 0.7519 - val_accuracy: 0.7231 Epoch 68/200 98/98 [==============================] - 0s 2ms/step - loss: 0.7545 - accuracy: 0.7282 - val_loss: 0.7545 - val_accuracy: 0.7327 Epoch 69/200 98/98 [==============================] - 0s 2ms/step - loss: 0.7253 - accuracy: 0.7409 - val_loss: 0.7425 - val_accuracy: 0.7365 Epoch 70/200 98/98 [==============================] - 0s 2ms/step - loss: 0.7190 - accuracy: 0.7445 - val_loss: 0.7377 - val_accuracy: 0.7404 Epoch 71/200 98/98 [==============================] - 0s 2ms/step - loss: 0.7114 - accuracy: 0.7551 - val_loss: 0.7470 - val_accuracy: 0.7404 Epoch 72/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6811 - accuracy: 0.7530 - val_loss: 0.7458 - val_accuracy: 0.7269 Epoch 73/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6945 - accuracy: 0.7365 - val_loss: 0.7501 - val_accuracy: 0.7346 Epoch 74/200 98/98 [==============================] - 0s 2ms/step - loss: 0.7087 - accuracy: 0.7471 - val_loss: 0.7618 - val_accuracy: 0.7250 Epoch 75/200 98/98 [==============================] - 0s 2ms/step - loss: 0.7285 - accuracy: 0.7410 - val_loss: 0.7451 - val_accuracy: 0.7327 Epoch 76/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6644 - accuracy: 0.7516 - val_loss: 0.7394 - val_accuracy: 0.7250 Epoch 77/200 98/98 [==============================] - 0s 2ms/step - loss: 0.7315 - accuracy: 0.7356 - val_loss: 0.7448 - val_accuracy: 0.7346 Epoch 78/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6932 - accuracy: 0.7423 - val_loss: 0.7341 - val_accuracy: 0.7288 Epoch 79/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6685 - accuracy: 0.7558 - val_loss: 0.7339 - val_accuracy: 0.7385 Epoch 80/200 98/98 [==============================] - 0s 2ms/step - loss: 0.7121 - accuracy: 0.7505 - val_loss: 0.7357 - val_accuracy: 0.7365 Epoch 81/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6903 - accuracy: 0.7528 - val_loss: 0.7373 - val_accuracy: 0.7385 Epoch 82/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6952 - accuracy: 0.7643 - val_loss: 0.7221 - val_accuracy: 0.7250 Epoch 83/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6815 - accuracy: 0.7461 - val_loss: 0.7264 - val_accuracy: 0.7327 Epoch 84/200 98/98 [==============================] - 0s 2ms/step - loss: 0.7003 - accuracy: 0.7509 - val_loss: 0.7430 - val_accuracy: 0.7077 Epoch 85/200 98/98 [==============================] - 0s 2ms/step - loss: 0.7003 - accuracy: 0.7389 - val_loss: 0.7274 - val_accuracy: 0.7500 Epoch 86/200 98/98 [==============================] - 0s 2ms/step - loss: 0.7097 - accuracy: 0.7540 - val_loss: 0.7205 - val_accuracy: 0.7481 Epoch 87/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6744 - accuracy: 0.7599 - val_loss: 0.7231 - val_accuracy: 0.7385 Epoch 88/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6992 - accuracy: 0.7500 - val_loss: 0.7310 - val_accuracy: 0.7288 Epoch 89/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6424 - accuracy: 0.7602 - val_loss: 0.7191 - val_accuracy: 0.7442 Epoch 90/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6649 - accuracy: 0.7673 - val_loss: 0.7339 - val_accuracy: 0.7558 Epoch 91/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6421 - accuracy: 0.7653 - val_loss: 0.7179 - val_accuracy: 0.7462 Epoch 92/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6525 - accuracy: 0.7672 - val_loss: 0.7321 - val_accuracy: 0.7442 Epoch 93/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6564 - accuracy: 0.7612 - val_loss: 0.7238 - val_accuracy: 0.7481 Epoch 94/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6786 - accuracy: 0.7438 - val_loss: 0.7255 - val_accuracy: 0.7308 Epoch 95/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6830 - accuracy: 0.7570 - val_loss: 0.7078 - val_accuracy: 0.7404 Epoch 96/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6501 - accuracy: 0.7445 - val_loss: 0.7138 - val_accuracy: 0.7404 Epoch 97/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6757 - accuracy: 0.7597 - val_loss: 0.7114 - val_accuracy: 0.7442 Epoch 98/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6689 - accuracy: 0.7580 - val_loss: 0.7183 - val_accuracy: 0.7462 Epoch 99/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6659 - accuracy: 0.7605 - val_loss: 0.7210 - val_accuracy: 0.7404 Epoch 100/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6742 - accuracy: 0.7608 - val_loss: 0.7261 - val_accuracy: 0.7577 Epoch 101/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6950 - accuracy: 0.7496 - val_loss: 0.7269 - val_accuracy: 0.7519 Epoch 102/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6784 - accuracy: 0.7484 - val_loss: 0.7170 - val_accuracy: 0.7519 Epoch 103/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6616 - accuracy: 0.7572 - val_loss: 0.6997 - val_accuracy: 0.7442 Epoch 104/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6590 - accuracy: 0.7606 - val_loss: 0.7043 - val_accuracy: 0.7385 Epoch 105/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6587 - accuracy: 0.7594 - val_loss: 0.7072 - val_accuracy: 0.7423 Epoch 106/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6885 - accuracy: 0.7425 - val_loss: 0.7005 - val_accuracy: 0.7442 Epoch 107/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6454 - accuracy: 0.7743 - val_loss: 0.7091 - val_accuracy: 0.7635 Epoch 108/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6602 - accuracy: 0.7705 - val_loss: 0.7016 - val_accuracy: 0.7442 Epoch 109/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6513 - accuracy: 0.7529 - val_loss: 0.7029 - val_accuracy: 0.7481 Epoch 110/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6325 - accuracy: 0.7784 - val_loss: 0.7026 - val_accuracy: 0.7577 Epoch 111/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6657 - accuracy: 0.7690 - val_loss: 0.7026 - val_accuracy: 0.7481 Epoch 112/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6701 - accuracy: 0.7554 - val_loss: 0.7053 - val_accuracy: 0.7538 Epoch 113/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6511 - accuracy: 0.7808 - val_loss: 0.7025 - val_accuracy: 0.7558 Epoch 114/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6750 - accuracy: 0.7618 - val_loss: 0.6922 - val_accuracy: 0.7346 Epoch 115/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6360 - accuracy: 0.7779 - val_loss: 0.7014 - val_accuracy: 0.7538 Epoch 116/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6520 - accuracy: 0.7605 - val_loss: 0.7014 - val_accuracy: 0.7673 Epoch 117/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6078 - accuracy: 0.7874 - val_loss: 0.7044 - val_accuracy: 0.7365 Epoch 118/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6259 - accuracy: 0.7764 - val_loss: 0.6874 - val_accuracy: 0.7615 Epoch 119/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6577 - accuracy: 0.7605 - val_loss: 0.6984 - val_accuracy: 0.7558 Epoch 120/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6419 - accuracy: 0.7645 - val_loss: 0.6886 - val_accuracy: 0.7615 Epoch 121/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6369 - accuracy: 0.7669 - val_loss: 0.7017 - val_accuracy: 0.7558 Epoch 122/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6486 - accuracy: 0.7620 - val_loss: 0.7003 - val_accuracy: 0.7577 Epoch 123/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6188 - accuracy: 0.7654 - val_loss: 0.6918 - val_accuracy: 0.7538 Epoch 124/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6407 - accuracy: 0.7716 - val_loss: 0.6984 - val_accuracy: 0.7635 Epoch 125/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6219 - accuracy: 0.7885 - val_loss: 0.6920 - val_accuracy: 0.7596 Epoch 126/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6555 - accuracy: 0.7531 - val_loss: 0.6906 - val_accuracy: 0.7654 Epoch 127/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6242 - accuracy: 0.7669 - val_loss: 0.6969 - val_accuracy: 0.7577 Epoch 128/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6361 - accuracy: 0.7687 - val_loss: 0.6910 - val_accuracy: 0.7654 Epoch 129/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6572 - accuracy: 0.7571 - val_loss: 0.6835 - val_accuracy: 0.7558 Epoch 130/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6313 - accuracy: 0.7610 - val_loss: 0.6935 - val_accuracy: 0.7481 Epoch 131/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6238 - accuracy: 0.7836 - val_loss: 0.7052 - val_accuracy: 0.7500 Epoch 132/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6299 - accuracy: 0.7602 - val_loss: 0.6986 - val_accuracy: 0.7500 Epoch 133/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6566 - accuracy: 0.7672 - val_loss: 0.6949 - val_accuracy: 0.7596 Epoch 134/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6286 - accuracy: 0.7665 - val_loss: 0.6938 - val_accuracy: 0.7577 Epoch 135/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6259 - accuracy: 0.7738 - val_loss: 0.6906 - val_accuracy: 0.7519 Epoch 136/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6141 - accuracy: 0.7708 - val_loss: 0.6853 - val_accuracy: 0.7654 Epoch 137/200 98/98 [==============================] - 0s 2ms/step - loss: 0.5885 - accuracy: 0.7849 - val_loss: 0.6952 - val_accuracy: 0.7615 Epoch 138/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6171 - accuracy: 0.7741 - val_loss: 0.6927 - val_accuracy: 0.7596 Epoch 139/200 98/98 [==============================] - 0s 2ms/step - loss: 0.5820 - accuracy: 0.7901 - val_loss: 0.6879 - val_accuracy: 0.7673 Epoch 140/200 98/98 [==============================] - 0s 2ms/step - loss: 0.5959 - accuracy: 0.7781 - val_loss: 0.6947 - val_accuracy: 0.7442 Epoch 141/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6180 - accuracy: 0.7858 - val_loss: 0.6857 - val_accuracy: 0.7654 Epoch 142/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6247 - accuracy: 0.7778 - val_loss: 0.6838 - val_accuracy: 0.7538 Epoch 143/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6107 - accuracy: 0.7822 - val_loss: 0.6845 - val_accuracy: 0.7635 Epoch 144/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6013 - accuracy: 0.7928 - val_loss: 0.6772 - val_accuracy: 0.7654 Epoch 145/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6646 - accuracy: 0.7628 - val_loss: 0.6861 - val_accuracy: 0.7519 Epoch 146/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6176 - accuracy: 0.7761 - val_loss: 0.6809 - val_accuracy: 0.7654 Epoch 147/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6158 - accuracy: 0.7695 - val_loss: 0.6832 - val_accuracy: 0.7673 Epoch 148/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6139 - accuracy: 0.7677 - val_loss: 0.6761 - val_accuracy: 0.7673 Epoch 149/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6270 - accuracy: 0.7792 - val_loss: 0.6920 - val_accuracy: 0.7442 Epoch 150/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6103 - accuracy: 0.7842 - val_loss: 0.6851 - val_accuracy: 0.7519 Epoch 151/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6312 - accuracy: 0.7598 - val_loss: 0.6732 - val_accuracy: 0.7635 Epoch 152/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6420 - accuracy: 0.7716 - val_loss: 0.6790 - val_accuracy: 0.7615 Epoch 153/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6009 - accuracy: 0.7968 - val_loss: 0.6898 - val_accuracy: 0.7558 Epoch 154/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6372 - accuracy: 0.7653 - val_loss: 0.6769 - val_accuracy: 0.7635 Epoch 155/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6253 - accuracy: 0.7699 - val_loss: 0.7003 - val_accuracy: 0.7654 Epoch 156/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6425 - accuracy: 0.7749 - val_loss: 0.6783 - val_accuracy: 0.7673 Epoch 157/200 98/98 [==============================] - 0s 2ms/step - loss: 0.5940 - accuracy: 0.7870 - val_loss: 0.6701 - val_accuracy: 0.7712 Epoch 158/200 98/98 [==============================] - 0s 2ms/step - loss: 0.5924 - accuracy: 0.7883 - val_loss: 0.6937 - val_accuracy: 0.7462 Epoch 159/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6223 - accuracy: 0.7900 - val_loss: 0.6705 - val_accuracy: 0.7577 Epoch 160/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6250 - accuracy: 0.7762 - val_loss: 0.6768 - val_accuracy: 0.7673 Epoch 161/200 98/98 [==============================] - 0s 2ms/step - loss: 0.5967 - accuracy: 0.7922 - val_loss: 0.6675 - val_accuracy: 0.7731 Epoch 162/200 98/98 [==============================] - 0s 2ms/step - loss: 0.5739 - accuracy: 0.7949 - val_loss: 0.6778 - val_accuracy: 0.7731 Epoch 163/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6516 - accuracy: 0.7442 - val_loss: 0.6702 - val_accuracy: 0.7712 Epoch 164/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6020 - accuracy: 0.7777 - val_loss: 0.6709 - val_accuracy: 0.7654 Epoch 165/200 98/98 [==============================] - 0s 2ms/step - loss: 0.5989 - accuracy: 0.7820 - val_loss: 0.6815 - val_accuracy: 0.7577 Epoch 166/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6047 - accuracy: 0.7844 - val_loss: 0.6641 - val_accuracy: 0.7615 Epoch 167/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6180 - accuracy: 0.7730 - val_loss: 0.6686 - val_accuracy: 0.7712 Epoch 168/200 98/98 [==============================] - 0s 2ms/step - loss: 0.5752 - accuracy: 0.7907 - val_loss: 0.6683 - val_accuracy: 0.7750 Epoch 169/200 98/98 [==============================] - 0s 2ms/step - loss: 0.5827 - accuracy: 0.7849 - val_loss: 0.6798 - val_accuracy: 0.7673 Epoch 170/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6310 - accuracy: 0.7768 - val_loss: 0.6624 - val_accuracy: 0.7673 Epoch 171/200 98/98 [==============================] - 0s 2ms/step - loss: 0.5964 - accuracy: 0.7885 - val_loss: 0.6647 - val_accuracy: 0.7654 Epoch 172/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6048 - accuracy: 0.7876 - val_loss: 0.6865 - val_accuracy: 0.7673 Epoch 173/200 98/98 [==============================] - 0s 2ms/step - loss: 0.5942 - accuracy: 0.7906 - val_loss: 0.6681 - val_accuracy: 0.7615 Epoch 174/200 98/98 [==============================] - 0s 2ms/step - loss: 0.5657 - accuracy: 0.7816 - val_loss: 0.6697 - val_accuracy: 0.7673 Epoch 175/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6114 - accuracy: 0.7781 - val_loss: 0.6653 - val_accuracy: 0.7635 Epoch 176/200 98/98 [==============================] - 0s 2ms/step - loss: 0.5779 - accuracy: 0.8125 - val_loss: 0.6808 - val_accuracy: 0.7673 Epoch 177/200 98/98 [==============================] - 0s 2ms/step - loss: 0.5815 - accuracy: 0.7939 - val_loss: 0.6725 - val_accuracy: 0.7692 Epoch 178/200 98/98 [==============================] - 0s 2ms/step - loss: 0.5786 - accuracy: 0.7942 - val_loss: 0.6723 - val_accuracy: 0.7654 Epoch 179/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6160 - accuracy: 0.7792 - val_loss: 0.6617 - val_accuracy: 0.7519 Epoch 180/200 98/98 [==============================] - 0s 2ms/step - loss: 0.5538 - accuracy: 0.8034 - val_loss: 0.6654 - val_accuracy: 0.7673 Epoch 181/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6256 - accuracy: 0.7702 - val_loss: 0.6623 - val_accuracy: 0.7673 Epoch 182/200 98/98 [==============================] - 0s 2ms/step - loss: 0.5722 - accuracy: 0.7956 - val_loss: 0.6775 - val_accuracy: 0.7673 Epoch 183/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6127 - accuracy: 0.7825 - val_loss: 0.6735 - val_accuracy: 0.7673 Epoch 184/200 98/98 [==============================] - 0s 2ms/step - loss: 0.5505 - accuracy: 0.8025 - val_loss: 0.6634 - val_accuracy: 0.7692 Epoch 185/200 98/98 [==============================] - 0s 2ms/step - loss: 0.5836 - accuracy: 0.7843 - val_loss: 0.6727 - val_accuracy: 0.7654 Epoch 186/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6168 - accuracy: 0.7723 - val_loss: 0.6640 - val_accuracy: 0.7558 Epoch 187/200 98/98 [==============================] - 0s 2ms/step - loss: 0.5749 - accuracy: 0.7995 - val_loss: 0.6630 - val_accuracy: 0.7635 Epoch 188/200 98/98 [==============================] - 0s 2ms/step - loss: 0.5411 - accuracy: 0.8099 - val_loss: 0.6582 - val_accuracy: 0.7769 Epoch 189/200 98/98 [==============================] - 0s 2ms/step - loss: 0.5902 - accuracy: 0.7868 - val_loss: 0.6970 - val_accuracy: 0.7596 Epoch 190/200 98/98 [==============================] - 0s 2ms/step - loss: 0.5631 - accuracy: 0.7806 - val_loss: 0.6687 - val_accuracy: 0.7769 Epoch 191/200 98/98 [==============================] - 0s 2ms/step - loss: 0.6292 - accuracy: 0.7633 - val_loss: 0.6710 - val_accuracy: 0.7692 Epoch 192/200 98/98 [==============================] - 0s 2ms/step - loss: 0.5737 - accuracy: 0.8003 - val_loss: 0.6691 - val_accuracy: 0.7673 Epoch 193/200 98/98 [==============================] - 0s 2ms/step - loss: 0.5517 - accuracy: 0.8069 - val_loss: 0.6680 - val_accuracy: 0.7558 Epoch 194/200 98/98 [==============================] - 0s 2ms/step - loss: 0.5639 - accuracy: 0.7845 - val_loss: 0.6668 - val_accuracy: 0.7673 Epoch 195/200 98/98 [==============================] - 0s 2ms/step - loss: 0.5997 - accuracy: 0.8004 - val_loss: 0.6574 - val_accuracy: 0.7673 Epoch 196/200 98/98 [==============================] - 0s 2ms/step - loss: 0.5739 - accuracy: 0.8073 - val_loss: 0.6604 - val_accuracy: 0.7712 Epoch 197/200 98/98 [==============================] - 0s 2ms/step - loss: 0.5848 - accuracy: 0.7926 - val_loss: 0.6608 - val_accuracy: 0.7596 Epoch 198/200 98/98 [==============================] - 0s 2ms/step - loss: 0.5596 - accuracy: 0.7939 - val_loss: 0.6577 - val_accuracy: 0.7769 Epoch 199/200 98/98 [==============================] - 0s 2ms/step - loss: 0.5643 - accuracy: 0.7978 - val_loss: 0.6603 - val_accuracy: 0.7788 Epoch 200/200 98/98 [==============================] - 0s 2ms/step - loss: 0.5555 - accuracy: 0.8046 - val_loss: 0.6593 - val_accuracy: 0.7769
As the model trains, the loss and accuracy metrics are displayed. This model reaches an accuracy of about 80% on the training data.
Next, compare how the model performs on the test dataset:
dict.keys(history.history)
dict_keys(['loss', 'accuracy', 'val_loss', 'val_accuracy'])
metrics = history.history
plt.plot(history.epoch, metrics['loss'])
plt.plot(history.epoch, metrics['accuracy'])
plt.plot(history.epoch, metrics['val_loss'])
plt.plot(history.epoch, metrics['val_accuracy'])
plt.legend(['loss', 'accuracy', 'val_loss', 'val_accuracy'])
plt.show()
It turns out that the accuracy on the test dataset is a little less than the accuracy on the training dataset. This gap between training accuracy and test accuracy represents overfitting. Overfitting happens when a machine learning model performs worse on new, previously unseen inputs than it does on the training data. An overfitted model "memorizes" the noise and details in the training dataset to a point where it negatively impacts the performance of the model on the new data.
One way to deal with this is by introducing some regularization with a Dropout layer. To do this, you need to go back to re-initialize the Sequential
class with one layer being tk.keras.layers.Dropout(0.2)
, where 0.2
is the amount of units in percentage to remove between layers. The more you take, the stronger the regularization.
probability_model = tf.keras.Sequential([model, tf.keras.layers.Softmax()])
predictions = probability_model.predict(test_features)
Here, the model has predicted the label for each artist in the testing set. Let's take a look at the first prediction:
predictions[0]
array([1.1181573e-02, 9.8401791e-01, 9.3440511e-05, 9.6941723e-05, 4.6101473e-03], dtype=float32)
A prediction is an array of 5 numbers. They represent the model's "confidence" that the image corresponds to each of the 5 different articles of clothing. You can see which label has the highest confidence value:
np.argmax(predictions[0])
1
So, the model is most confident that this is artist class_names[1]
. Examining the test label shows that this classification is correct:
test_labels[0]
1.0
pred_labels = np.array([np.argmax(p) for p in predictions])
pred_labels.shape, test_labels.shape
((520,), (520,))
import seaborn as sns
mat = tf.math.confusion_matrix(test_labels, pred_labels)
sns.heatmap(mat, square=True, annot=True, cbar=True, fmt='d',
xticklabels=label_names,
yticklabels=label_names)
plt.xlabel('True')
plt.ylabel('Predicted');
Finally, use the trained model to make a prediction about a single image.
stats = []
# Grab a random element from the test dataset.
example = np.random.randint(len(test_features))
tester_feature = test_features[example]
# `tf.keras` models are optimized to make predictions on a *batch*,
# or collection, of examples at once. Accordingly, even though
# you're using a single image, you need to add it to a list:
# Add the image to a batch where it's the only member.
tester_feature = (np.expand_dims(tester_feature,0))
# Now predict the correct label for this image:
predictions_single = probability_model.predict(tester_feature)
y_pred = np.argmax(predictions_single[0])
y_true = int(test_labels[example])
# get both Predicted and True labels:
from termcolor import colored
if y_true != y_pred:
stats.append(0)
msg = f"Yikes!, {['Oh, well', 'perhaps next time?', 'ugh...'][np.random.randint(3)]}"
color='red'
else:
stats.append(1)
msg = f"Knew it."
color='green'
print(colored("="*80, color))
print(colored("| "+msg, color))
print(colored("-"*80, color))
print(colored("| TRUE: "+label_names[y_true], color))
print(colored("| PRED: "+label_names[y_pred], color))
print(colored("="*80, color))
print(f"Current accuracy: {stats.count(1) / len(stats) * 100}")
print(f"Actual accuracy is: {metrics['accuracy'][-1]*100}")
================================================================================ | Knew it. -------------------------------------------------------------------------------- | TRUE: The Rolling Stones | PRED: The Rolling Stones ================================================================================ Current accuracy: 76.92307692307693 Actual accuracy is: 79.67948913574219