S2 - Model Source and Checkpoint


Table of contents

  1. Source Code
  2. Downloading the Checkpoints
  3. Loading the Checkpoints
  4. Generating Drum Loops

Source Code


The source code for the model is available at https://github.com/behzadhaki/GrooveTransformer.

This repository holds all the training data, model source definitions, and trained checkpoints.

Note that there are multiple models hosted here, but the one used in this chapter is located in model/Base directory.

A very detailed guide on how to use the models is available in the repository’s documentation.

Below is a quick guide on loading the pretrained model checkpoints, and using them to generate drum loops.


Downloading the Checkpoints


The checkpoints are available in the model/Base/monotonic_groove_transformer_v1/latest directory.

Each model also has a .json file that contains the model’s configuration. This file is NOT required to load the model, but it is useful for understanding the model’s architecture and hyperparameters. Moreover, the name of the models is different than the paper, with the json configurations, you can find the corresponding model. Alternatively, use the following table to download the checkpoints directly.

Model NameCheckpoint Name 
Model 1rosie-durian-248Download
Model 2hopeful-gorge-252Download
Model 3solar-shadow-247Download
Model 4misunderstood-bush-246Download

Loading the Checkpoints


Each model contains a .pth file that can be loaded using the following code:

from helpers.BasicMonotonicGrooveTransformer.modelLoadersSamplers import  load_mgt_model

model_path = "SomePathToTheModel.pth"
MonotonicGrooveTransformer = load_mgt_model(model_path)

Generating Drum Loops


To generate drum loops, you can use the following code:


from data import load_gmd_hvo_sequences

#   5.2 Pass groove to model and sample
from helpers import predict_using_mgt

voice_thresholds = [0.5] * 9    # Sampling threshold for each voice (kick, snare, chat, o hat, l tom, m tom, h tom, crash, ride)
voice_max_count_allowed = [32] * 9
input_groove_hvo = torch.rand((1, 32, 27))

output_hvo = predict_using_mgt(MonotonicGrooveTransformer, input_groove_hvo, voice_thresholds,
                            voice_max_count_allowed, return_concatenated=True)

output_hvo.shape