mskov commited on
Commit
5b7e87f
·
1 Parent(s): e729475

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -10
app.py CHANGED
@@ -13,7 +13,7 @@ import time
13
 
14
 
15
 
16
- ### code snippet
17
  gpt2 = AutoModelForCausalLM.from_pretrained("gpt2", return_dict_in_generate=True)
18
  tokenizer = AutoTokenizer.from_pretrained("gpt2")
19
 
@@ -61,22 +61,25 @@ def inference(audio, state=""):
61
 
62
  # length penalty for gpt2.generate???
63
  #Prompt
64
- generated_outputs = gpt2.generate(input_ids, do_sample=True, num_return_sequences=3, output_scores=True, max_length=4)
65
- outputs = [generated_output[-4:] for generated_output in generated_outputs.tolist()]
 
 
66
  # print("outputs generated ", generated_outputs[0])
67
  # only use id's that were generated
68
  # gen_sequences has shape [3, 15]
69
- gen_sequences = outputs.sequences[:, input_ids.shape[-1]:]
70
- print("gen sequences: ", gen_sequences)
 
71
 
72
  # let's stack the logits generated at each step to a tensor and transform
73
  # logits to probs
74
- probs = torch.stack(generated_outputs.scores, dim=1).softmax(-1) # -> shape [3, 15, vocab_size]
75
 
76
  # now we need to collect the probability of the generated token
77
  # we need to add a dummy dim in the end to make gather work
78
- gen_probs = torch.gather(probs, 2, gen_sequences[:, :, None]).squeeze(-1)
79
- print("gen probs result: ", gen_probs)
80
  # now we can do all kinds of things with the probs
81
 
82
  # 1) the probs that exactly those sequences are generated again
@@ -99,11 +102,11 @@ def inference(audio, state=""):
99
  print(state)
100
  gt = [gt['generated_text'] for gt in state]
101
  print(type(gt))
102
-
103
 
104
  # result.text
105
  #return getText, gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
106
- return result.text, state, gt
107
 
108
 
109
 
 
13
 
14
 
15
 
16
+
17
  gpt2 = AutoModelForCausalLM.from_pretrained("gpt2", return_dict_in_generate=True)
18
  tokenizer = AutoTokenizer.from_pretrained("gpt2")
19
 
 
61
 
62
  # length penalty for gpt2.generate???
63
  #Prompt
64
+ #generated_outputs = gpt2.generate(input_ids, do_sample=True, num_return_sequences=3, output_scores=True, max_length=4)
65
+ output = model.generate(input_ids, max_length=5, do_sample=True, top_k=50, top_p=0.95, num_return_sequences=5)
66
+ print("output ", output)
67
+ #outputs = [output[-4:] for output in output.tolist()]
68
  # print("outputs generated ", generated_outputs[0])
69
  # only use id's that were generated
70
  # gen_sequences has shape [3, 15]
71
+
72
+ #gen_sequences = outputs.sequences[:, input_ids.shape[-1]:]
73
+ #print("gen sequences: ", gen_sequences)
74
 
75
  # let's stack the logits generated at each step to a tensor and transform
76
  # logits to probs
77
+ #probs = torch.stack(generated_outputs.scores, dim=1).softmax(-1) # -> shape [3, 15, vocab_size]
78
 
79
  # now we need to collect the probability of the generated token
80
  # we need to add a dummy dim in the end to make gather work
81
+ #gen_probs = torch.gather(probs, 2, gen_sequences[:, :, None]).squeeze(-1)
82
+ #print("gen probs result: ", gen_probs)
83
  # now we can do all kinds of things with the probs
84
 
85
  # 1) the probs that exactly those sequences are generated again
 
102
  print(state)
103
  gt = [gt['generated_text'] for gt in state]
104
  print(type(gt))
105
+ gtTrim = [gt.lstrip(result) for val in gt]
106
 
107
  # result.text
108
  #return getText, gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
109
+ return result.text, state, gtTrim
110
 
111
 
112