mojad121 commited on
Commit
375c831
·
verified ·
1 Parent(s): 24f225c

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +8 -13
src/streamlit_app.py CHANGED
@@ -1,9 +1,8 @@
1
  import torch
2
- import torchaudio
3
  import os
4
  import streamlit as st
5
- import sounddevice as sd
6
- import soundfile as sf
7
  from transformers import WhisperProcessor, WhisperForConditionalGeneration
8
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
9
 
@@ -17,14 +16,11 @@ whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-
17
  text_model = AutoModelForSequenceClassification.from_pretrained("GroNLP/hateBERT", token=hf_token)
18
  tokenizer = AutoTokenizer.from_pretrained("GroNLP/hateBERT", token=hf_token)
19
 
20
- def record_audio(duration, filename, samplerate=16000):
21
- recording = sd.rec(int(duration * samplerate), samplerate=samplerate, channels=1, dtype='float32')
22
- sd.wait()
23
- sf.write(filename, recording, samplerate)
24
-
25
  def transcribe(audio_path):
26
- waveform, sample_rate = torchaudio.load(audio_path)
27
- input_features = whisper_processor(waveform.squeeze().numpy(), sampling_rate=sample_rate, return_tensors="pt").input_features
 
 
28
  predicted_ids = whisper_model.generate(input_features)
29
  transcription = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
30
  return transcription
@@ -36,8 +32,7 @@ def extract_text_features(text):
36
  return "Hate Speech" if predicted_class >= 1 else "Not Hate Speech"
37
 
38
  def predict(text_input):
39
- audio_path = "mic_input.wav"
40
- record_audio(5, audio_path)
41
  transcribed_text = transcribe(audio_path)
42
  prediction = extract_text_features(text_input or transcribed_text)
43
  if text_input:
@@ -47,6 +42,6 @@ def predict(text_input):
47
 
48
  st.title("Hate Speech Detector")
49
  text_input = st.text_input("Enter text (optional):")
50
- if st.button("Start Recording and Predict"):
51
  result = predict(text_input)
52
  st.success(result)
 
1
  import torch
 
2
  import os
3
  import streamlit as st
4
+ from pydub import AudioSegment
5
+ import numpy as np
6
  from transformers import WhisperProcessor, WhisperForConditionalGeneration
7
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
8
 
 
16
  text_model = AutoModelForSequenceClassification.from_pretrained("GroNLP/hateBERT", token=hf_token)
17
  tokenizer = AutoTokenizer.from_pretrained("GroNLP/hateBERT", token=hf_token)
18
 
 
 
 
 
 
19
  def transcribe(audio_path):
20
+ audio = AudioSegment.from_file(audio_path, format="opus")
21
+ audio = audio.set_channels(1).set_frame_rate(16000)
22
+ samples = np.array(audio.get_array_of_samples()).astype(np.float32) / (2**15)
23
+ input_features = whisper_processor(samples, sampling_rate=16000, return_tensors="pt").input_features
24
  predicted_ids = whisper_model.generate(input_features)
25
  transcription = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
26
  return transcription
 
32
  return "Hate Speech" if predicted_class >= 1 else "Not Hate Speech"
33
 
34
  def predict(text_input):
35
+ audio_path = "input.opus"
 
36
  transcribed_text = transcribe(audio_path)
37
  prediction = extract_text_features(text_input or transcribed_text)
38
  if text_input:
 
42
 
43
  st.title("Hate Speech Detector")
44
  text_input = st.text_input("Enter text (optional):")
45
+ if st.button("Run Prediction"):
46
  result = predict(text_input)
47
  st.success(result)