|
# Fine-tuning BART on CNN-Dailymail summarization task |
|
|
|
### 1) Follow instructions [here](https://github.com/abisee/cnn-dailymail) to download and process into data-files with non-tokenized cased samples. |
|
|
|
### 2) BPE preprocess: |
|
```bash |
|
wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/encoder.json' |
|
wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/vocab.bpe' |
|
wget -N 'https://dl.fbaipublicfiles.com/fairseq/gpt2_bpe/dict.txt' |
|
|
|
for SPLIT in train val |
|
do |
|
for LANG in source target |
|
do |
|
python -m examples.roberta.multiprocessing_bpe_encoder \ |
|
--encoder-json encoder.json \ |
|
--vocab-bpe vocab.bpe \ |
|
--inputs "cnn_dm/$SPLIT.$LANG" \ |
|
--outputs "cnn_dm/$SPLIT.bpe.$LANG" \ |
|
--workers 60 \ |
|
--keep-empty; |
|
done |
|
done |
|
``` |
|
|
|
### 3) Binarize dataset: |
|
```bash |
|
fairseq-preprocess \ |
|
--source-lang "source" \ |
|
--target-lang "target" \ |
|
--trainpref "cnn_dm/train.bpe" \ |
|
--validpref "cnn_dm/val.bpe" \ |
|
--destdir "cnn_dm-bin/" \ |
|
--workers 60 \ |
|
--srcdict dict.txt \ |
|
--tgtdict dict.txt; |
|
``` |
|
|
|
### 4) Fine-tuning on CNN-DM summarization task: |
|
Example fine-tuning cmd |
|
```bash |
|
TOTAL_NUM_UPDATES=20000 |
|
WARMUP_UPDATES=500 |
|
LR=3e-05 |
|
MAX_TOKENS=2048 |
|
UPDATE_FREQ=4 |
|
BART_PATH=/path/to/bart/model.pt |
|
|
|
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 python train.py cnn_dm-bin \ |
|
--restore-file $BART_PATH \ |
|
--max-tokens $MAX_TOKENS \ |
|
--task translation \ |
|
--source-lang source --target-lang target \ |
|
--truncate-source \ |
|
--layernorm-embedding \ |
|
--share-all-embeddings \ |
|
--share-decoder-input-output-embed \ |
|
--reset-optimizer --reset-dataloader --reset-meters \ |
|
--required-batch-size-multiple 1 \ |
|
--arch bart_large \ |
|
--criterion label_smoothed_cross_entropy \ |
|
--label-smoothing 0.1 \ |
|
--dropout 0.1 --attention-dropout 0.1 \ |
|
--weight-decay 0.01 --optimizer adam --adam-betas "(0.9, 0.999)" --adam-eps 1e-08 \ |
|
--clip-norm 0.1 \ |
|
--lr-scheduler polynomial_decay --lr $LR --total-num-update $TOTAL_NUM_UPDATES --warmup-updates $WARMUP_UPDATES \ |
|
--fp16 --update-freq $UPDATE_FREQ \ |
|
--skip-invalid-size-inputs-valid-test \ |
|
--find-unused-parameters; |
|
``` |
|
Above is expected to run on `1` node with `8 32gb-V100`. |
|
Expected training time is about `5 hours`. Training time can be reduced with distributed training on `4` nodes and `--update-freq 1`. |
|
|
|
### Inference for CNN-DM test data using above trained checkpoint. |
|
After training the model as mentioned in previous step, you can perform inference with checkpoints in `checkpoints/` directory using following python code snippet: |
|
|
|
```python |
|
from fairseq.models.bart import BARTModel |
|
|
|
bart = BARTModel.from_pretrained( |
|
'checkpoints/', |
|
checkpoint_file='checkpoint_best.pt', |
|
data_name_or_path='cnn_dm-bin' |
|
) |
|
|
|
bart.cuda() |
|
bart.eval() |
|
bart.half() |
|
count = 1 |
|
bsz = 32 |
|
with open('cnn_dm/test.source') as source, open('cnn_dm/test.hypo', 'w') as fout: |
|
sline = source.readline().strip() |
|
slines = [sline] |
|
for sline in source: |
|
if count % bsz == 0: |
|
with torch.no_grad(): |
|
hypotheses_batch = bart.sample(slines, beam=4, lenpen=2.0, max_len_b=140, min_len=55, no_repeat_ngram_size=3) |
|
|
|
for hypothesis in hypotheses_batch: |
|
fout.write(hypothesis + '\n') |
|
fout.flush() |
|
slines = [] |
|
|
|
slines.append(sline.strip()) |
|
count += 1 |
|
if slines != []: |
|
hypotheses_batch = bart.sample(slines, beam=4, lenpen=2.0, max_len_b=140, min_len=55, no_repeat_ngram_size=3) |
|
for hypothesis in hypotheses_batch: |
|
fout.write(hypothesis + '\n') |
|
fout.flush() |
|
``` |
|
|