Tringles commited on
Commit
861aa19
·
1 Parent(s): 40062f2

feat: riffusion-demo

Browse files
Files changed (1) hide show
  1. main.py +35 -0
main.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import io
2
+ import requests
3
+ import numpy as np
4
+ import gradio as gr
5
+ from PIL import Image
6
+ from loguru import logger
7
+
8
+ riffusion_url = 'https://rrsowhl2evenygl2el4rm3hplm0mgelx.lambda-url.ap-northeast-2.on.aws/'
9
+
10
+ def bta(audio_content: bytes):
11
+ return (44100, np.frombuffer(audio_content, dtype=np.int16))
12
+
13
+ def bti(image_content: bytes):
14
+ return Image.open(io.BytesIO(image_content))
15
+
16
+
17
+ def riffusion(query):
18
+ res = requests.get(url=riffusion_url, params={'query': query}).json()
19
+ audio_url = res['audio_url']
20
+ image_url = res['image_url']
21
+ image_content = []
22
+ audio_content = []
23
+ logger.info(f'request query = {query} \naudio url = {audio_url} \nimage url = {image_url}')
24
+ for a in audio_url:
25
+ audio_content.append(requests.get(a).content)
26
+ for i in image_url:
27
+ image_content.append(requests.get(i).content)
28
+ return bti(image_content[0]), bta(audio_content[0]), bti(image_content[1]), bta(audio_content[1]), bti(image_content[2]), bta(audio_content[2])
29
+
30
+ demo = gr.Interface(fn=riffusion, inputs="text", outputs=[
31
+ "image", "audio", "image", "audio", "image", "audio"
32
+ ])
33
+
34
+ if __name__ == "__main__":
35
+ demo.launch(share=True)