- 使用
CNN/DailyMail
数据集进行 文本摘要任务 (Text-Summarization Task) - 修改BERT的注意力掩码机制,从双向注意力改为单向注意力(类似GPT)
- 修改预训练任务,从MLM改为CLM(因果语言建模)
- 调整位置编码和输入处理方式
- 继承自BertPreTrainedModel,保留了BERT的基础架构
- 设置is_decoder=True,启用因果注意力机制
- 添加了lm_head用于语言建模任务
- 实现了get_causal_attention_mask方法,生成upper triangular mask
- 确保每个token只能看到其之前的token,实现自回归特性
- 从BERT的MLM改为GPT式的CLM(因果语言建模)
- 实现了标准的语言模型损失计算
-
bert_decoder_3.py
: use transformers library's bert to implement a decoder -
bert_decoder.py
: use pytorch to first to construct a bert from scratch, and then change it into a decoder
- The first version
bert_decoder_3.py
is finished.
- ROUGE
- before running, you should manually copy the
vocab.txt
file from the bert-base-uncased directory to the this project directory.
- pretrain:
cd pretrain
python bert_decoder_pretrain.py
- sft:
cd sft
python bert_decoder_sft.py
- if you want to create your own corpus, you can just copy any text from the internet and paste it into a txt file, then you can run the
clean_corpus.py
to clean the corpus.