mojad121 commited on
Commit
981c966
·
verified ·
1 Parent(s): 61edcdf

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +18 -12
src/streamlit_app.py CHANGED
@@ -3,16 +3,12 @@ os.environ["TRANSFORMERS_CACHE"] = "/app/.cache/huggingface"
3
  os.environ["HF_HOME"] = "/app/.cache/huggingface"
4
  os.environ["XDG_CACHE_HOME"] = "/app/.cache"
5
  os.environ["XDG_CONFIG_HOME"] = "/app/.streamlit"
 
6
  import torch
7
  import torchaudio
8
  from transformers import WhisperProcessor, WhisperForConditionalGeneration
9
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
10
  import streamlit as st
11
- import os
12
- os.environ["TRANSFORMERS_CACHE"] = "/app/.cache/huggingface"
13
- os.environ["HF_HOME"] = "/app/.cache/huggingface"
14
-
15
-
16
 
17
  @st.cache_resource
18
  def load_models():
@@ -36,20 +32,30 @@ def extract_text_features(text):
36
  outputs = text_model(**inputs)
37
  return outputs.logits.argmax(dim=1).item()
38
 
39
- def predict_hate_speech(audio_path, text):
40
- transcription = transcribe(audio_path)
41
- text_input = text if text else transcription
 
 
 
 
 
 
42
  prediction = extract_text_features(text_input)
43
  return "Hate Speech" if prediction == 1 else "Not Hate Speech"
44
 
45
  st.title("Hate Speech Detector with Audio and Text")
46
- audio_file = st.file_uploader("Upload an audio file", type=["wav", "mp3", "flac"])
47
  text_input = st.text_input("Optional text input")
 
48
  if st.button("Predict"):
49
  if audio_file is not None:
50
- with open("temp_audio.wav", "wb") as f:
51
  f.write(audio_file.read())
52
- prediction = predict_hate_speech("temp_audio.wav", text_input)
 
 
 
53
  st.success(prediction)
54
  else:
55
- st.warning("Please upload an audio file.")
 
3
  os.environ["HF_HOME"] = "/app/.cache/huggingface"
4
  os.environ["XDG_CACHE_HOME"] = "/app/.cache"
5
  os.environ["XDG_CONFIG_HOME"] = "/app/.streamlit"
6
+
7
  import torch
8
  import torchaudio
9
  from transformers import WhisperProcessor, WhisperForConditionalGeneration
10
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
11
  import streamlit as st
 
 
 
 
 
12
 
13
  @st.cache_resource
14
  def load_models():
 
32
  outputs = text_model(**inputs)
33
  return outputs.logits.argmax(dim=1).item()
34
 
35
+ def predict_hate_speech(audio_path=None, text=None):
36
+ if text:
37
+ text_input = text
38
+ elif audio_path:
39
+ transcription = transcribe(audio_path)
40
+ text_input = transcription
41
+ else:
42
+ return "Please provide either audio or text input."
43
+
44
  prediction = extract_text_features(text_input)
45
  return "Hate Speech" if prediction == 1 else "Not Hate Speech"
46
 
47
  st.title("Hate Speech Detector with Audio and Text")
48
+ audio_file = st.file_uploader("Upload an audio file (wav, mp3, flac, ogg, opus)", type=["wav", "mp3", "flac", "ogg", "opus"])
49
  text_input = st.text_input("Optional text input")
50
+
51
  if st.button("Predict"):
52
  if audio_file is not None:
53
+ with open("temp_audio", "wb") as f:
54
  f.write(audio_file.read())
55
+ prediction = predict_hate_speech("temp_audio", text_input)
56
+ st.success(prediction)
57
+ elif text_input:
58
+ prediction = predict_hate_speech(text=text_input)
59
  st.success(prediction)
60
  else:
61
+ st.warning("Please provide at least audio or text input.")