S36 - Model Source and Checkpoint


Table of contents

  1. Model Checkpoints
  2. Using the Pre-trained Models
    1. Step 1. Loading the Model
    2. Step 2. Extract MSO from audio
    3. Step 3. Infilling
  3. Comprehensive Source Code

Model Checkpoints

The model checkpoints are available in the following links:

Download All


Using the Pre-trained Models

If you only want to use the pre-trained models, you can refer to the following repository: https://github.com/behzadhaki/InfillingTransformerRealTime.

Step 1. Loading the Model

The detailed scripts can be found in the py_utils/utils.py file. The following code snippet shows how to load the model:

# Import model
import torch
from BaseGrooveTransformers.models.transformer import GrooveTransformerEncoder
import os
import time


# MAGENTA MAPPING
ROLAND_REDUCED_MAPPING = {
    "KICK": [36],
    "SNARE": [38, 37, 40],
    "HH_CLOSED": [42, 22, 44],
    "HH_OPEN": [46, 26],
    "TOM_3_LO": [43, 58],
    "TOM_2_MID": [47, 45],
    "TOM_1_HI": [50, 48],
    "CRASH": [49, 52, 55, 57],
    "RIDE":  [51, 53, 59]
}

def initialize_model(params):
    model_params = params

    groove_transformer = GrooveTransformerEncoder(model_params['d_model'], model_params['embedding_size_src'],
                                                  model_params['embedding_size_tgt'], model_params['n_heads'],
                                                  model_params['dim_feedforward'], model_params['dropout'],
                                                  model_params['num_encoder_layers'],
                                                  model_params['max_len'], model_params['device'])
    return groove_transformer

def load_model(model_path, model_name):

    # load checkpoint
    params = torch.load(os.path.join(model_path, model_name+".params"), map_location='cpu')['model']
    print(params)
    model = initialize_model(params)
    model.load_state_dict(torch.load(os.path.join(model_path,model_name + ".pt")))
    model.eval()

    return model

Step 2. Extract MSO from audio

Detailed scripts on extracting MSO from audio can be found here. Below is also a snippet of the code:

from py_utils import mso_utils
from py_utils.mso_utils import *

import os
import librosa
import librosa.display
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.axes_grid1 import make_axes_locatable

import os

mso_settings = {
    "sr": 44100,
    "n_fft": 1024,
    "win_length": 1024,
    "hop_length": 512,
    "n_bins_per_octave": 16,
    "n_octaves": 9,
    "f_min": 40,
    "mean_filter_size": 22,
    "c_freq": [55, 90, 138, 175, 350, 6000, 8500, 12500]
}

def mso(y, grid_lines, **kwargs):
    """
    Multiband synthesized onsets.
    """
    sr = kwargs.get('sr', 44100)
    n_fft = kwargs.get('n_fft', 1024)
    win_length = kwargs.get('win_length', 1024)
    hop_length = kwargs.get('hop_length', 512)
    n_bins_per_octave = kwargs.get('n_bins_per_octave', 16)
    n_octaves = kwargs.get('n_octaves', 9)
    f_min = kwargs.get('f_min', 40)
    mean_filter_size = kwargs.get('mean_filter_size', 22)
    c_freq = kwargs.get('c_freq', [55, 90, 138, 175, 350, 6000, 8500, 12500])

    # if the audio starts right at grid-line 0, but the grid lines are relative to -0.5 microtiming of first grid line, set to True
    reorder_to_start_before_gridline_0 = kwargs.get('reorder_to_start_before_gridline_0', True)
    if reorder_to_start_before_gridline_0 is True:
        half_grid_res_in_samples = int((grid_lines[1]-grid_lines[0]) * sr / 2.0)
        y = np.roll(y, half_grid_res_in_samples)    # grab last 32note segment and put at beginning

    # onset strength spectrogram
    spec, f_cq = get_onset_strength_spec(y, n_fft=n_fft, win_length=win_length,
                                         hop_length=hop_length, n_bins_per_octave=n_bins_per_octave,
                                         n_octaves=n_octaves, f_min=f_min, sr=sr,
                                         mean_filter_size=mean_filter_size)

    # multiband onset detection and strength
    mb_onset_strength = reduce_f_bands_in_spec(c_freq, f_cq, spec)
    mb_onset_detect = detect_onset(mb_onset_strength)

    # map to grid
    strength_grid, onsets_grid = map_onsets_to_grid(grid_lines, mb_onset_strength, mb_onset_detect, n_fft=n_fft,
                                                    hop_length=hop_length, sr=sr)

    # concatenate in one single array
    mso = np.concatenate((strength_grid, onsets_grid), axis=1)

    return mso, spec, f_cq, strength_grid, onsets_grid, c_freq

if __name__ == "__main__":

    filepath = "./tmp/internal_buffer_30_04_2022 14.19.57.wav"

    bpm = 120
    sr = 44100

    data, _ = mso_utils.read_input_wav(filepath, delete_after_read=False, sr=sr)

    grid_lines = mso_utils.create_grid_lines(bpm, sr, 33, start_before_0=True)

    data_mso, spec, f_cq, strength_grid, onsets_grid, c_freq= mso(data, grid_lines, sr=sr, reorder_to_start_before_gridline_0=True)

Step 3. Infilling

def get_prediction(trained_model, input_tensor, voice_thresholds, voice_max_count_allowed, sampling_mode=0):
    trained_model.eval()
    with torch.no_grad():
        _h, v, o = trained_model.forward(input_tensor)  # Nx32xembedding_size_src/3,Nx32xembedding_size_src/3,Nx32xembedding_size_src/3

        _h = torch.sigmoid(_h)
        h = torch.zeros_like(_h)


    for ix, (thres, max_count) in enumerate(zip(voice_thresholds, voice_max_count_allowed)):
        max_indices = torch.topk(_h[:, :, ix], max_count).indices[0]
        h[:, max_indices, ix]  = _h[:, max_indices, ix]
        h[:, :, ix] = torch.where(h[:, :, ix] > thres, 1, 0)

    return h, v, o


Comprehensive Source Code

The original source code for the model is available at: https://github.com/pelinski/TransformerGrooveInfilling.

The above source code provides instructions on training the models, running the inference, and evaluating the models.