Tringles's picture
Update app.py
2dc1a61
import os
import io
import requests
import numpy as np
import gradio as gr
from PIL import Image
from loguru import logger
riffusion_url = os.environ["MODEL_ENDPOINT"]
def bta(audio_content: bytes):
return (44100, np.frombuffer(audio_content, dtype=np.int16))
def bti(image_content: bytes):
return Image.open(io.BytesIO(image_content))
def riffusion(query):
res = requests.get(url=riffusion_url, params={'query': query}).json()
audio_url = res['audio_url']
image_url = res['image_url']
image_content = []
audio_content = []
logger.info(f'request query = {query} \naudio url = {audio_url} \nimage url = {image_url}')
for a in audio_url:
audio_content.append(requests.get(a).content)
for i in image_url:
image_content.append(requests.get(i).content)
return bti(image_content[0]), bta(audio_content[0]), bti(image_content[1]), bta(audio_content[1]), bti(image_content[2]), bta(audio_content[2])
demo = gr.Interface(fn=riffusion, inputs="text", outputs=[
"image", "audio", "image", "audio", "image", "audio"
])
if __name__ == "__main__":
demo.launch()