fffiloni KingNish commited on
Commit
c654b1b
·
verified ·
1 Parent(s): d8b1d07

Optimized for speed (#7)

Browse files

- A patch of transformers (98915f7a2db0b5b80c135c8ee021c7e73f0488ab)
- Increased stage 2 batch size (c9e6e07806ab611efa8e609c110db1d450a450fa)
- Flash Attention 2 only supports fp16 (e5ae04a10082bfbc1f23931af7e09b9745cc3ae1)


Co-authored-by: Nishith Jain <[email protected]>

Files changed (3) hide show
  1. app.py +1 -1
  2. inference/infer.py +1 -1
  3. requirements.txt +1 -1
app.py CHANGED
@@ -124,7 +124,7 @@ def infer(genre_txt_content, lyrics_txt_content, num_segments, max_new_tokens):
124
  "--genre_txt", f"{genre_txt_path}",
125
  "--lyrics_txt", f"{lyrics_txt_path}",
126
  "--run_n_segments", str(num_segments),
127
- "--stage2_batch_size", "4",
128
  "--output_dir", f"{output_dir}",
129
  "--cuda_idx", "0",
130
  "--max_new_tokens", str(max_new_tokens)
 
124
  "--genre_txt", f"{genre_txt_path}",
125
  "--lyrics_txt", f"{lyrics_txt_path}",
126
  "--run_n_segments", str(num_segments),
127
+ "--stage2_batch_size", "16",
128
  "--output_dir", f"{output_dir}",
129
  "--cuda_idx", "0",
130
  "--max_new_tokens", str(max_new_tokens)
inference/infer.py CHANGED
@@ -76,7 +76,7 @@ print(f"Using device: {device}")
76
  mmtokenizer = _MMSentencePieceTokenizer("./mm_tokenizer_v0.2_hf/tokenizer.model")
77
  model = AutoModelForCausalLM.from_pretrained(
78
  stage1_model,
79
- torch_dtype=torch.bfloat16,
80
  attn_implementation="flash_attention_2", # To enable flashattn, you have to install flash-attn
81
  )
82
  model.to(device)
 
76
  mmtokenizer = _MMSentencePieceTokenizer("./mm_tokenizer_v0.2_hf/tokenizer.model")
77
  model = AutoModelForCausalLM.from_pretrained(
78
  stage1_model,
79
+ torch_dtype=torch.float16,
80
  attn_implementation="flash_attention_2", # To enable flashattn, you have to install flash-attn
81
  )
82
  model.to(device)
requirements.txt CHANGED
@@ -3,7 +3,7 @@ torchaudio==2.2.0 --index-url https://download.pytorch.org/whl/cu118
3
  omegaconf
4
  einops
5
  numpy<2
6
- transformers
7
  sentencepiece
8
  tqdm
9
  tensorboard
 
3
  omegaconf
4
  einops
5
  numpy<2
6
+ git+https://github.com/KingNish24/transformers.git@yue-patch
7
  sentencepiece
8
  tqdm
9
  tensorboard