File size: 3,805 Bytes
c8eb530
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0cb654b
c8eb530
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0cb654b
 
c8eb530
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import os
from pprint import pprint
os.system("pip install git+https://github.com/openai/whisper.git")
import gradio as gr
import whisper
from transformers import pipeline
import torch
from transformers import AutoModelForCausalLM
from transformers import AutoTokenizer
import time
# import streaming.py
# from next_word_prediction import GPT2



### code snippet
gpt2 = AutoModelForCausalLM.from_pretrained("gpt2", return_dict_in_generate=True)
tokenizer = AutoTokenizer.from_pretrained("gpt2")

### /code snippet


# get gpt2 model
generator = pipeline('text-generation', model='gpt2')

# whisper model specification 
model = whisper.load_model("tiny")


        
def inference(audio, state=""):

    #time.sleep(2)
    #text = p(audio)["text"]
    #state += text + " "
    # load audio data
    audio = whisper.load_audio(audio)
    # ensure sample is in correct format for inference
    audio = whisper.pad_or_trim(audio)

    # generate a log-mel spetrogram of the audio data
    mel = whisper.log_mel_spectrogram(audio).to(model.device)
    
    _, probs = model.detect_language(mel)

    # decode audio data
    options = whisper.DecodingOptions(fp16 = False)
    # transcribe speech to text
    result = whisper.decode(model, mel, options)
    result_len = len(result.text)
    
    # Added prompt below
    input_prompt = "The following is a transcript of someone talking, please predict what they will say next. \n"
    ### code 
    input_total = input_prompt + result.text
    input_ids = tokenizer(input_total, return_tensors="pt").input_ids
    print("inputs ", input_ids)

    # prompt length
    # prompt_length = len(tokenizer.decode(inputs_ids[0]))
    
    # length penalty for gpt2.generate??? 
    #Prompt
    generated_outputs = gpt2.generate(input_ids, do_sample=True, num_return_sequences=3, output_scores=True)
    print("outputs generated ", generated_outputs[0])
    # only use id's that were generated
    # gen_sequences has shape [3, 15]
    gen_sequences = generated_outputs.sequences[:, input_ids.shape[-1]:]
    print("gen sequences: ", gen_sequences)
    
    # let's stack the logits generated at each step to a tensor and transform
    # logits to probs
    probs = torch.stack(generated_outputs.scores, dim=1).softmax(-1)  # -> shape [3, 15, vocab_size]
    
    # now we need to collect the probability of the generated token
    # we need to add a dummy dim in the end to make gather work
    gen_probs = torch.gather(probs, 2, gen_sequences[:, :, None]).squeeze(-1)
    print("gen probs result: ", gen_probs)
    # now we can do all kinds of things with the probs
    
    # 1) the probs that exactly those sequences are generated again
    # those are normally going to be very small
    # unique_prob_per_sequence = gen_probs.prod(-1)
    
    # 2) normalize the probs over the three sequences
    # normed_gen_probs = gen_probs / gen_probs.sum(0)
    # assert normed_gen_probs[:, 0].sum() == 1.0, "probs should be normalized"
    
    # 3) compare normalized probs to each other like in 1)
    # unique_normed_prob_per_sequence = normed_gen_probs.prod(-1)
    
    ### end code
    # print audio data as text
    # print(result.text)
    # prompt
    getText = generator(result.text, max_new_tokens=10, num_return_sequences=5)
    state = getText
    print(state)
    gt = [gt['generated_text'] for gt in state]
    print(type(gt))
    
        
    # result.text
    #return getText, gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)
    return result.text, state, gt



# get audio from microphone 

gr.Interface(
        fn=inference, 
    inputs=[
        gr.inputs.Audio(source="microphone", type="filepath"), 
        "state"
    ],
    outputs=[
        "textbox",
        "state",
        "textbox"
    ],
    live=True).launch()