RamAnanth1 commited on
Commit
1f93035
Β·
1 Parent(s): e0b4905

Switch to HF based whisper-large-v2 model

Browse files
Files changed (1) hide show
  1. app.py +16 -26
app.py CHANGED
@@ -9,40 +9,30 @@ title="Whisper to Emotion"
9
 
10
  ### β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
11
 
12
- whisper_model = whisper.load_model("large")
13
-
14
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
15
 
 
 
 
 
 
 
 
 
 
 
 
16
  emotion_classifier = pipeline("text-classification",model='bhadresh-savani/distilbert-base-uncased-emotion')
17
 
18
  def translate_and_classify(audio):
 
 
 
19
 
20
- print("""
21
- β€”
22
- Sending audio to Whisper ...
23
- β€”
24
- """)
25
- audio = whisper.load_audio(audio)
26
- audio = whisper.pad_or_trim(audio)
27
-
28
- mel = whisper.log_mel_spectrogram(audio).to(whisper_model.device)
29
-
30
- _, probs = whisper_model.detect_language(mel)
31
-
32
- transcript_options = whisper.DecodingOptions(task="transcribe", fp16 = False)
33
- translate_options = whisper.DecodingOptions(task="translate", fp16 = False)
34
-
35
- transcription = whisper.decode(whisper_model, mel, transcript_options)
36
- translation = whisper.decode(whisper_model, mel, translate_options)
37
-
38
- print("Language Spoken: " + transcription.language)
39
- print("Transcript: " + transcription.text)
40
- print("Translated: " + translation.text)
41
-
42
- emotion = emotion_classifier(translation.text)
43
  detected_emotion = emotion[0]["label"]
44
  print("Detected Emotion: ", detected_emotion)
45
- return transcription.text, detected_emotion
46
 
47
  css = """
48
  .gradio-container {
 
9
 
10
  ### β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”β€”
11
 
 
 
12
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
13
 
14
+ whisper_model = pipeline(
15
+ task="automatic-speech-recognition",
16
+ model="openai/whisper-large-v2",
17
+ chunk_length_s=30,
18
+ device=device,
19
+ )
20
+
21
+ all_special_ids = whisper_model.tokenizer.all_special_ids
22
+ transcribe_token_id = all_special_ids[-5]
23
+ translate_token_id = all_special_ids[-6]
24
+
25
  emotion_classifier = pipeline("text-classification",model='bhadresh-savani/distilbert-base-uncased-emotion')
26
 
27
  def translate_and_classify(audio):
28
+ task = "Transcribe in Spoken Language"
29
+ whisper_model.model.config.forced_decoder_ids = [[2, transcribe_token_id if task=="Transcribe in Spoken Language" else translate_token_id]]
30
+ text = whisper_model(audio)["text"]
31
 
32
+ emotion = emotion_classifier(text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  detected_emotion = emotion[0]["label"]
34
  print("Detected Emotion: ", detected_emotion)
35
+ return text, detected_emotion
36
 
37
  css = """
38
  .gradio-container {