KingNish commited on
Commit
4bdbda7
·
verified ·
1 Parent(s): 2f32d3a

Update inference/infer.py

Browse files
Files changed (1) hide show
  1. inference/infer.py +4 -0
inference/infer.py CHANGED
@@ -139,6 +139,9 @@ repetition_penalty = 1.2
139
  # special tokens
140
  start_of_segment = mmtokenizer.tokenize('[start_of_segment]')
141
  end_of_segment = mmtokenizer.tokenize('[end_of_segment]')
 
 
 
142
  # Format text prompt
143
  run_n_segments = min(args.run_n_segments+1, len(lyrics))
144
 
@@ -197,6 +200,7 @@ for i, p in enumerate(tqdm(prompt_texts[:run_n_segments])):
197
  raw_output = torch.cat([raw_output, prompt_ids, output_seq[:, input_ids.shape[-1]:]], dim=1)
198
  else:
199
  raw_output = output_seq
 
200
 
201
  # save raw output and check sanity
202
  ids = raw_output[0].cpu().numpy()
 
139
  # special tokens
140
  start_of_segment = mmtokenizer.tokenize('[start_of_segment]')
141
  end_of_segment = mmtokenizer.tokenize('[end_of_segment]')
142
+
143
+ raw_output = None
144
+
145
  # Format text prompt
146
  run_n_segments = min(args.run_n_segments+1, len(lyrics))
147
 
 
200
  raw_output = torch.cat([raw_output, prompt_ids, output_seq[:, input_ids.shape[-1]:]], dim=1)
201
  else:
202
  raw_output = output_seq
203
+ print(len(raw_output))
204
 
205
  # save raw output and check sanity
206
  ids = raw_output[0].cpu().numpy()