The baseline model for this project has been taken from Keyword Transformer: A Self-Attention Model for Keyword Spotting.
There are two versions of the dataset, V1 and V2. To download and extract dataset V2, run:
wget https://storage.googleapis.com/download.tensorflow.org/data/speech_commands_v0.02.tar.gz
mkdir data2
mv ./speech_commands_v0.02.tar.gz ./data2
cd ./data2
tar -xf ./speech_commands_v0.02.tar.gz
cd ../
Set up a new virtual environment:
pip install virtualenv
virtualenv --system-site-packages -p python3 ./venv3
source ./venv3/bin/activate
To install dependencies, run
pip install -r requirements.txt
The Keyword-Transformer model is defined here. It takes the mel scale spectrogram as input, which has shape 98 x 40 using the default settings, corresponding to the 98 time windows with 40 frequency coefficients.
There are three variants of the Keyword-Transformer model:
- Time-domain attention: each time-window is treated as a patch, self-attention is computed between time-windows
- Frequency-domain attention: each frequency is treated as a patch self-attention is computed between frequencies
- Combination of both: The signal is fed into both a time- and a frequency-domain transformer and the outputs are combined
- Patch-wise attention: Similar to the vision transformer, it extracts rectangular patches from the spectrogram, so attention happens both in the time and frequency domain simultaneously.
To train KWT-3 from scratch on Speech Commands V2, run
sh train.sh
Please note that the train directory (given by the argument --train_dir
) cannot exist prior to start script.
The model-specific arguments for KWT are:
--num_layers 12 \ #number of sequential transformer encoders
--heads 3 \ #number of attentions heads
--d_model 192 \ #embedding dimension
--mlp_dim 768 \ #mlp-dimension
--dropout1 0. \ #dropout in mlp/multi-head attention blocks
--attention_type 'time' \ #attention type: 'time', 'freq', 'both' or 'patch'
--patch_size '1,40' \ #spectrogram patch_size, if patch attention is used
--prenorm False \ # if False, use postnorm
We employ hard distillation from a convolutional model (Att-MH-RNN), similar to the approach in DeIT.
To train KWT-3 with hard distillation from a pre-trained model, run
sh distill.sh
To perform inference on Google Speech Commands v2 with 12 labels, run
sh eval.sh
Specific Experiments can be found at branches-
- mish
- nonorm_swish
- norm_change
This repository has been forked from https://github.com/ARM-software/keyword-transformer