mskov commited on
Commit
e729475
·
1 Parent(s): 3f8c47b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -2
app.py CHANGED
@@ -61,11 +61,12 @@ def inference(audio, state=""):
61
 
62
  # length penalty for gpt2.generate???
63
  #Prompt
64
- generated_outputs = gpt2.generate(input_ids, do_sample=True, num_return_sequences=3, output_scores=True, max_length=4)[:-4]
 
65
  # print("outputs generated ", generated_outputs[0])
66
  # only use id's that were generated
67
  # gen_sequences has shape [3, 15]
68
- gen_sequences = generated_outputs.sequences[:, input_ids.shape[-1]:]
69
  print("gen sequences: ", gen_sequences)
70
 
71
  # let's stack the logits generated at each step to a tensor and transform
 
61
 
62
  # length penalty for gpt2.generate???
63
  #Prompt
64
+ generated_outputs = gpt2.generate(input_ids, do_sample=True, num_return_sequences=3, output_scores=True, max_length=4)
65
+ outputs = [generated_output[-4:] for generated_output in generated_outputs.tolist()]
66
  # print("outputs generated ", generated_outputs[0])
67
  # only use id's that were generated
68
  # gen_sequences has shape [3, 15]
69
+ gen_sequences = outputs.sequences[:, input_ids.shape[-1]:]
70
  print("gen sequences: ", gen_sequences)
71
 
72
  # let's stack the logits generated at each step to a tensor and transform