In this colab you will train a model to recognize musical instruments from the nsynth
database.
import numpy as np
from IPython import display as ipd
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow as tf
import tensorflow_datasets as tfds
# Set seed for experiment reproducibility
seed = 42
tf.random.set_seed(seed)
np.random.seed(seed)
# default nsynth samplerate
sr = 16000
clip_audio_at = sr // 1 # maximum 1 second
# mfcc parameters
window_length = 256
hop_length = 128
num_mels = 20
min_hz = 60.0
max_hz = 6400.0
AUTOTUNE = tf.data.experimental.AUTOTUNE
batch_size = 64
max_x = 32
# train, validation, and test sizes to fetch from the nsynth database
p = [8000, 2000, 200]
Nsynth is a large-scale and high-quality dataset of annotated musical notes: https://magenta.tensorflow.org/datasets/nsynth
label_names = [
'bass', 'brass', 'flute',
'guitar', 'keyboard', 'mallet',
'organ', 'reed', 'string', 'synth_lead', 'vocal'
]
First, look at the dataset shape and select a percentage of which to use here
nsynth = tfds.load("nsynth", try_gcs=True)
print(f"Nsynth keys: {', '.join(nsynth.keys())}")
for key in nsynth.keys():
print(f"{key} dataset length: {len(nsynth[key])}")
Nsynth keys: test, train, valid test dataset length: 4096 train dataset length: 289205 valid dataset length: 12678
train_files = tfds.load("nsynth",split=f"train[:{sum(p)}]",shuffle_files=True, try_gcs=True)
# val_files = tfds.load("nsynth",split=f"valid[:{p[1]}]",shuffle_files=True, try_gcs=True)
# test_files = tfds.load("nsynth",split=f"test[:{p[2]}]", shuffle_files=True, try_gcs=True)
print('Training set size', len(train_files))
# print('Validation set size', len(val_files))
# print('Test set size', len(test_files))
Training set size 10200
assert isinstance(train_files, tf.data.Dataset)
To make MFCCs in tensorflow and to map MFCCs and labels to our nsynth dataset
mel_weights = tf.signal.linear_to_mel_weight_matrix(num_mel_bins=num_mels,
num_spectrogram_bins=window_length//2+1,
sample_rate=sr,
lower_edge_hertz=min_hz,
upper_edge_hertz=max_hz)
def mfcc(waveform):
# Padding for files with less than clip_audio_at samples
zero_padding = tf.zeros([clip_audio_at] - tf.shape(waveform), dtype=tf.float32)
# Concatenate audio with padding
# so that all audio clips will be of the same length
waveform = tf.cast(waveform, tf.float32)
padded_waveform = tf.concat([waveform, zero_padding], 0)
# compute stft
D = tf.signal.stft(padded_waveform, window_length, hop_length)
# take absolute for amplitude
amp_spec = tf.math.abs(D)
# compute MFCCs
mfcc_spec = tf.tensordot(amp_spec, mel_weights, 1)
return mfcc_spec
def mfcc_and_label(entry):
mfcc_spec = mfcc(entry['audio'][:clip_audio_at])
mfcc_spec = tf.expand_dims(mfcc_spec, -1)
label = entry['instrument']['family']
return mfcc_spec, label
mfcc_train_ds = train_files.map(mfcc_and_label, num_parallel_calls=AUTOTUNE)
# mfcc_test_ds = test_files.map( mfcc_and_label, num_parallel_calls=AUTOTUNE)
# mfcc_val_ds = val_files.map( mfcc_and_label, num_parallel_calls=AUTOTUNE)
d = 3
scatter_plot = True
fig, axes = plt.subplots(nrows=d, ncols=d, figsize=(10, 10))
for i, (mfcc, label) in enumerate(mfcc_train_ds.take(d*d)):
r = i // d
c = i % d
ax = axes[r][c]
ax.set_title(label_names[label.numpy()])
mfcc = np.log(np.squeeze(mfcc.numpy()))
if scatter_plot:
ax.set_yticks([])
ax.set_xticks([])
ax.scatter(mfcc[:,0],mfcc[:,1])
else:
ax.imshow(mfcc.T, origin='lower', aspect='auto')
ax.axis('off')
plt.show()
plt.close()
# create a normalizer
normalizer = tf.keras.layers.experimental.preprocessing.Normalization()
# adapt your normalization with the train dataset,
# but only using the mfccs, not the labels, hence the .map()
normalizer.adapt(mfcc_train_ds.map(lambda x, _: x))
We'll make a new dataset that is used for training. This one is different from the Spectrogram and Label datasets, because it is now "batched", or split into groups, and it will be cached in memory for better performance.
Batch the training and validation sets for model training and add dataset cache()
and prefetch()
operations to reduce read latency while training the model.
train_ds = mfcc_train_ds.take(p[0])
print("Train set size:",len(train_ds))
train_ds = train_ds.batch(batch_size).cache().prefetch(AUTOTUNE)
valid_ds = mfcc_train_ds.skip(p[0]).take(p[1])
print("Validation set size:",len(valid_ds))
valid_ds = valid_ds.batch(batch_size).cache().prefetch(AUTOTUNE)
test_ds = mfcc_train_ds.skip(p[0]+p[1]).take(p[2])
print("Test set size:",len(test_ds))
test_ds = test_ds.batch(batch_size).cache().prefetch(AUTOTUNE)
# test_ds = mfcc_test_ds.batch(batch_size).cache().prefetch(AUTOTUNE)
# val_ds = mfcc_val_ds.batch(batch_size).cache().prefetch(AUTOTUNE)
Train set size: 8000 Validation set size: 2000 Test set size: 200
for mfcc, _ in mfcc_train_ds.take(1):
input_shape = mfcc.shape
print('Input shape:', input_shape)
if input_shape[0] >= max_x:
resize_x = max_x
else:
resize_x = input_shape[0]
Input shape: (124, 20, 1)
model = tf.keras.models.Sequential([
tf.keras.layers.Input(shape=input_shape),
tf.keras.layers.experimental.preprocessing.Resizing(resize_x,num_mels),
normalizer,
tf.keras.layers.Conv2D(resize_x, 3),
tf.keras.layers.LeakyReLU(alpha=0.3),
tf.keras.layers.Conv2D(resize_x*2, 3),
tf.keras.layers.LeakyReLU(alpha=0.3),
tf.keras.layers.MaxPooling2D(pool_size=(3,3)),
tf.keras.layers.Dropout(0.35),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128,
kernel_regularizer=tf.keras.regularizers.L2(l2=0.001)),
tf.keras.layers.LeakyReLU(alpha=0.4),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.Dense(len(label_names))
])
model.summary()
Model: "sequential_26" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= resizing_28 (Resizing) (None, 32, 20, 1) 0 _________________________________________________________________ normalization_2 (Normalizati (None, 32, 20, 1) 3 _________________________________________________________________ conv2d_49 (Conv2D) (None, 30, 18, 32) 320 _________________________________________________________________ leaky_re_lu_20 (LeakyReLU) (None, 30, 18, 32) 0 _________________________________________________________________ conv2d_50 (Conv2D) (None, 28, 16, 64) 18496 _________________________________________________________________ leaky_re_lu_21 (LeakyReLU) (None, 28, 16, 64) 0 _________________________________________________________________ max_pooling2d_23 (MaxPooling (None, 9, 5, 64) 0 _________________________________________________________________ dropout_60 (Dropout) (None, 9, 5, 64) 0 _________________________________________________________________ flatten_26 (Flatten) (None, 2880) 0 _________________________________________________________________ dense_57 (Dense) (None, 128) 368768 _________________________________________________________________ leaky_re_lu_22 (LeakyReLU) (None, 128) 0 _________________________________________________________________ dropout_61 (Dropout) (None, 128) 0 _________________________________________________________________ dense_58 (Dense) (None, 11) 1419 ================================================================= Total params: 389,006 Trainable params: 389,003 Non-trainable params: 3 _________________________________________________________________
model.compile(
optimizer=tf.keras.optimizers.Adam(),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=['accuracy'],
)
history = model.fit(
train_ds,
validation_data=valid_ds,
epochs=200,
verbose=1,
callbacks=tf.keras.callbacks.EarlyStopping(monitor='loss', verbose=1, patience=5)
)
Epoch 1/200 125/125 [==============================] - 1s 7ms/step - loss: 2.3393 - accuracy: 0.2428 - val_loss: 1.9487 - val_accuracy: 0.3820 Epoch 2/200 125/125 [==============================] - 1s 5ms/step - loss: 1.9586 - accuracy: 0.3681 - val_loss: 1.7435 - val_accuracy: 0.4425 Epoch 3/200 125/125 [==============================] - 1s 5ms/step - loss: 1.8010 - accuracy: 0.4219 - val_loss: 1.6495 - val_accuracy: 0.4670 Epoch 4/200 125/125 [==============================] - 1s 5ms/step - loss: 1.7133 - accuracy: 0.4419 - val_loss: 1.5810 - val_accuracy: 0.4940 Epoch 5/200 125/125 [==============================] - 1s 5ms/step - loss: 1.6605 - accuracy: 0.4529 - val_loss: 1.5242 - val_accuracy: 0.5065 Epoch 6/200 125/125 [==============================] - 1s 5ms/step - loss: 1.6189 - accuracy: 0.4734 - val_loss: 1.4918 - val_accuracy: 0.5200 Epoch 7/200 125/125 [==============================] - 1s 5ms/step - loss: 1.5809 - accuracy: 0.4907 - val_loss: 1.4868 - val_accuracy: 0.5275 Epoch 8/200 125/125 [==============================] - 1s 5ms/step - loss: 1.5705 - accuracy: 0.4905 - val_loss: 1.4452 - val_accuracy: 0.5495 Epoch 9/200 125/125 [==============================] - 1s 6ms/step - loss: 1.5384 - accuracy: 0.5103 - val_loss: 1.4332 - val_accuracy: 0.5590 Epoch 10/200 125/125 [==============================] - 1s 6ms/step - loss: 1.5260 - accuracy: 0.5135 - val_loss: 1.4112 - val_accuracy: 0.5725 Epoch 11/200 125/125 [==============================] - 1s 6ms/step - loss: 1.5141 - accuracy: 0.5222 - val_loss: 1.3874 - val_accuracy: 0.5730 Epoch 12/200 125/125 [==============================] - 1s 5ms/step - loss: 1.4833 - accuracy: 0.5307 - val_loss: 1.3725 - val_accuracy: 0.5805 Epoch 13/200 125/125 [==============================] - 1s 6ms/step - loss: 1.4935 - accuracy: 0.5282 - val_loss: 1.3696 - val_accuracy: 0.5810 Epoch 14/200 125/125 [==============================] - 1s 6ms/step - loss: 1.4652 - accuracy: 0.5376 - val_loss: 1.3499 - val_accuracy: 0.6035 Epoch 15/200 125/125 [==============================] - 1s 5ms/step - loss: 1.4517 - accuracy: 0.5438 - val_loss: 1.3255 - val_accuracy: 0.6065 Epoch 16/200 125/125 [==============================] - 1s 5ms/step - loss: 1.4357 - accuracy: 0.5591 - val_loss: 1.3251 - val_accuracy: 0.6265 Epoch 17/200 125/125 [==============================] - 1s 5ms/step - loss: 1.4509 - accuracy: 0.5547 - val_loss: 1.3136 - val_accuracy: 0.6240 Epoch 18/200 125/125 [==============================] - 1s 6ms/step - loss: 1.4120 - accuracy: 0.5622 - val_loss: 1.3402 - val_accuracy: 0.6140 Epoch 19/200 125/125 [==============================] - 1s 6ms/step - loss: 1.4096 - accuracy: 0.5591 - val_loss: 1.3185 - val_accuracy: 0.6180 Epoch 20/200 125/125 [==============================] - 1s 6ms/step - loss: 1.3964 - accuracy: 0.5711 - val_loss: 1.3041 - val_accuracy: 0.6375 Epoch 21/200 125/125 [==============================] - 1s 6ms/step - loss: 1.3879 - accuracy: 0.5739 - val_loss: 1.3164 - val_accuracy: 0.6305 Epoch 22/200 125/125 [==============================] - 1s 5ms/step - loss: 1.4136 - accuracy: 0.5671 - val_loss: 1.3006 - val_accuracy: 0.6255 Epoch 23/200 125/125 [==============================] - 1s 6ms/step - loss: 1.3825 - accuracy: 0.5794 - val_loss: 1.2873 - val_accuracy: 0.6360 Epoch 24/200 125/125 [==============================] - 1s 6ms/step - loss: 1.3715 - accuracy: 0.5880 - val_loss: 1.2802 - val_accuracy: 0.6290 Epoch 25/200 125/125 [==============================] - 1s 6ms/step - loss: 1.3595 - accuracy: 0.5879 - val_loss: 1.2698 - val_accuracy: 0.6345 Epoch 26/200 125/125 [==============================] - 1s 6ms/step - loss: 1.3572 - accuracy: 0.5930 - val_loss: 1.2739 - val_accuracy: 0.6400 Epoch 27/200 125/125 [==============================] - 1s 5ms/step - loss: 1.3379 - accuracy: 0.5928 - val_loss: 1.2673 - val_accuracy: 0.6395 Epoch 28/200 125/125 [==============================] - 1s 6ms/step - loss: 1.3588 - accuracy: 0.5900 - val_loss: 1.2450 - val_accuracy: 0.6470 Epoch 29/200 125/125 [==============================] - 1s 6ms/step - loss: 1.3543 - accuracy: 0.5997 - val_loss: 1.2459 - val_accuracy: 0.6455 Epoch 30/200 125/125 [==============================] - 1s 6ms/step - loss: 1.3237 - accuracy: 0.6054 - val_loss: 1.2382 - val_accuracy: 0.6500 Epoch 31/200 125/125 [==============================] - 1s 5ms/step - loss: 1.3272 - accuracy: 0.5968 - val_loss: 1.2342 - val_accuracy: 0.6530 Epoch 32/200 125/125 [==============================] - 1s 6ms/step - loss: 1.3026 - accuracy: 0.6080 - val_loss: 1.2239 - val_accuracy: 0.6555 Epoch 33/200 125/125 [==============================] - 1s 6ms/step - loss: 1.3167 - accuracy: 0.6003 - val_loss: 1.2215 - val_accuracy: 0.6700 Epoch 34/200 125/125 [==============================] - 1s 5ms/step - loss: 1.3084 - accuracy: 0.6075 - val_loss: 1.2168 - val_accuracy: 0.6665 Epoch 35/200 125/125 [==============================] - 1s 6ms/step - loss: 1.3010 - accuracy: 0.6129 - val_loss: 1.2069 - val_accuracy: 0.6640 Epoch 36/200 125/125 [==============================] - 1s 6ms/step - loss: 1.3062 - accuracy: 0.6153 - val_loss: 1.2183 - val_accuracy: 0.6645 Epoch 37/200 125/125 [==============================] - 1s 5ms/step - loss: 1.2900 - accuracy: 0.6205 - val_loss: 1.1960 - val_accuracy: 0.6755 Epoch 38/200 125/125 [==============================] - 1s 6ms/step - loss: 1.2868 - accuracy: 0.6206 - val_loss: 1.1977 - val_accuracy: 0.6755 Epoch 39/200 125/125 [==============================] - 1s 6ms/step - loss: 1.3043 - accuracy: 0.6149 - val_loss: 1.2033 - val_accuracy: 0.6690 Epoch 40/200 125/125 [==============================] - 1s 5ms/step - loss: 1.2759 - accuracy: 0.6146 - val_loss: 1.1890 - val_accuracy: 0.6695 Epoch 41/200 125/125 [==============================] - 1s 5ms/step - loss: 1.2585 - accuracy: 0.6292 - val_loss: 1.1897 - val_accuracy: 0.6715 Epoch 42/200 125/125 [==============================] - 1s 6ms/step - loss: 1.2839 - accuracy: 0.6103 - val_loss: 1.1732 - val_accuracy: 0.6750 Epoch 43/200 125/125 [==============================] - 1s 5ms/step - loss: 1.2562 - accuracy: 0.6383 - val_loss: 1.1971 - val_accuracy: 0.6660 Epoch 44/200 125/125 [==============================] - 1s 6ms/step - loss: 1.2737 - accuracy: 0.6344 - val_loss: 1.1898 - val_accuracy: 0.6745 Epoch 45/200 125/125 [==============================] - 1s 5ms/step - loss: 1.2640 - accuracy: 0.6417 - val_loss: 1.1979 - val_accuracy: 0.6775 Epoch 46/200 125/125 [==============================] - 1s 5ms/step - loss: 1.2705 - accuracy: 0.6297 - val_loss: 1.1859 - val_accuracy: 0.6725 Epoch 47/200 125/125 [==============================] - 1s 6ms/step - loss: 1.2489 - accuracy: 0.6349 - val_loss: 1.1773 - val_accuracy: 0.6800 Epoch 48/200 125/125 [==============================] - 1s 6ms/step - loss: 1.2386 - accuracy: 0.6351 - val_loss: 1.1487 - val_accuracy: 0.6850 Epoch 49/200 125/125 [==============================] - 1s 5ms/step - loss: 1.2479 - accuracy: 0.6307 - val_loss: 1.1873 - val_accuracy: 0.6825 Epoch 50/200 125/125 [==============================] - 1s 6ms/step - loss: 1.2536 - accuracy: 0.6367 - val_loss: 1.1820 - val_accuracy: 0.6860 Epoch 51/200 125/125 [==============================] - 1s 6ms/step - loss: 1.2583 - accuracy: 0.6424 - val_loss: 1.1775 - val_accuracy: 0.6780 Epoch 52/200 125/125 [==============================] - 1s 6ms/step - loss: 1.2462 - accuracy: 0.6441 - val_loss: 1.1548 - val_accuracy: 0.6845 Epoch 53/200 125/125 [==============================] - 1s 6ms/step - loss: 1.2175 - accuracy: 0.6468 - val_loss: 1.1528 - val_accuracy: 0.7020 Epoch 54/200 125/125 [==============================] - 1s 6ms/step - loss: 1.2057 - accuracy: 0.6524 - val_loss: 1.1577 - val_accuracy: 0.6885 Epoch 55/200 125/125 [==============================] - 1s 6ms/step - loss: 1.2202 - accuracy: 0.6442 - val_loss: 1.1334 - val_accuracy: 0.7045 Epoch 56/200 125/125 [==============================] - 1s 6ms/step - loss: 1.2048 - accuracy: 0.6504 - val_loss: 1.1544 - val_accuracy: 0.6960 Epoch 57/200 125/125 [==============================] - 1s 6ms/step - loss: 1.2113 - accuracy: 0.6483 - val_loss: 1.1609 - val_accuracy: 0.6915 Epoch 58/200 125/125 [==============================] - 1s 6ms/step - loss: 1.2324 - accuracy: 0.6509 - val_loss: 1.1734 - val_accuracy: 0.6870 Epoch 59/200 125/125 [==============================] - 1s 6ms/step - loss: 1.2416 - accuracy: 0.6346 - val_loss: 1.1662 - val_accuracy: 0.6910 Epoch 00059: early stopping
metrics = history.history
plt.plot(history.epoch, metrics['loss'])
plt.plot(history.epoch, metrics['accuracy'])
plt.plot(history.epoch, metrics['val_loss'])
plt.plot(history.epoch, metrics['val_accuracy'])
plt.legend(['loss', 'accuracy', 'val_loss', 'val_accuracy'])
plt.show()
predictions = model.predict(test_ds)
pred_labels = np.array([np.argmax(p) for p in predictions])
pred_labels.shape
(200,)
true_labels = []
for _,label in test_ds.as_numpy_iterator():
true_labels.append(label)
true_labels = np.concatenate(true_labels)
true_labels.shape
(200,)
mat = tf.math.confusion_matrix(true_labels, pred_labels)
fig = plt.figure(figsize=(10,10))
sns.heatmap(mat, square=True, annot=True, cbar=True, fmt='d',
xticklabels=label_names,
yticklabels=label_names)
plt.xlabel('True')
plt.ylabel('Predicted');
print(f"Actual accuracy is: {metrics['accuracy'][-1]*100}")
Actual accuracy is: 64.31249976158142
example = mfcc_train_ds.skip(p[0] + p[1]).skip(np.random.randint(p[2]-1)).take(1)
for spec,label in example:
beast = label_names[int(label.numpy())]
predictions_single = model.predict(example)
cloud = label_names[np.argmax(predictions_single[0])]
if beast != cloud:
just_silly = "Wrong!"
somehow="However,"
else:
just_silly = "Correct!"
somehow="Moreover,"
print(f"""
I've seen a {cloud}.
{somehow} it is a {beast}
Therefore,
I am
{just_silly}
"""
)
plt.bar(label_names, tf.nn.softmax(predictions_single[0]))
plt.title(f'Predictions for "{predicted_label}"')
I've seen a bass. Moreover, it is a bass Therefore, I am Correct!
Text(0.5, 1.0, 'Predictions for "bass"')