mojad121 commited on
Commit
55c917c
·
verified ·
1 Parent(s): a44744a

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +36 -40
src/streamlit_app.py CHANGED
@@ -2,58 +2,54 @@ import os
2
 
3
  os.environ["TRANSFORMERS_CACHE"] = "/app/.cache"
4
  os.environ["HF_HOME"] = "/app/.cache"
 
5
  import torch
6
  import torchaudio
 
7
  from transformers import WhisperProcessor, WhisperForConditionalGeneration
8
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
9
  import streamlit as st
10
 
11
-
12
- def load_models():
13
- whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
14
- whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
15
- text_model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased")
16
- tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
17
- return whisper_processor, whisper_model, text_model, tokenizer
18
-
19
- whisper_processor, whisper_model, text_model, tokenizer = load_models()
20
-
21
- def transcribe(audio_path):
22
- waveform, sample_rate = torchaudio.load(audio_path)
23
- input_features = whisper_processor(
24
- waveform.squeeze().numpy(),
25
- sampling_rate=sample_rate,
26
- return_tensors="pt"
27
- ).input_features
28
  predicted_ids = whisper_model.generate(input_features)
29
  transcription = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
 
30
  return transcription
31
 
32
  def extract_text_features(text):
33
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
34
  outputs = text_model(**inputs)
35
- predicted_class = outputs.logits.argmax(dim=1).item()
36
- return "Hate Speech" if predicted_class == 1 else "Not Hate Speech"
37
-
38
- def predict(audio_file, text_input):
39
- if not audio_file and not text_input:
40
- return "Please provide either an audio file or some text."
41
- if audio_file is not None:
42
- audio_path = "temp_audio.wav"
43
- with open(audio_path, "wb") as f:
44
- f.write(audio_file.read())
45
- transcribed_text = transcribe(audio_path)
46
- prediction = extract_text_features(text_input or transcribed_text)
47
- return f"Predicted: {prediction} \n\n(Transcribed: {transcribed_text})" if not text_input else f"Predicted: {prediction}"
48
- elif text_input:
49
- prediction = extract_text_features(text_input)
50
- return f"Predicted: {prediction}"
51
-
52
- st.title("Hate Speech Detector")
53
-
54
- uploaded_audio = st.file_uploader("Upload Audio File (.mp3, .wav, .ogg, .flac, .opus)", type=["mp3", "wav", "ogg", "flac", "opus"])
55
- text_input = st.text_input("Or enter text:")
56
 
57
  if st.button("Predict"):
58
- result = predict(uploaded_audio, text_input)
59
- st.success(result)
 
 
 
 
 
2
 
3
  os.environ["TRANSFORMERS_CACHE"] = "/app/.cache"
4
  os.environ["HF_HOME"] = "/app/.cache"
5
+ import os
6
  import torch
7
  import torchaudio
8
+ import tempfile
9
  from transformers import WhisperProcessor, WhisperForConditionalGeneration
10
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
11
  import streamlit as st
12
 
13
+ whisper_processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
14
+ whisper_model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny")
15
+ text_model = AutoModelForSequenceClassification.from_pretrained("bert-base-uncased")
16
+ tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
17
+
18
+ def transcribe(audio_bytes):
19
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
20
+ tmp.write(audio_bytes)
21
+ tmp_path = tmp.name
22
+ waveform, sample_rate = torchaudio.load(tmp_path)
23
+ input_features = whisper_processor(waveform.squeeze().numpy(), sampling_rate=sample_rate, return_tensors="pt").input_features
 
 
 
 
 
 
24
  predicted_ids = whisper_model.generate(input_features)
25
  transcription = whisper_processor.batch_decode(predicted_ids, skip_special_tokens=True)[0]
26
+ os.remove(tmp_path)
27
  return transcription
28
 
29
  def extract_text_features(text):
30
  inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
31
  outputs = text_model(**inputs)
32
+ return outputs.logits.argmax(dim=1).item()
33
+
34
+ def predict_hate_speech(audio_bytes, text):
35
+ if audio_bytes:
36
+ transcription = transcribe(audio_bytes)
37
+ text_input = text if text else transcription
38
+ elif text:
39
+ text_input = text
40
+ else:
41
+ return "Please provide audio or text"
42
+ prediction = extract_text_features(text_input)
43
+ return "Hate Speech" if prediction == 1 else "Not Hate Speech"
44
+
45
+ st.title("Hate Speech Detection")
46
+ audio_file = st.file_uploader("Upload audio file", type=["wav", "mp3", "flac", "ogg", "opus"])
47
+ text_input = st.text_input("Or enter text")
 
 
 
 
 
48
 
49
  if st.button("Predict"):
50
+ if audio_file is not None or text_input:
51
+ audio_bytes = audio_file.read() if audio_file else None
52
+ result = predict_hate_speech(audio_bytes, text_input)
53
+ st.success(result)
54
+ else:
55
+ st.warning("Please provide either audio or text input")