- The best result is produced by the second model below
resnet152+lstm+hidden_size1024+lr_1e3: Bleu_4_C5 = 0.235378 CIDEr_C5 = 0.748013
resnet152+lstm+hidden_size512+lr_1e3: Bleu_4_C5 = 0.242659 CIDEr_C5 = 0.772517
resnet152+gru+hidden_size512+lr_1e3: Bleu_4_C5 = 0.235254 CIDEr_C5 = 0.750898
resnet152+gru+hidden_size1024+lr_1e3: Bleu_4_C5 = 0.234776 CIDEr_C5 = 0.749187
resnet50+lstm+hidden_size1024+lr_1e3: Bleu_4_C5 = 0.237044 CIDEr_C5 = 0.749605
resnet152+lstm+hidden_size512+lr_1e-2: Bleu_4_C5 = 0.222249 CIDEr_C5 = 0.695748
Python 3.6+
PyTorch 1.3.1
And some other library included in regular python
These could be installed on DSMLP server by:
python -m pip install --user torch
python -m pip install --user matplotlib
python -m pip install --user numpy
python -m pip install --user pycocotools
python -m pip install --user nltk
folder for pictures used in inference demo
in this folder, there's a notebook for plot the validation and training loss
Some models may not include the trained model files because the storage limit of GitHub
But all models has the result and the eval_score, and we have DropBox link for the model files that exceed limit
- We change the CNN (resnet152 and resnet50), RNN (LSTM and GRU), hidden size (512 and 1024), learning rate (1e-2 and 1e-3)
- Under each folder, there might be three folders (models, result_json and eval_score) for:
- storing trained models
- storing results
- storing scores
notebook for run the trained model on whole test set
notebook for run the trained model on whole validation set
notebook for check the dataset&dataloader performance, it shows some of the pictures and captions
this is the demo for inference, it will test the 7 example pictures and generate the captions
this is the demo for training the 'resnet152+lstm+hidden_size512+lr_1e-3' network
this is the dictionary that encodes the word
main file for training
lstm+CNN model construction
gru+CNN model construction
some useful functions
load the training and validation set
get captions when running forward through the network
build the dictionary