File size: 4,054 Bytes
c8eb530
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5b7e87f
c8eb530
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0cb654b
c8eb530
 
 
 
 
 
 
 
 
 
 
 
 
5b7e87f
ea1ad1e
5b7e87f
 
3f8c47b
c8eb530
 
5b7e87f
 
 
c8eb530
 
 
5b7e87f
c8eb530
 
 
5b7e87f
 
c8eb530
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0cb654b
5b7e87f
c8eb530
 
 
5b7e87f
c8eb530
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import os
from pprint import pprint
os.system("pip install git+https://github.com/openai/whisper.git")
import gradio as gr
import whisper
from transformers import pipeline
import torch
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer
import time
# import streaming.py
# from next_word_prediction import GPT2




gpt2 = AutoModelForCausalLM.from_pretrained("gpt2", return_dict_in_generate=True)
tokenizer = AutoTokenizer.from_pretrained("gpt2")

### /code snippet


# get gpt2 model
generator = pipeline('text-generation', model='gpt2')

# whisper model specification 
model = whisper.load_model("tiny")


        
def inference(audio, state=""):

    #time.sleep(2)
    #text = p(audio)["text"]
    #state += text + " "
    # load audio data
    audio = whisper.load_audio(audio)
    # ensure sample is in correct format for inference
    audio = whisper.pad_or_trim(audio)

    # generate a log-mel spetrogram of the audio data
    mel = whisper.log_mel_spectrogram(audio).to(model.device)
    
    _, probs = model.detect_language(mel)

    # decode audio data
    options = whisper.DecodingOptions(fp16 = False)
    # transcribe speech to text
    result = whisper.decode(model, mel, options)
    result_len = len(result.text)
    
    # Added prompt below
    input_prompt = "The following is a transcript of someone talking, please predict what they will say next. \n"
    ### code 
    input_total = input_prompt + result.text
    input_ids = tokenizer(input_total, return_tensors="pt").input_ids
    print("inputs ", input_ids)

    # prompt length
    # prompt_length = len(tokenizer.decode(inputs_ids[0]))
    
    # length penalty for gpt2.generate??? 
    #Prompt
    #generated_outputs = gpt2.generate(input_ids, do_sample=True, num_return_sequences=3, output_scores=True, max_length=4)
    output = gpt2.generate(input_ids, max_length=5, do_sample=True, top_k=50, top_p=0.95, num_return_sequences=5)
    print("output ", output)
    #outputs = [output[-4:] for output in output.tolist()]
    # print("outputs generated ", generated_outputs[0])
    # only use id's that were generated
    # gen_sequences has shape [3, 15]
    
    #gen_sequences = outputs.sequences[:, input_ids.shape[-1]:]
    #print("gen sequences: ", gen_sequences)
    
    # let's stack the logits generated at each step to a tensor and transform
    # logits to probs
    #probs = torch.stack(generated_outputs.scores, dim=1).softmax(-1)  # -> shape [3, 15, vocab_size]
    
    # now we need to collect the probability of the generated token
    # we need to add a dummy dim in the end to make gather work
    #gen_probs = torch.gather(probs, 2, gen_sequences[:, :, None]).squeeze(-1)
    #print("gen probs result: ", gen_probs)
    # now we can do all kinds of things with the probs
    
    # 1) the probs that exactly those sequences are generated again
    # those are normally going to be very small
    # unique_prob_per_sequence = gen_probs.prod(-1)
    
    # 2) normalize the probs over the three sequences
    # normed_gen_probs = gen_probs / gen_probs.sum(0)
    # assert normed_gen_probs[:, 0].sum() == 1.0, "probs should be normalized"
    
    # 3) compare normalized probs to each other like in 1)
    # unique_normed_prob_per_sequence = normed_gen_probs.prod(-1)
    
    ### end code
    # print audio data as text
    # print(result.text)
    # prompt
    getText = generator(result.text, max_new_tokens=10, num_return_sequences=5)
    state = getText
    print(state)
    gt = [gt['generated_text'] for gt in state]
    print(type(gt))
    gtTrim = [gt.lstrip(result) for val in gt]
        
    # result.text
    #return getText, gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
    return result.text, state, gtTrim



# get audio from microphone 

gr.Interface(
        fn=inference, 
    inputs=[
        gr.inputs.Audio(source="microphone", type="filepath"), 
        "state"
    ],
    outputs=[
        "textbox",
        "state",
        "textbox"
    ],
    live=True).launch()