S2 - Model Source and Checkpoint
Table of contents
Source Code
The source code for the model is available at https://github.com/behzadhaki/GrooveTransformer.
This repository holds all the training data, model source definitions, and trained checkpoints.
Note that there are multiple models hosted here, but the one used in this chapter is located in model/Base directory.
A very detailed guide on how to use the models is available in the repository’s documentation.
Below is a quick guide on loading the pretrained model checkpoints, and using them to generate drum loops.
Downloading the Checkpoints
The checkpoints are available in the model/Base/monotonic_groove_transformer_v1/latest directory.
Each model also has a
.json
file that contains the model’s configuration. This file isNOT
required to load the model, but it is useful for understanding the model’s architecture and hyperparameters. Moreover, the name of the models is different than the paper, with the json configurations, you can find the corresponding model. Alternatively, use the following table to download the checkpoints directly.
Model Name | Checkpoint Name | |
---|---|---|
Model 1 | rosie-durian-248 | Download |
Model 2 | hopeful-gorge-252 | Download |
Model 3 | solar-shadow-247 | Download |
Model 4 | misunderstood-bush-246 | Download |
Loading the Checkpoints
Each model contains a .pth
file that can be loaded using the following code:
from helpers.BasicMonotonicGrooveTransformer.modelLoadersSamplers import load_mgt_model
model_path = "SomePathToTheModel.pth"
MonotonicGrooveTransformer = load_mgt_model(model_path)
Generating Drum Loops
To generate drum loops, you can use the following code:
from data import load_gmd_hvo_sequences
# 5.2 Pass groove to model and sample
from helpers import predict_using_mgt
voice_thresholds = [0.5] * 9 # Sampling threshold for each voice (kick, snare, chat, o hat, l tom, m tom, h tom, crash, ride)
voice_max_count_allowed = [32] * 9
input_groove_hvo = torch.rand((1, 32, 27))
output_hvo = predict_using_mgt(MonotonicGrooveTransformer, input_groove_hvo, voice_thresholds,
voice_max_count_allowed, return_concatenated=True)
output_hvo.shape