Spaces:
Running
on
Zero
Running
on
Zero
| import spaces | |
| from typing import Optional | |
| import logging | |
| import time | |
| import threading | |
| import torch | |
| import librosa | |
| from transformers import pipeline, AutoProcessor, AutoModelForCausalLM, Pipeline | |
| from accelerate import Accelerator | |
| # Set up logging | |
| logging.basicConfig( | |
| level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" | |
| ) | |
| logger = logging.getLogger(__name__) | |
| try: | |
| import subprocess | |
| subprocess.run( | |
| "pip install flash-attn --no-build-isolation", | |
| env={"FLASH_ATTENTION_SKIP_CUDA_BUILD": "TRUE"}, | |
| shell=True, | |
| ) | |
| logger.info("Flash Attention installed successfully.") | |
| USE_FA = True | |
| except: | |
| USE_FA = False | |
| logger.warning("Flash Attention not available. Using standard attention instead.") | |
| # Model constants | |
| MODEL_ID = "JacobLinCool/whisper-large-v3-turbo-common_voice_19_0-zh-TW" | |
| PHI_MODEL_ID = "JacobLinCool/Phi-4-multimodal-instruct-commonvoice-zh-tw" | |
| # Model instances (initialized lazily) | |
| pipe: Optional[Pipeline] = None | |
| phi_model = None | |
| phi_processor = None | |
| # Lock for thread-safe model loading | |
| model_loading_lock = threading.Lock() | |
| def load_model() -> None: | |
| """ | |
| Load the Whisper model for transcription. | |
| Uses GPU if available. | |
| """ | |
| global pipe | |
| if pipe is not None: | |
| return # Model already loaded | |
| try: | |
| start_time = time.time() | |
| logger.info(f"Loading Whisper model {MODEL_ID}...") | |
| device = Accelerator().device | |
| pipe = pipeline("automatic-speech-recognition", model=MODEL_ID, device=device) | |
| logger.info( | |
| f"Model loaded successfully in {time.time() - start_time:.2f} seconds" | |
| ) | |
| except Exception as e: | |
| logger.error(f"Failed to load Whisper model: {str(e)}") | |
| raise | |
| def get_gpu_duration(audio: str) -> int: | |
| """ | |
| Calculate required GPU allocation time based on audio duration. | |
| Args: | |
| audio: Path to audio file | |
| Returns: | |
| GPU allocation time in seconds | |
| """ | |
| try: | |
| y, sr = librosa.load(audio) | |
| duration = librosa.get_duration(y=y, sr=sr) / 60.0 | |
| gpu_duration = max(1.0, (duration + 59.0) // 60.0) * 60.0 | |
| logger.info( | |
| f"Audio duration: {duration:.2f} min, Allocated GPU time: {gpu_duration:.2f} min" | |
| ) | |
| return int(gpu_duration) | |
| except Exception as e: | |
| logger.error(f"Failed to calculate GPU duration: {str(e)}") | |
| return 60 # Default to 1 minute if calculation fails | |
| def transcribe_audio_local(audio: str) -> str: | |
| """ | |
| Transcribe audio using the Whisper model. | |
| Args: | |
| audio: Path to audio file | |
| Returns: | |
| Transcribed text | |
| """ | |
| try: | |
| logger.info(f"Transcribing audio with Whisper: {audio}") | |
| if pipe is None: | |
| load_model() | |
| out = pipe(audio, return_timestamps=True) | |
| return out.get("text", "No transcription generated") | |
| except Exception as e: | |
| logger.error(f"Whisper transcription error: {str(e)}") | |
| raise | |
| def load_phi_model() -> None: | |
| """ | |
| Load the Phi-4 model and processor. | |
| Uses GPU with Flash Attention if available. | |
| """ | |
| global phi_model, phi_processor | |
| if phi_model is not None and phi_processor is not None: | |
| return # Model already loaded | |
| try: | |
| start_time = time.time() | |
| logger.info(f"Loading Phi-4 model {PHI_MODEL_ID}...") | |
| phi_processor = AutoProcessor.from_pretrained( | |
| PHI_MODEL_ID, trust_remote_code=True | |
| ) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| dtype = torch.bfloat16 if USE_FA else torch.float32 | |
| attn_implementation = "flash_attention_2" if USE_FA else "sdpa" | |
| phi_model = AutoModelForCausalLM.from_pretrained( | |
| PHI_MODEL_ID, | |
| torch_dtype=dtype, | |
| _attn_implementation=attn_implementation, | |
| trust_remote_code=True, | |
| ).to(device) | |
| logger.info( | |
| f"Phi-4 model loaded successfully in {time.time() - start_time:.2f} seconds" | |
| ) | |
| except Exception as e: | |
| logger.error(f"Failed to load Phi-4 model: {str(e)}") | |
| raise | |
| def transcribe_audio_phi(audio: str) -> str: | |
| """ | |
| Transcribe audio using the Phi-4 model. | |
| Args: | |
| audio: Path to audio file | |
| Returns: | |
| Transcribed text | |
| """ | |
| try: | |
| logger.info(f"Transcribing audio with Phi-4: {audio}") | |
| load_phi_model() | |
| # Load and resample audio to 16kHz | |
| y, sr = librosa.load(audio, sr=16000) | |
| # Prepare the user message and generate the prompt | |
| user_message = { | |
| "role": "user", | |
| "content": "<|audio_1|> Transcribe the audio clip into text.", | |
| } | |
| prompt = phi_processor.tokenizer.apply_chat_template( | |
| [user_message], tokenize=False, add_generation_prompt=True | |
| ) | |
| # Build inputs for the model | |
| inputs = phi_processor(text=prompt, audios=[(y, sr)], return_tensors="pt") | |
| inputs = { | |
| k: v.to(phi_model.device) if hasattr(v, "to") else v | |
| for k, v in inputs.items() | |
| } | |
| # Generate transcription without gradients | |
| with torch.no_grad(): | |
| generated_ids = phi_model.generate( | |
| **inputs, | |
| eos_token_id=phi_processor.tokenizer.eos_token_id, | |
| max_new_tokens=256, # Increased for longer transcriptions | |
| do_sample=False, | |
| ) | |
| # Decode the generated token IDs into text | |
| transcription = phi_processor.decode( | |
| generated_ids[0, inputs["input_ids"].shape[1] :], | |
| skip_special_tokens=True, | |
| clean_up_tokenization_spaces=False, | |
| ) | |
| logger.info(f"Phi-4 transcription completed successfully") | |
| return transcription | |
| except Exception as e: | |
| logger.error(f"Phi-4 transcription error: {str(e)}") | |
| raise | |
| def preload_models() -> None: | |
| """ | |
| Preload models into memory to reduce cold start time. | |
| This function can be called at application startup. | |
| """ | |
| try: | |
| logger.info("Preloading models to reduce cold start time") | |
| # Load Whisper model first as it's the default | |
| load_model() | |
| # Then load Phi model | |
| load_phi_model() | |
| logger.info("All models preloaded successfully") | |
| except Exception as e: | |
| logger.error(f"Error during model preloading: {str(e)}") | |