Nepjune's picture
Update app.py
059c7dc verified
raw
history blame
2.03 kB
from TTS.api import TTS
from transformers import BlipProcessor, BlipForConditionalGeneration, AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
import torchaudio
from torchaudio.transforms import Resample
import torch
import gradio as gr
# Initialize TTS model from TTS library
tts_model_path = "tts_models/multilingual/multi-dataset/xtts_v1"
tts = TTS(tts_model_path, gpu=True)
# Initialize Blip model for image captioning
model_id = "dblasko/blip-dalle3-img2prompt"
blip_model = BlipForConditionalGeneration.from_pretrained(model_id)
blip_processor = BlipProcessor.from_pretrained(model_id)
def generate_caption(image):
# Generate caption from image using Blip model
inputs = blip_processor(images=image, return_tensors="pt")
pixel_values = inputs.pixel_values
generated_ids = blip_model.generate(pixel_values=pixel_values, max_length=50)
generated_caption = blip_processor.batch_decode(generated_ids, skip_special_tokens=True, temperature=0.8, top_k=40, top_p=0.9)[0]
# Use TTS model to convert generated caption to audio
tts.tts_to_file(text=generated_caption,
file_path="generated_audio.wav",
speaker_wav="/path/to/target/speaker.wav",
language="en")
# Resample the audio to match the expected sampling rate
waveform, sample_rate = torchaudio.load("generated_audio.wav")
resampler = Resample(orig_freq=sample_rate, new_freq=24_000)
waveform_resampled = resampler(waveform)
# Save the resampled audio
torchaudio.save("generated_audio_resampled.wav", waveform_resampled, 24_000)
return generated_caption, "generated_audio_resampled.wav"
# Create a Gradio interface with an image input, a textbox output, a button, and an audio player
demo = gr.Interface(
fn=generate_caption,
inputs=gr.Image(),
outputs=[
gr.Textbox(label="Generated caption"),
gr.Button("Convert to Audio"),
gr.Audio(type="player", label="Generated Audio")
],
live=True
)
demo.launch(share=True)