S36 - Model Source and Checkpoint
Table of contents
Model Checkpoints
The model checkpoints are available in the following links:
Using the Pre-trained Models
If you only want to use the pre-trained models, you can refer to the following repository: https://github.com/behzadhaki/InfillingTransformerRealTime.
Step 1. Loading the Model
The detailed scripts can be found in the py_utils/utils.py file. The following code snippet shows how to load the model:
# Import model
import torch
from BaseGrooveTransformers.models.transformer import GrooveTransformerEncoder
import os
import time
# MAGENTA MAPPING
ROLAND_REDUCED_MAPPING = {
"KICK": [36],
"SNARE": [38, 37, 40],
"HH_CLOSED": [42, 22, 44],
"HH_OPEN": [46, 26],
"TOM_3_LO": [43, 58],
"TOM_2_MID": [47, 45],
"TOM_1_HI": [50, 48],
"CRASH": [49, 52, 55, 57],
"RIDE": [51, 53, 59]
}
def initialize_model(params):
model_params = params
groove_transformer = GrooveTransformerEncoder(model_params['d_model'], model_params['embedding_size_src'],
model_params['embedding_size_tgt'], model_params['n_heads'],
model_params['dim_feedforward'], model_params['dropout'],
model_params['num_encoder_layers'],
model_params['max_len'], model_params['device'])
return groove_transformer
def load_model(model_path, model_name):
# load checkpoint
params = torch.load(os.path.join(model_path, model_name+".params"), map_location='cpu')['model']
print(params)
model = initialize_model(params)
model.load_state_dict(torch.load(os.path.join(model_path,model_name + ".pt")))
model.eval()
return model
Step 2. Extract MSO from audio
Detailed scripts on extracting MSO from audio can be found here. Below is also a snippet of the code:
from py_utils import mso_utils
from py_utils.mso_utils import *
import os
import librosa
import librosa.display
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable
import os
mso_settings = {
"sr": 44100,
"n_fft": 1024,
"win_length": 1024,
"hop_length": 512,
"n_bins_per_octave": 16,
"n_octaves": 9,
"f_min": 40,
"mean_filter_size": 22,
"c_freq": [55, 90, 138, 175, 350, 6000, 8500, 12500]
}
def mso(y, grid_lines, **kwargs):
"""
Multiband synthesized onsets.
"""
sr = kwargs.get('sr', 44100)
n_fft = kwargs.get('n_fft', 1024)
win_length = kwargs.get('win_length', 1024)
hop_length = kwargs.get('hop_length', 512)
n_bins_per_octave = kwargs.get('n_bins_per_octave', 16)
n_octaves = kwargs.get('n_octaves', 9)
f_min = kwargs.get('f_min', 40)
mean_filter_size = kwargs.get('mean_filter_size', 22)
c_freq = kwargs.get('c_freq', [55, 90, 138, 175, 350, 6000, 8500, 12500])
# if the audio starts right at grid-line 0, but the grid lines are relative to -0.5 microtiming of first grid line, set to True
reorder_to_start_before_gridline_0 = kwargs.get('reorder_to_start_before_gridline_0', True)
if reorder_to_start_before_gridline_0 is True:
half_grid_res_in_samples = int((grid_lines[1]-grid_lines[0]) * sr / 2.0)
y = np.roll(y, half_grid_res_in_samples) # grab last 32note segment and put at beginning
# onset strength spectrogram
spec, f_cq = get_onset_strength_spec(y, n_fft=n_fft, win_length=win_length,
hop_length=hop_length, n_bins_per_octave=n_bins_per_octave,
n_octaves=n_octaves, f_min=f_min, sr=sr,
mean_filter_size=mean_filter_size)
# multiband onset detection and strength
mb_onset_strength = reduce_f_bands_in_spec(c_freq, f_cq, spec)
mb_onset_detect = detect_onset(mb_onset_strength)
# map to grid
strength_grid, onsets_grid = map_onsets_to_grid(grid_lines, mb_onset_strength, mb_onset_detect, n_fft=n_fft,
hop_length=hop_length, sr=sr)
# concatenate in one single array
mso = np.concatenate((strength_grid, onsets_grid), axis=1)
return mso, spec, f_cq, strength_grid, onsets_grid, c_freq
if __name__ == "__main__":
filepath = "./tmp/internal_buffer_30_04_2022 14.19.57.wav"
bpm = 120
sr = 44100
data, _ = mso_utils.read_input_wav(filepath, delete_after_read=False, sr=sr)
grid_lines = mso_utils.create_grid_lines(bpm, sr, 33, start_before_0=True)
data_mso, spec, f_cq, strength_grid, onsets_grid, c_freq= mso(data, grid_lines, sr=sr, reorder_to_start_before_gridline_0=True)
Step 3. Infilling
def get_prediction(trained_model, input_tensor, voice_thresholds, voice_max_count_allowed, sampling_mode=0):
trained_model.eval()
with torch.no_grad():
_h, v, o = trained_model.forward(input_tensor) # Nx32xembedding_size_src/3,Nx32xembedding_size_src/3,Nx32xembedding_size_src/3
_h = torch.sigmoid(_h)
h = torch.zeros_like(_h)
for ix, (thres, max_count) in enumerate(zip(voice_thresholds, voice_max_count_allowed)):
max_indices = torch.topk(_h[:, :, ix], max_count).indices[0]
h[:, max_indices, ix] = _h[:, max_indices, ix]
h[:, :, ix] = torch.where(h[:, :, ix] > thres, 1, 0)
return h, v, o
Comprehensive Source Code
The original source code for the model is available at: https://github.com/pelinski/TransformerGrooveInfilling.
The above source code provides instructions on training the models, running the inference, and evaluating the models.