Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoModelForSeq2SeqLM | |
| from transformers import SpeechT5Processor, SpeechT5ForTextToSpeech, SpeechT5HifiGan | |
| import whisper | |
| import gradio as gr | |
| import re | |
| import pandas as pd | |
| import numpy as np | |
| import os | |
| import time | |
| import logging | |
| import threading | |
| import queue | |
| from scipy.io.wavfile import write as write_wav | |
| from html import escape | |
| import traceback | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
| handlers=[logging.StreamHandler()] | |
| ) | |
| logger = logging.getLogger('profanity_detector') | |
| # Define device at the top of the script (global scope) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| logger.info(f"Using device: {device}") | |
| # Global variables for models | |
| profanity_model = None | |
| profanity_tokenizer = None | |
| t5_model = None | |
| t5_tokenizer = None | |
| whisper_model = None | |
| tts_processor = None | |
| tts_model = None | |
| vocoder = None | |
| models_loaded = False | |
| # Default speaker embeddings for TTS | |
| speaker_embeddings = None | |
| # Queue for real-time audio processing | |
| audio_queue = queue.Queue() | |
| processing_active = False | |
| # Model loading with int8 quantization | |
| def load_models(): | |
| global profanity_model, profanity_tokenizer, t5_model, t5_tokenizer, whisper_model | |
| global tts_processor, tts_model, vocoder, speaker_embeddings, models_loaded | |
| try: | |
| logger.info("Loading profanity detection model...") | |
| PROFANITY_MODEL = "parsawar/profanity_model_3.1" | |
| profanity_tokenizer = AutoTokenizer.from_pretrained(PROFANITY_MODEL) | |
| # Load model with memory optimization using half-precision | |
| profanity_model = AutoModelForSequenceClassification.from_pretrained(PROFANITY_MODEL) | |
| # Move to GPU if available and optimize with half-precision where possible | |
| if torch.cuda.is_available(): | |
| profanity_model = profanity_model.to(device) | |
| # Convert to half precision to save memory (if possible) | |
| try: | |
| profanity_model = profanity_model.half() # Convert to FP16 | |
| logger.info("Successfully converted profanity model to half precision") | |
| except Exception as e: | |
| logger.warning(f"Could not convert to half precision: {str(e)}") | |
| logger.info("Loading detoxification model...") | |
| T5_MODEL = "s-nlp/t5-paranmt-detox" | |
| t5_tokenizer = AutoTokenizer.from_pretrained(T5_MODEL) | |
| # Load model with memory optimization | |
| t5_model = AutoModelForSeq2SeqLM.from_pretrained(T5_MODEL) | |
| # Move to GPU if available and optimize with half-precision where possible | |
| if torch.cuda.is_available(): | |
| t5_model = t5_model.to(device) | |
| # Convert to half precision to save memory (if possible) | |
| try: | |
| t5_model = t5_model.half() # Convert to FP16 | |
| logger.info("Successfully converted T5 model to half precision") | |
| except Exception as e: | |
| logger.warning(f"Could not convert to half precision: {str(e)}") | |
| logger.info("Loading Whisper speech-to-text model...") | |
| whisper_model = whisper.load_model("large") | |
| if torch.cuda.is_available(): | |
| whisper_model = whisper_model.to(device) | |
| logger.info("Loading Text-to-Speech model...") | |
| TTS_MODEL = "microsoft/speecht5_tts" | |
| tts_processor = SpeechT5Processor.from_pretrained(TTS_MODEL) | |
| # Load TTS models without automatic device mapping | |
| tts_model = SpeechT5ForTextToSpeech.from_pretrained(TTS_MODEL) | |
| vocoder = SpeechT5HifiGan.from_pretrained("microsoft/speecht5_hifigan") | |
| # Move models to appropriate device | |
| if torch.cuda.is_available(): | |
| tts_model = tts_model.to(device) | |
| vocoder = vocoder.to(device) | |
| # Speaker embeddings for TTS | |
| speaker_embeddings = torch.zeros((1, 512)) | |
| if torch.cuda.is_available(): | |
| speaker_embeddings = speaker_embeddings.to(device) | |
| models_loaded = True | |
| logger.info("All models loaded successfully.") | |
| return "Models loaded successfully." | |
| except Exception as e: | |
| error_msg = f"Error loading models: {str(e)}\n{traceback.format_exc()}" | |
| logger.error(error_msg) | |
| return error_msg | |
| def detect_profanity(text: str, threshold: float = 0.5): | |
| """ | |
| Detect profanity in text with adjustable threshold | |
| Args: | |
| text: The input text to analyze | |
| threshold: Profanity detection threshold (0.0-1.0) | |
| Returns: | |
| Dictionary with analysis results | |
| """ | |
| if not models_loaded: | |
| return {"error": "Models not loaded yet. Please wait."} | |
| try: | |
| # Detect profanity and score | |
| inputs = profanity_tokenizer(text, return_tensors="pt", truncation=True, max_length=512) | |
| if torch.cuda.is_available(): | |
| inputs = inputs.to(device) | |
| with torch.no_grad(): | |
| outputs = profanity_model(**inputs).logits | |
| score = torch.nn.functional.softmax(outputs, dim=1)[0][1].item() | |
| # Identify specific profane words | |
| words = re.findall(r'\b\w+\b', text) | |
| profane_words = [] | |
| word_scores = {} | |
| if score > threshold: | |
| for word in words: | |
| if len(word) < 2: # Skip very short words | |
| continue | |
| word_inputs = profanity_tokenizer(word, return_tensors="pt", truncation=True, max_length=512) | |
| if torch.cuda.is_available(): | |
| word_inputs = word_inputs.to(device) | |
| with torch.no_grad(): | |
| word_outputs = profanity_model(**word_inputs).logits | |
| word_score = torch.nn.functional.softmax(word_outputs, dim=1)[0][1].item() | |
| word_scores[word] = word_score | |
| if word_score > threshold: | |
| profane_words.append(word.lower()) | |
| # Create highlighted version of the text | |
| highlighted_text = create_highlighted_text(text, profane_words) | |
| return { | |
| "text": text, | |
| "score": score, | |
| "profanity": score > threshold, | |
| "profane_words": profane_words, | |
| "highlighted_text": highlighted_text, | |
| "word_scores": word_scores | |
| } | |
| except Exception as e: | |
| error_msg = f"Error in profanity detection: {str(e)}" | |
| logger.error(error_msg) | |
| return {"error": error_msg, "text": text, "score": 0, "profanity": False} | |
| def create_highlighted_text(text, profane_words): | |
| """ | |
| Create HTML-formatted text with profane words highlighted | |
| """ | |
| if not profane_words: | |
| return escape(text) | |
| # Create a regex pattern matching any of the profane words (case insensitive) | |
| pattern = r'\b(' + '|'.join(re.escape(word) for word in profane_words) + r')\b' | |
| # Replace occurrences with highlighted versions | |
| def highlight_match(match): | |
| return f'<span style="background-color: rgba(255, 0, 0, 0.3); padding: 0px 2px; border-radius: 3px;">{match.group(0)}</span>' | |
| highlighted = re.sub(pattern, highlight_match, text, flags=re.IGNORECASE) | |
| return highlighted | |
| def rephrase_profanity(text): | |
| """ | |
| Rephrase text containing profanity | |
| """ | |
| if not models_loaded: | |
| return "Models not loaded yet. Please wait." | |
| try: | |
| # Rephrase using the detoxification model | |
| inputs = t5_tokenizer(text, return_tensors="pt", truncation=True, max_length=512) | |
| if torch.cuda.is_available(): | |
| inputs = inputs.to(device) | |
| # Use more conservative generation settings with error handling | |
| try: | |
| outputs = t5_model.generate( | |
| **inputs, | |
| max_length=512, | |
| num_beams=4, # Reduced from 5 to be more memory-efficient | |
| early_stopping=True, | |
| no_repeat_ngram_size=2, | |
| length_penalty=1.0 | |
| ) | |
| rephrased_text = t5_tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Verify the output is reasonable | |
| if not rephrased_text or len(rephrased_text) < 3: | |
| logger.warning(f"T5 model produced unusable output: '{rephrased_text}'") | |
| return text # Return original if output is too short | |
| return rephrased_text.strip() | |
| except RuntimeError as e: | |
| # Handle potential CUDA out of memory error | |
| if "CUDA out of memory" in str(e): | |
| logger.warning("CUDA out of memory in T5 model. Trying with smaller beam size...") | |
| # Try again with smaller beam size | |
| outputs = t5_model.generate( | |
| **inputs, | |
| max_length=512, | |
| num_beams=2, # Use smaller beam size | |
| early_stopping=True | |
| ) | |
| rephrased_text = t5_tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| return rephrased_text.strip() | |
| else: | |
| raise e # Re-raise if it's not a memory issue | |
| except Exception as e: | |
| error_msg = f"Error in rephrasing: {str(e)}" | |
| logger.error(error_msg) | |
| return text # Return original text if rephrasing fails | |
| def text_to_speech(text): | |
| """ | |
| Convert text to speech using SpeechT5 | |
| """ | |
| if not models_loaded: | |
| return None | |
| try: | |
| # Create a temporary file path to save the audio | |
| temp_file = f"temp_tts_output_{int(time.time())}.wav" | |
| # Process the text input | |
| inputs = tts_processor(text=text, return_tensors="pt") | |
| if torch.cuda.is_available(): | |
| inputs = inputs.to(device) | |
| # Generate speech with a fixed speaker embedding | |
| speech = tts_model.generate_speech( | |
| inputs["input_ids"], | |
| speaker_embeddings, | |
| vocoder=vocoder | |
| ) | |
| # Convert from PyTorch tensor to NumPy array | |
| speech_np = speech.cpu().numpy() | |
| # Save as WAV file (sampling rate is 16kHz for SpeechT5) | |
| write_wav(temp_file, 16000, speech_np) | |
| return temp_file | |
| except Exception as e: | |
| error_msg = f"Error in text-to-speech conversion: {str(e)}" | |
| logger.error(error_msg) | |
| return None | |
| def text_analysis(input_text, threshold=0.5): | |
| """ | |
| Analyze text for profanity with adjustable threshold | |
| """ | |
| if not models_loaded: | |
| return "Models not loaded yet. Please wait for initialization to complete.", None, None | |
| try: | |
| # Detect profanity with the given threshold | |
| result = detect_profanity(input_text, threshold=threshold) | |
| # Handle error case | |
| if "error" in result: | |
| return result["error"], None, None | |
| # Process results | |
| if result["profanity"]: | |
| clean_text = rephrase_profanity(input_text) | |
| profane_words_str = ", ".join(result["profane_words"]) | |
| toxicity_score = result["score"] | |
| classification = ( | |
| "Severe Toxicity" if toxicity_score >= 0.7 else | |
| "Moderate Toxicity" if toxicity_score >= 0.5 else | |
| "Mild Toxicity" if toxicity_score >= 0.35 else | |
| "Minimal Toxicity" if toxicity_score >= 0.2 else | |
| "No Toxicity" | |
| ) | |
| # Generate audio for the rephrased text | |
| audio_output = text_to_speech(clean_text) | |
| return ( | |
| f"Profanity Score: {result['score']:.4f}\n\n" | |
| f"Profane: {result['profanity']}\n" | |
| f"Classification: {classification}\n" | |
| f"Detected Profane Words: {profane_words_str}\n\n" | |
| f"Reworded: {clean_text}" | |
| ), result["highlighted_text"], audio_output | |
| else: | |
| # If no profanity detected, just convert the original text to speech | |
| audio_output = text_to_speech(input_text) | |
| return ( | |
| f"Profanity Score: {result['score']:.4f}\n" | |
| f"Profane: {result['profanity']}\n" | |
| f"Classification: No Toxicity" | |
| ), None, audio_output | |
| except Exception as e: | |
| error_msg = f"Error in text analysis: {str(e)}\n{traceback.format_exc()}" | |
| logger.error(error_msg) | |
| return error_msg, None, None | |
| def analyze_audio(audio_path, threshold=0.5): | |
| """ | |
| Analyze audio for profanity with adjustable threshold | |
| """ | |
| if not models_loaded: | |
| return "Models not loaded yet. Please wait for initialization to complete.", None, None | |
| if not audio_path: | |
| return "No audio provided.", None, None | |
| try: | |
| # Transcribe audio | |
| result = whisper_model.transcribe(audio_path, fp16=torch.cuda.is_available()) | |
| text = result["text"] | |
| # Detect profanity with user-defined threshold | |
| analysis = detect_profanity(text, threshold=threshold) | |
| # Handle error case | |
| if "error" in analysis: | |
| return f"Error during analysis: {analysis['error']}\nTranscription: {text}", None, None | |
| if analysis["profanity"]: | |
| clean_text = rephrase_profanity(text) | |
| else: | |
| clean_text = text | |
| # Generate audio for the rephrased text | |
| audio_output = text_to_speech(clean_text) | |
| return ( | |
| f"Transcription: {text}\n\n" | |
| f"Profanity Score: {analysis['score']:.4f}\n" | |
| f"Profane: {'Yes' if analysis['profanity'] else 'No'}\n" | |
| f"Classification: {'Severe Toxicity' if analysis['score'] >= 0.7 else 'Moderate Toxicity' if analysis['score'] >= 0.5 else 'Mild Toxicity' if analysis['score'] >= 0.35 else 'Minimal Toxicity' if analysis['score'] >= 0.2 else 'No Toxicity'}\n" | |
| f"Profane Words: {', '.join(analysis['profane_words']) if analysis['profanity'] else 'None'}\n\n" | |
| f"Reworded: {clean_text}" | |
| ), analysis["highlighted_text"] if analysis["profanity"] else None, audio_output | |
| except Exception as e: | |
| error_msg = f"Error in audio analysis: {str(e)}\n{traceback.format_exc()}" | |
| logger.error(error_msg) | |
| return error_msg, None, None | |
| # Global variables to store streaming results | |
| stream_results = { | |
| "transcript": "", | |
| "profanity_info": "", | |
| "clean_text": "", | |
| "audio_output": None | |
| } | |
| def process_stream_chunk(audio_chunk): | |
| """Process an audio chunk from the streaming interface""" | |
| global stream_results, processing_active | |
| if not processing_active or not models_loaded: | |
| return stream_results["transcript"], stream_results["profanity_info"], stream_results["clean_text"], stream_results["audio_output"] | |
| try: | |
| # The format of audio_chunk from Gradio streaming can vary | |
| # It can be: (numpy_array, sample_rate), (filepath, sample_rate, numpy_array) or just numpy_array | |
| # Let's handle all possible cases | |
| if audio_chunk is None: | |
| # No audio received | |
| return stream_results["transcript"], stream_results["profanity_info"], stream_results["clean_text"], stream_results["audio_output"] | |
| # Different Gradio versions return different formats | |
| temp_file = None | |
| if isinstance(audio_chunk, tuple): | |
| if len(audio_chunk) == 2: | |
| # Format: (numpy_array, sample_rate) | |
| samples, sample_rate = audio_chunk | |
| temp_file = f"temp_stream_{int(time.time())}.wav" | |
| write_wav(temp_file, sample_rate, samples) | |
| elif len(audio_chunk) == 3: | |
| # Format: (filepath, sample_rate, numpy_array) | |
| filepath, sample_rate, samples = audio_chunk | |
| # Use the provided filepath if it exists | |
| if os.path.exists(filepath): | |
| temp_file = filepath | |
| else: | |
| # Create our own file | |
| temp_file = f"temp_stream_{int(time.time())}.wav" | |
| write_wav(temp_file, sample_rate, samples) | |
| elif isinstance(audio_chunk, np.ndarray): | |
| # Just a numpy array, assume sample rate of 16000 for Whisper | |
| samples = audio_chunk | |
| sample_rate = 16000 | |
| temp_file = f"temp_stream_{int(time.time())}.wav" | |
| write_wav(temp_file, sample_rate, samples) | |
| elif isinstance(audio_chunk, str) and os.path.exists(audio_chunk): | |
| # It's a filepath | |
| temp_file = audio_chunk | |
| else: | |
| # Unknown format | |
| stream_results["profanity_info"] = f"Error: Unknown audio format: {type(audio_chunk)}" | |
| return stream_results["transcript"], stream_results["profanity_info"], stream_results["clean_text"], stream_results["audio_output"] | |
| # Make sure we have a valid file to process | |
| if not temp_file or not os.path.exists(temp_file): | |
| stream_results["profanity_info"] = "Error: Failed to create audio file for processing" | |
| return stream_results["transcript"], stream_results["profanity_info"], stream_results["clean_text"], stream_results["audio_output"] | |
| # Process with Whisper | |
| result = whisper_model.transcribe(temp_file, fp16=torch.cuda.is_available()) | |
| transcript = result["text"].strip() | |
| # Skip processing if transcript is empty | |
| if not transcript: | |
| # Clean up temp file if we created it | |
| if temp_file and temp_file.startswith("temp_stream_") and os.path.exists(temp_file): | |
| try: | |
| os.remove(temp_file) | |
| except: | |
| pass | |
| # Return current state, but update profanity info | |
| stream_results["profanity_info"] = "No speech detected. Keep talking..." | |
| return stream_results["transcript"], stream_results["profanity_info"], stream_results["clean_text"], stream_results["audio_output"] | |
| # Update transcript | |
| stream_results["transcript"] = transcript | |
| # Analyze for profanity | |
| analysis = detect_profanity(transcript, threshold=0.5) | |
| # Check if profanity was detected | |
| if analysis.get("profanity", False): | |
| profane_words = ", ".join(analysis.get("profane_words", [])) | |
| stream_results["profanity_info"] = f"Profanity Detected (Score: {analysis['score']:.2f})\nProfane Words: {profane_words}" | |
| # Rephrase to clean text | |
| clean_text = rephrase_profanity(transcript) | |
| stream_results["clean_text"] = clean_text | |
| # Create audio from cleaned text | |
| audio_file = text_to_speech(clean_text) | |
| if audio_file: | |
| stream_results["audio_output"] = audio_file | |
| else: | |
| stream_results["profanity_info"] = f"No Profanity Detected (Score: {analysis['score']:.2f})" | |
| stream_results["clean_text"] = transcript | |
| # Use original text for audio if no profanity | |
| audio_file = text_to_speech(transcript) | |
| if audio_file: | |
| stream_results["audio_output"] = audio_file | |
| # Clean up temporary file if we created it | |
| if temp_file and temp_file.startswith("temp_stream_") and os.path.exists(temp_file): | |
| try: | |
| os.remove(temp_file) | |
| except: | |
| pass | |
| return stream_results["transcript"], stream_results["profanity_info"], stream_results["clean_text"], stream_results["audio_output"] | |
| except Exception as e: | |
| error_msg = f"Error processing streaming audio: {str(e)}\n{traceback.format_exc()}" | |
| logger.error(error_msg) | |
| # Update profanity info with error message | |
| stream_results["profanity_info"] = f"Error: {str(e)}" | |
| return stream_results["transcript"], stream_results["profanity_info"], stream_results["clean_text"], stream_results["audio_output"] | |
| def start_streaming(): | |
| """Start the real-time audio processing""" | |
| global processing_active, stream_results | |
| if not models_loaded: | |
| return "Models not loaded yet. Please wait for initialization to complete." | |
| if processing_active: | |
| return "Streaming is already active." | |
| # Reset results | |
| stream_results = { | |
| "transcript": "", | |
| "profanity_info": "Waiting for audio input...", | |
| "clean_text": "", | |
| "audio_output": None | |
| } | |
| processing_active = True | |
| logger.info("Started real-time audio processing") | |
| return "Started real-time audio processing. Speak into your microphone." | |
| def stop_streaming(): | |
| """Stop the real-time audio processing""" | |
| global processing_active | |
| if not processing_active: | |
| return "Streaming is not active." | |
| processing_active = False | |
| return "Stopped real-time audio processing." | |
| def create_ui(): | |
| """Create the Gradio UI""" | |
| # Simple CSS for styling | |
| css = """ | |
| /* Fix for dark mode text visibility */ | |
| .dark .gr-input, | |
| .dark textarea, | |
| .dark .gr-textbox, | |
| .dark [data-testid="textbox"] { | |
| color: white !important; | |
| background-color: #2c303b !important; | |
| } | |
| .dark .gr-box, | |
| .dark .gr-form, | |
| .dark .gr-panel, | |
| .dark .gr-block { | |
| color: white !important; | |
| } | |
| /* Highlighted text container - with dark mode fixes */ | |
| .highlighted-text { | |
| border: 1px solid #ddd; | |
| border-radius: 5px; | |
| padding: 10px; | |
| margin: 10px 0; | |
| background-color: #f9f9f9; | |
| font-family: sans-serif; | |
| max-height: 300px; | |
| overflow-y: auto; | |
| color: #333 !important; /* Ensure text is dark for light mode */ | |
| } | |
| /* Dark mode specific styling for highlighted text */ | |
| .dark .highlighted-text { | |
| background-color: #2c303b !important; | |
| color: #ffffff !important; | |
| border-color: #4a4f5a !important; | |
| } | |
| /* Make sure text in the highlighted container remains visible in both themes */ | |
| .highlighted-text, .dark .highlighted-text { | |
| color-scheme: light dark; | |
| } | |
| /* Loading animation */ | |
| .loading { | |
| display: inline-block; | |
| width: 20px; | |
| height: 20px; | |
| border: 3px solid rgba(0,0,0,.3); | |
| border-radius: 50%; | |
| border-top-color: #3498db; | |
| animation: spin 1s ease-in-out infinite; | |
| } | |
| @keyframes spin { | |
| to { transform: rotate(360deg); } | |
| } | |
| """ | |
| # Create a custom theme based on Soft but explicitly set to light mode | |
| light_theme = gr.themes.Soft( | |
| primary_hue="blue", | |
| secondary_hue="blue", | |
| neutral_hue="gray" | |
| ) | |
| # Set theme to light mode and disable theme switching | |
| with gr.Blocks(css=css, theme=light_theme, analytics_enabled=False) as ui: | |
| # Model initialization | |
| init_status = gr.State("") | |
| gr.Markdown( | |
| """ | |
| # Profanity Detection & Replacement System | |
| Detect, rephrase, and listen to cleaned content from text or audio! | |
| """, | |
| elem_classes="header" | |
| ) | |
| # The rest of your UI code remains unchanged... | |
| # Initialize models button with status indicators | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| init_button = gr.Button("Initialize Models", variant="primary") | |
| init_output = gr.Textbox(label="Initialization Status", interactive=False) | |
| with gr.Column(scale=1): | |
| model_status = gr.HTML( | |
| """<div style="text-align: center; padding: 5px;"> | |
| <p><b>Model Status:</b> <span style="color: #e74c3c;">Not Loaded</span></p> | |
| </div>""" | |
| ) | |
| # Global sensitivity slider | |
| sensitivity = gr.Slider( | |
| minimum=0.2, | |
| maximum=0.95, | |
| value=0.5, | |
| step=0.05, | |
| label="Profanity Detection Sensitivity", | |
| info="Lower values are more permissive, higher values are more strict" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| gr.Markdown("### Choose an Input Method") | |
| # Text Analysis | |
| with gr.Tabs(): | |
| with gr.TabItem("Text Analysis", elem_id="text-tab"): | |
| with gr.Row(): | |
| text_input = gr.Textbox( | |
| label="Enter Text", | |
| placeholder="Type your text here...", | |
| lines=5, | |
| elem_classes="textbox" | |
| ) | |
| with gr.Row(): | |
| text_button = gr.Button("Analyze Text", variant="primary") | |
| clear_button = gr.Button("Clear", variant="secondary") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| text_output = gr.Textbox(label="Results", lines=10) | |
| highlighted_output = gr.HTML(label="Detected Profanity", elem_classes="highlighted-text") | |
| with gr.Column(scale=1): | |
| text_audio_output = gr.Audio(label="Rephrased Audio", type="filepath") | |
| # Audio Analysis | |
| with gr.TabItem("Audio Analysis", elem_id="audio-tab"): | |
| gr.Markdown("### Upload or Record Audio") | |
| audio_input = gr.Audio( | |
| label="Audio Input", | |
| type="filepath", | |
| sources=["microphone", "upload"] | |
| #waveform_options=gr.WaveformOptions(waveform_color="#4a90e2") | |
| ) | |
| with gr.Row(): | |
| audio_button = gr.Button("Analyze Audio", variant="primary") | |
| clear_audio_button = gr.Button("Clear", variant="secondary") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| audio_output = gr.Textbox(label="Results", lines=10, show_copy_button=True) | |
| audio_highlighted_output = gr.HTML(label="Detected Profanity", elem_classes="highlighted-text") | |
| with gr.Column(scale=1): | |
| clean_audio_output = gr.Audio(label="Rephrased Audio", type="filepath") | |
| # Real-time Streaming | |
| with gr.TabItem("Real-time Streaming", elem_id="streaming-tab"): | |
| gr.Markdown("### Real-time Audio Processing") | |
| gr.Markdown("Enable real-time audio processing to filter profanity as you speak.") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| start_stream_button = gr.Button("Start Real-time Processing", variant="primary") | |
| stop_stream_button = gr.Button("Stop Real-time Processing", variant="secondary") | |
| stream_status = gr.Textbox(label="Streaming Status", value="Inactive", interactive=False) | |
| # Add microphone input specifically for streaming | |
| stream_audio_input = gr.Audio( | |
| label="Streaming Microphone Input", | |
| type="filepath", | |
| sources=["microphone"], | |
| streaming=True | |
| #waveform_options=gr.WaveformOptions(waveform_color="#4a90e2") | |
| ) | |
| with gr.Column(scale=2): | |
| # Add elements to display streaming results | |
| stream_transcript = gr.Textbox(label="Live Transcription", lines=2) | |
| stream_profanity_info = gr.Textbox(label="Profanity Detection", lines=2) | |
| stream_clean_text = gr.Textbox(label="Clean Text", lines=2) | |
| # Element to play the clean audio | |
| stream_audio_output = gr.Audio(label="Clean Audio Output", type="filepath") | |
| gr.Markdown(""" | |
| ### How Real-time Streaming Works | |
| 1. Click "Start Real-time Processing" to begin | |
| 2. Use the microphone input to speak | |
| 3. The system will process audio in real-time, detect and clean profanity | |
| 4. You'll see the transcription, profanity info, and clean output appear above | |
| 5. Click "Stop Real-time Processing" when finished | |
| Note: This feature requires microphone access and may have some latency. | |
| """) | |
| # Event handlers | |
| def update_model_status(status_text): | |
| """Update both the status text and the visual indicator""" | |
| if "successfully" in status_text.lower(): | |
| status_html = """<div style="text-align: center; padding: 5px;"> | |
| <p><b>Model Status:</b> <span style="color: #2ecc71;">Loaded ✓</span></p> | |
| </div>""" | |
| elif "error" in status_text.lower(): | |
| status_html = """<div style="text-align: center; padding: 5px;"> | |
| <p><b>Model Status:</b> <span style="color: #e74c3c;">Error ✗</span></p> | |
| </div>""" | |
| else: | |
| status_html = """<div style="text-align: center; padding: 5px;"> | |
| <p><b>Model Status:</b> <span style="color: #f39c12;">Loading...</span></p> | |
| </div>""" | |
| return status_text, status_html | |
| init_button.click( | |
| lambda: update_model_status("Loading models, please wait..."), | |
| inputs=[], | |
| outputs=[init_output, model_status] | |
| ).then( | |
| load_models, | |
| inputs=[], | |
| outputs=[init_output] | |
| ).then( | |
| update_model_status, | |
| inputs=[init_output], | |
| outputs=[init_output, model_status] | |
| ) | |
| text_button.click( | |
| text_analysis, | |
| inputs=[text_input, sensitivity], | |
| outputs=[text_output, highlighted_output, text_audio_output] | |
| ) | |
| clear_button.click( | |
| lambda: [None, None, None], | |
| inputs=None, | |
| outputs=[text_input, highlighted_output, text_audio_output] | |
| ) | |
| audio_button.click( | |
| analyze_audio, | |
| inputs=[audio_input, sensitivity], | |
| outputs=[audio_output, audio_highlighted_output, clean_audio_output] | |
| ) | |
| clear_audio_button.click( | |
| lambda: [None, None, None, None], | |
| inputs=None, | |
| outputs=[audio_input, audio_output, audio_highlighted_output, clean_audio_output] | |
| ) | |
| start_stream_button.click( | |
| start_streaming, | |
| inputs=[], | |
| outputs=[stream_status] | |
| ) | |
| stop_stream_button.click( | |
| stop_streaming, | |
| inputs=[], | |
| outputs=[stream_status] | |
| ) | |
| # Connect the streaming audio input to our processing function | |
| # First function to debug the audio chunk format | |
| def debug_audio_format(audio_chunk): | |
| """Debug function to log audio format""" | |
| format_info = f"Type: {type(audio_chunk)}" | |
| if isinstance(audio_chunk, tuple): | |
| format_info += f", Length: {len(audio_chunk)}" | |
| for i, item in enumerate(audio_chunk): | |
| format_info += f", Item {i} type: {type(item)}" | |
| logger.info(f"Audio chunk format: {format_info}") | |
| return audio_chunk | |
| # Use the stream method with preprocessor for debugging | |
| stream_audio_input.stream( | |
| fn=process_stream_chunk, | |
| inputs=[stream_audio_input], | |
| outputs=[stream_transcript, stream_profanity_info, stream_clean_text, stream_audio_output], | |
| preprocess=debug_audio_format | |
| ) | |
| return ui | |
| if __name__ == "__main__": | |
| # Set environment variable to avoid OpenMP conflicts | |
| os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE' | |
| # Create and launch the UI | |
| ui = create_ui() | |
| ui.launch(server_name="0.0.0.0", share=True) |