Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fine tuning issues #65

Open
VRichardJP opened this issue Dec 27, 2024 · 1 comment
Open

Fine tuning issues #65

VRichardJP opened this issue Dec 27, 2024 · 1 comment

Comments

@VRichardJP
Copy link

Hi,

I am trying to use fine tune one of the pretrained HTS-AT model for binary classification on a custom dataset. I have already managed to do the exact same thing with a pretrained BEATs model, but somehow I can't make it work with HTS-AT.

Here is a summary of what I do:

  • I create the HTSAT_Swin_Transformer model with the same config than in your ESC-50 fine tuning example, the only difference being num_classes=1 and loss_type = "clip_bce" since I do binary classification
  • I load one of the pretrained checkpoint (e.g. HTSAT_AudioSet_Saved_1.ckpt) and update all the model weights but sed_model.tscam_conv.weights and sed_model.tscam_conv.bias (I have verified all weights are correctly loaded)
  • I freeze all the parameters but tscam_conv ones (4.6K trainable params left)
  • I feed the model batches of raw audio frames (sampled at 32000Hz and zero-padded to fit longest audio clip in the batch) and compute the loss against its 0-1 targets with nn.BCELoss

I follow the exact same process with BEATs, the only difference being the layers names and the input data sample rate (16000Hz). Yet I can't get the HTS-AT model to learn anything. For example here is the val_loss after a few epochs over a few tries (blue is BEATs fine tuning for reference):

image

I have tried with different learning rates, pretrained weights and optimizers but it does not seem to have any effect.

My dataset being composed of roughly 10% of positives, the val_loss of a dummy model outputing a constant value of 0.10 would have an approximate val_loss of 0.27, which is what all my attempts seem to converge toward. Basically, the model is not learning anything from the input here.

The input data looks "normal". For example, here is what the sound of an ambulance looks like after HTSAT preprocessing:

image

    def forward(
        self, x: torch.Tensor, mixup_lambda=None, infer_mode=False
    ):  # out_feat_keys: List[str] = None):
        x = self.spectrogram_extractor(x)  # (batch_size, 1, time_steps, freq_bins)

        fig, axs = plt.subplots(2)
        img = librosa.display.specshow(
            x[0][0].detach().cpu().numpy().T, x_axis="time", y_axis="log", ax=axs[0]
        )
        fig.colorbar(img, ax=axs[0], format="%+2.f dB")
        axs[0].set(title="spectogram")

        x = self.logmel_extractor(x)  # (batch_size, 1, time_steps, mel_bins)

        img = librosa.display.specshow(
            x[0][0].detach().cpu().numpy().T, x_axis="time", y_axis="mel", ax=axs[1]
        )
        fig.colorbar(img, ax=axs[1], format="%+2.f dB")
        axs[1].set(title="logmel")
        plt.show()

        # ...

Am I missing a key detail?

@RetroCirce
Copy link
Owner

RetroCirce commented Jan 14, 2025

Hi, thank you for this question and using HTS-AT for the finetuning.

According to my experience, freezing all the parameters but the TSCAM layers in HTS-AT is not a typical choice, because the TSCAM layers function like a mapping, it could only use the knowledge HTS-AT learns from the AudioSet. If this has a large difference to your specific datasets, such finetuning might not work. Usually we finetune the whole model to make the model adapt to a more specific dataset or class categories you want in your scenario.

The similar thing might also work on BEAT, but since BEAT might got more layers to be retrained/finetuned, it might lead to the better optimization. It is all about how many parameters you want to train from scratch.

Have you tried for finetuning the whole model? (i.e., change the last few TSCAM layers but train the whole model)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants