Update inference/infer.py
Browse files- inference/infer.py +4 -0
inference/infer.py
CHANGED
@@ -139,6 +139,9 @@ repetition_penalty = 1.2
|
|
139 |
# special tokens
|
140 |
start_of_segment = mmtokenizer.tokenize('[start_of_segment]')
|
141 |
end_of_segment = mmtokenizer.tokenize('[end_of_segment]')
|
|
|
|
|
|
|
142 |
# Format text prompt
|
143 |
run_n_segments = min(args.run_n_segments+1, len(lyrics))
|
144 |
|
@@ -197,6 +200,7 @@ for i, p in enumerate(tqdm(prompt_texts[:run_n_segments])):
|
|
197 |
raw_output = torch.cat([raw_output, prompt_ids, output_seq[:, input_ids.shape[-1]:]], dim=1)
|
198 |
else:
|
199 |
raw_output = output_seq
|
|
|
200 |
|
201 |
# save raw output and check sanity
|
202 |
ids = raw_output[0].cpu().numpy()
|
|
|
139 |
# special tokens
|
140 |
start_of_segment = mmtokenizer.tokenize('[start_of_segment]')
|
141 |
end_of_segment = mmtokenizer.tokenize('[end_of_segment]')
|
142 |
+
|
143 |
+
raw_output = None
|
144 |
+
|
145 |
# Format text prompt
|
146 |
run_n_segments = min(args.run_n_segments+1, len(lyrics))
|
147 |
|
|
|
200 |
raw_output = torch.cat([raw_output, prompt_ids, output_seq[:, input_ids.shape[-1]:]], dim=1)
|
201 |
else:
|
202 |
raw_output = output_seq
|
203 |
+
print(len(raw_output))
|
204 |
|
205 |
# save raw output and check sanity
|
206 |
ids = raw_output[0].cpu().numpy()
|