diff --git a/README.md b/README.md index 2c13254cadd2424eca8efeb8a3aca455d60b85c4..bbfe7a7020020b8faf07527e0f1bf24746902dd4 100644 --- a/README.md +++ b/README.md @@ -1,14 +1,14 @@ --- title: Audio Flamingo 2 -emoji: 🐢 -colorFrom: purple +emoji: 🏃 +colorFrom: yellow colorTo: red sdk: gradio sdk_version: 5.15.0 app_file: app.py pinned: false license: apache-2.0 -short_description: NVIDIA Audio Flamingo 2 Demo +short_description: Audio Flamingo 2 Demo --- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..dc3418ef74c7b6c4ad0f9f02bc839cd78c2e9395 --- /dev/null +++ b/app.py @@ -0,0 +1,232 @@ +import os +import yaml +import json +import torch +import spaces +import librosa +import argparse +import numpy as np +import gradio as gr +from tqdm import tqdm +import soundfile as sf +from pydub import AudioSegment +from safetensors.torch import load_file +from huggingface_hub import snapshot_download + +from data.data import get_audiotext_dataloader +from src.factory import create_model_and_transforms +from train.train_utils import Dict2Class, get_autocast, get_cast_dtype + +def int16_to_float32(x): + return (x / 32767.0).astype(np.float32) + +def float32_to_int16(x): + x = np.clip(x, a_min=-1., a_max=1.) + return (x * 32767.).astype(np.int16) + +os.environ["TOKENIZERS_PARALLELISM"] = "false" + +snapshot_download(repo_id="SreyanG-NVIDIA/audio-flamingo-2", local_dir="./") + +config = yaml.load(open("configs/inference.yaml"), Loader=yaml.FullLoader) + +data_config = config['data_config'] +model_config = config['model_config'] +clap_config = config['clap_config'] +args = Dict2Class(config['train_config']) + +model, tokenizer = create_model_and_transforms( + **model_config, + clap_config=clap_config, + use_local_files=args.offline, + gradient_checkpointing=args.gradient_checkpointing, + freeze_lm_embeddings=args.freeze_lm_embeddings, +) + +device_id = 0 +model = model.to(device_id) +model.eval() + +# Load metadata +with open("safe_ckpt/metadata.json", "r") as f: + metadata = json.load(f) + +# Reconstruct the full state_dict +state_dict = {} + +# Load each SafeTensors chunk +for chunk_name in metadata: + chunk_path = f"safe_ckpt/{chunk_name}.safetensors" + chunk_tensors = load_file(chunk_path) + + # Merge tensors into state_dict + state_dict.update(chunk_tensors) + +x,y = model.load_state_dict(state_dict, False) + +autocast = get_autocast( + args.precision, cache_enabled=(not args.fsdp) +) + +cast_dtype = get_cast_dtype(args.precision) + +def get_num_windows(T, sr): + + window_length = int(float(clap_config["window_length"]) * sr) + window_overlap = int(float(clap_config["window_overlap"]) * sr) + max_num_window = int(clap_config["max_num_window"]) + + num_windows = 1 + if T <= window_length: + num_windows = 1 + full_length = window_length + elif T >= (max_num_window * window_length - (max_num_window - 1) * window_overlap): + num_windows = max_num_window + full_length = (max_num_window * window_length - (max_num_window - 1) * window_overlap) + else: + num_windows = 1 + int(np.ceil((T - window_length) / float(window_length - window_overlap))) + full_length = num_windows * window_length - (num_windows - 1) * window_overlap + + return num_windows, full_length + + +def read_audio(file_path, target_sr=16000, duration=30.0, start=0.0): + + if file_path.endswith('.mp3'): + audio = AudioSegment.from_file(file_path) + if len(audio) > (start + duration) * 1000: + audio = audio[start * 1000:(start + duration) * 1000] + + if audio.frame_rate != target_sr: + audio = audio.set_frame_rate(target_sr) + + if audio.channels > 1: + audio = audio.set_channels(1) + + data = np.array(audio.get_array_of_samples()) + if audio.sample_width == 2: + data = data.astype(np.float32) / np.iinfo(np.int16).max + elif audio.sample_width == 4: + data = data.astype(np.float32) / np.iinfo(np.int32).max + else: + raise ValueError("Unsupported bit depth: {}".format(audio.sample_width)) + + else: + with sf.SoundFile(file_path) as audio: + original_sr = audio.samplerate + channels = audio.channels + + max_frames = int((start + duration) * original_sr) + + audio.seek(int(start * original_sr)) + frames_to_read = min(max_frames, len(audio)) + data = audio.read(frames_to_read) + + if data.max() > 1 or data.min() < -1: + data = data / max(abs(data.max()), abs(data.min())) + + if original_sr != target_sr: + if channels == 1: + data = librosa.resample(data.flatten(), orig_sr=original_sr, target_sr=target_sr) + else: + data = librosa.resample(data.T, orig_sr=original_sr, target_sr=target_sr)[0] + else: + if channels != 1: + data = data.T[0] + + if data.min() >= 0: + data = 2 * data / abs(data.max()) - 1.0 + else: + data = data / max(abs(data.max()), abs(data.min())) + + assert len(data.shape) == 1, data.shape + return data + +def load_audio(audio_path): + + sr = 16000 + window_length = int(float(clap_config["window_length"]) * sr) + window_overlap = int(float(clap_config["window_overlap"]) * sr) + max_num_window = int(clap_config["max_num_window"]) + duration = max_num_window * (clap_config["window_length"] - clap_config["window_overlap"]) + clap_config["window_overlap"] + + audio_data = read_audio(audio_path, sr, duration, 0.0) # hard code audio start to 0.0 + T = len(audio_data) + num_windows, full_length = get_num_windows(T, sr) + + # pads to the nearest multiple of window_length + if full_length > T: + audio_data = np.append(audio_data, np.zeros(full_length - T)) + + audio_data = audio_data.reshape(1, -1) + audio_data_tensor = torch.from_numpy(int16_to_float32(float32_to_int16(audio_data))).float() + + audio_clips = [] + audio_embed_mask = torch.ones(num_windows) + for i in range(num_windows): + start = i * (window_length - window_overlap) + audio_data_tensor_this = audio_data_tensor[:, start:start+window_length] + audio_clips.append(audio_data_tensor_this) + + if len(audio_clips) < max_num_window: + audio_clips = audio_clips[:max_num_window] + audio_embed_mask = audio_embed_mask[:max_num_window] + + audio_clips = torch.cat(audio_clips) + + return audio_clips, audio_embed_mask + +@spaces.GPU +def predict(filepath, question): + + audio_clips, audio_embed_mask = load_audio(filepath) + audio_clips = audio_clips.to(device_id, dtype=cast_dtype, non_blocking=True) + audio_embed_mask = audio_embed_mask.to(device_id, dtype=cast_dtype, non_blocking=True) + + text_prompt = str(question).lower() + text_output = str(question).lower() + + sample = f"