mojad121 commited on
Commit
16d93a1
·
verified ·
1 Parent(s): cf31213

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +25 -30
src/streamlit_app.py CHANGED
@@ -2,6 +2,8 @@ import torch
2
  import torchaudio
3
  import os
4
  import streamlit as st
 
 
5
  from transformers import WhisperProcessor, WhisperForConditionalGeneration
6
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
7
 
@@ -10,12 +12,15 @@ os.environ["HF_HOME"] = "/app/.cache/huggingface"
10
  os.environ["TORCH_HOME"] = "/app/.cache/torch"
11
  hf_token = os.getenv("HateSpeechMujtabatoken")
12
 
13
- whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-small")
14
- whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
15
- text_model = AutoModelForSequenceClassification.from_pretrained("Hate-speech-CNERG/bert-base-uncased-hatexplain")
16
- tokenizer = AutoTokenizer.from_pretrained("Hate-speech-CNERG/bert-base-uncased-hatexplain")
17
 
18
- label_map = {0: "Not Hate Speech", 1: "Hate Speech", 2: "Hate Speech"}
 
 
 
19
 
20
  def transcribe(audio_path):
21
  waveform, sample_rate = torchaudio.load(audio_path)
@@ -25,33 +30,23 @@ def transcribe(audio_path):
25
  return transcription
26
 
27
  def extract_text_features(text):
28
- inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
29
  outputs = text_model(**inputs)
30
- pred_label = outputs.logits.argmax(dim=1).item()
31
- return label_map.get(pred_label, "Unknown")
32
 
33
- def predict_hate_speech(audio_path=None, text=None):
34
- if audio_path:
35
- transcription = transcribe(audio_path)
36
- text_input = text if text else transcription
37
- elif text:
38
- text_input = text
 
39
  else:
40
- return "No input provided"
41
- prediction = extract_text_features(text_input)
42
- return prediction
43
 
44
  st.title("Hate Speech Detector")
45
- audio_file = st.file_uploader("Upload an audio file", type=["wav", "mp3", "flac", "ogg", "opus"])
46
- text_input = st.text_input("Optional text input")
47
- if st.button("Predict"):
48
- if audio_file is not None:
49
- with open("temp_audio.wav", "wb") as f:
50
- f.write(audio_file.read())
51
- prediction = predict_hate_speech("temp_audio.wav", text_input)
52
- st.success(prediction)
53
- elif text_input:
54
- prediction = predict_hate_speech(text=text_input)
55
- st.success(prediction)
56
- else:
57
- st.warning("Please upload an audio file or enter text.")
 
2
  import torchaudio
3
  import os
4
  import streamlit as st
5
+ import sounddevice as sd
6
+ import soundfile as sf
7
  from transformers import WhisperProcessor, WhisperForConditionalGeneration
8
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
9
 
 
12
  os.environ["TORCH_HOME"] = "/app/.cache/torch"
13
  hf_token = os.getenv("HateSpeechMujtabatoken")
14
 
15
+ whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-tiny", token=hf_token)
16
+ whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny", token=hf_token)
17
+ text_model = AutoModelForSequenceClassification.from_pretrained("GroNLP/hateBERT", token=hf_token)
18
+ tokenizer = AutoTokenizer.from_pretrained("GroNLP/hateBERT", token=hf_token)
19
 
20
+ def record_audio(duration, filename, samplerate=16000):
21
+ recording = sd.rec(int(duration * samplerate), samplerate=samplerate, channels=1, dtype='float32')
22
+ sd.wait()
23
+ sf.write(filename, recording, samplerate)
24
 
25
  def transcribe(audio_path):
26
  waveform, sample_rate = torchaudio.load(audio_path)
 
30
  return transcription
31
 
32
  def extract_text_features(text):
33
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
34
  outputs = text_model(**inputs)
35
+ predicted_class = outputs.logits.argmax(dim=1).item()
36
+ return "Hate Speech" if predicted_class >= 1 else "Not Hate Speech"
37
 
38
+ def predict(text_input):
39
+ audio_path = "mic_input.wav"
40
+ record_audio(5, audio_path)
41
+ transcribed_text = transcribe(audio_path)
42
+ prediction = extract_text_features(text_input or transcribed_text)
43
+ if text_input:
44
+ return f"Predicted: {prediction}"
45
  else:
46
+ return f"Predicted: {prediction} \n\n(Transcribed: {transcribed_text})"
 
 
47
 
48
  st.title("Hate Speech Detector")
49
+ text_input = st.text_input("Enter text (optional):")
50
+ if st.button("Start Recording and Predict"):
51
+ result = predict(text_input)
52
+ st.success(result)