ylacombe commited on
Commit
4ae5d02
·
1 Parent(s): e31acce

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +118 -37
app.py CHANGED
@@ -1,10 +1,15 @@
 
1
  from gradio_client import Client
2
  import torch
3
  import nltk # we'll use this to split into sentences
4
  import numpy as np
5
  from transformers import BarkModel, AutoProcessor
6
  nltk.download('punkt')
 
7
  import gradio as gr
 
 
 
8
 
9
  def _grab_best_device(use_gpu=True):
10
  if torch.cuda.device_count() > 0 and use_gpu:
@@ -15,11 +20,9 @@ def _grab_best_device(use_gpu=True):
15
 
16
  device = _grab_best_device()
17
 
18
-
19
- BATCH_SIZE = 8
20
  SYST_PROMPT="""You're the storyteller, crafting a short tale for young listeners. Please abide by these guidelines:
21
- - Keep your sentences concise and easy to understand.
22
- - There should be only the narrator speaking. No dialogues."""
23
 
24
  #story_prompt = "A panda going on an adventure with a caterpillar. This is a story teaching a wonderful life lesson."
25
  story_prompt = "A princess breaks free from a dragon's grip. This evocates women empowerement and freedom."
@@ -27,10 +30,24 @@ temperature = 0.9
27
  top_p = 0.6
28
  repetition_penalty = 1.2
29
 
30
- text_client = Client("https://ysharma-explore-llamav2-with-tgi.hf.space/")
31
- image_client = Client("prodia/fast-stable-diffusion")
32
- image_negative_prompt = "ultrarealistic, soft lighting, 8k, ugly"
33
- image_positive_prompt = ". Cartoon, anime"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  image_seed = 9
35
 
36
  processor = AutoProcessor.from_pretrained("suno/bark")
@@ -41,29 +58,97 @@ voice_preset = "v2/en_speaker_6"
41
 
42
  # convert to bettertransformer
43
  model = model.to_bettertransformer()
 
44
 
45
  # enable CPU offload
46
  model.enable_cpu_offload()
47
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
 
50
  def generate_audio_and_image(story_prompt, voice_preset=voice_preset):
51
 
52
 
53
- story = text_client.predict(
54
- story_prompt,
55
- SYST_PROMPT,
56
- temperature,
57
- 4096,
58
- temperature,
59
- repetition_penalty,
60
- api_name="/chat"
61
- )
62
 
 
63
 
64
  model_input = story.replace("\n", " ").strip()
65
  model_input = nltk.sent_tokenize(model_input)
66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
  pieces = []
68
  for i in range(0, len(model_input), BATCH_SIZE):
69
  inputs = model_input[BATCH_SIZE*i:min(BATCH_SIZE*(i+1), len(model_input))]
@@ -71,35 +156,26 @@ def generate_audio_and_image(story_prompt, voice_preset=voice_preset):
71
  if len(inputs) != 0:
72
  inputs = processor(inputs, voice_preset=voice_preset)
73
 
74
- speech_output = model.generate(**inputs.to(device)).cpu().numpy()
75
 
 
 
 
76
  pieces += [*speech_output, silence.copy()]
77
 
78
- #job_img = image_client.submit(
79
- # story_prompt+image_positive_prompt, # str in 'parameter_11' Textbox component
80
- # image_negative_prompt, # str in 'parameter_12' Textbox component
81
- # "absolutereality_v181.safetensors [3d9d4d2b]", # str (Option from: ['absolutereality_V16.safetensors [37db0fc3]', 'absolutereality_v181.safetensors [3d9d4d2b]', 'analog-diffusion-1.0.ckpt [9ca13f02]', 'anythingv3_0-pruned.ckpt [2700c435]', 'anything-v4.5-pruned.ckpt [65745d25]', 'anythingV5_PrtRE.safetensors [893e49b9]', 'AOM3A3_orangemixs.safetensors [9600da17]', 'childrensStories_v13D.safetensors [9dfaabcb]', 'childrensStories_v1SemiReal.safetensors [a1c56dbb]', 'childrensStories_v1ToonAnime.safetensors [2ec7b88b]', 'cyberrealistic_v33.safetensors [82b0d085]', 'deliberate_v2.safetensors [10ec4b29]', 'deliberate_v3.safetensors [afd9d2d4]', 'dreamlike-anime-1.0.safetensors [4520e090]', 'dreamlike-diffusion-1.0.safetensors [5c9fd6e0]', 'dreamlike-photoreal-2.0.safetensors [fdcf65e7]', 'dreamshaper_6BakedVae.safetensors [114c8abb]', 'dreamshaper_7.safetensors [5cf5ae06]', 'dreamshaper_8.safetensors [9d40847d]', 'edgeOfRealism_eorV20.safetensors [3ed5de15]', 'EimisAnimeDiffusion_V1.ckpt [4f828a15]', 'elldreths-vivid-mix.safetensors [342d9d26]', 'epicrealism_naturalSinRC1VAE.safetensors [90a4c676]', 'ICantBelieveItsNotPhotography_seco.safetensors [4e7a3dfd]', 'juggernaut_aftermath.safetensors [5e20c455]', 'lyriel_v16.safetensors [68fceea2]', 'mechamix_v10.safetensors [ee685731]', 'meinamix_meinaV9.safetensors [2ec66ab0]', 'meinamix_meinaV11.safetensors [b56ce717]', 'openjourney_V4.ckpt [ca2f377f]', 'portraitplus_V1.0.safetensors [1400e684]', 'Realistic_Vision_V1.4-pruned-fp16.safetensors [8d21810b]', 'Realistic_Vision_V2.0.safetensors [79587710]', 'Realistic_Vision_V4.0.safetensors [29a7afaa]', 'Realistic_Vision_V5.0.safetensors [614d1063]', 'redshift_diffusion-V10.safetensors [1400e684]', 'revAnimated_v122.safetensors [3f4fefd9]', 'rundiffusionFX25D_v10.safetensors [cd12b0ee]', 'rundiffusionFX_v10.safetensors [cd4e694d]', 'sdv1_4.ckpt [7460a6fa]', 'v1-5-pruned-emaonly.safetensors [d7049739]', 'shoninsBeautiful_v10.safetensors [25d8c546]', 'theallys-mix-ii-churned.safetensors [5d9225a4]', 'timeless-1.0.ckpt [7c4971d4]', 'toonyou_beta6.safetensors [980f6b15]'])
82
- # 25,
83
- # "Euler a",
84
- # 7,
85
- # 512,
86
- # 512,
87
- # image_seed,
88
- # "https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png,https://raw.githubusercontent.com/gradio-app/gradio/main/test/test_files/bus.png", # str (path to directory with images and a file associating images with captions called captions.json)
89
- # fn_index=0
90
- #)
91
-
92
 
93
- #img = job_img.result()
 
94
 
95
- return story, (sampling_rate, np.concatenate(pieces))
96
 
97
 
98
 
99
 
100
  # Gradio blocks demo
101
  with gr.Blocks() as demo_blocks:
102
- gr.Markdown("""<h1 align="center">🐶Children story<</h1>""")
103
  gr.HTML("""<h3 style="text-align:center;">📢Audio Streaming powered by Gradio (v3.40.0 onwards)🦾! </h3>""")
104
  with gr.Group():
105
  with gr.Row():
@@ -114,11 +190,16 @@ with gr.Blocks() as demo_blocks:
114
 
115
  with gr.Row():
116
  btn = gr.Button("Create a story")
117
-
 
 
 
118
  with gr.Row():
119
  out_audio = gr.Audio(
120
  streaming=False, autoplay=True) # needed to stream output audio
121
  out_text = gr.Text()
122
- btn.click(generate_audio_and_image, [inp_text], [out_text, out_audio] ) #[out_audio]) #, out_count])
 
 
123
 
124
  demo_blocks.queue().launch(debug=True)
 
1
+ from huggingface_hub import InferenceClient
2
  from gradio_client import Client
3
  import torch
4
  import nltk # we'll use this to split into sentences
5
  import numpy as np
6
  from transformers import BarkModel, AutoProcessor
7
  nltk.download('punkt')
8
+
9
  import gradio as gr
10
+ import os
11
+ os.environ["GRADIO_TEMP_DIR"] = "/home/yoach/spaces/tmp"
12
+
13
 
14
  def _grab_best_device(use_gpu=True):
15
  if torch.cuda.device_count() > 0 and use_gpu:
 
20
 
21
  device = _grab_best_device()
22
 
 
 
23
  SYST_PROMPT="""You're the storyteller, crafting a short tale for young listeners. Please abide by these guidelines:
24
+ - Keep your sentences short, concise and easy to understand.
25
+ - There should be only the narrator speaking. If there are dialogues, they should be indirect."""
26
 
27
  #story_prompt = "A panda going on an adventure with a caterpillar. This is a story teaching a wonderful life lesson."
28
  story_prompt = "A princess breaks free from a dragon's grip. This evocates women empowerement and freedom."
 
30
  top_p = 0.6
31
  repetition_penalty = 1.2
32
 
33
+ TIMEOUT = int(os.environ.get("TIMEOUT", 45))
34
+
35
+ temperature = 0.9
36
+ top_p = 0.6
37
+ repetition_penalty = 1.2
38
+
39
+
40
+
41
+
42
+ # TODO: requirements: accelerate optimum
43
+
44
+ text_client = InferenceClient(
45
+ "mistralai/Mistral-7B-Instruct-v0.1",
46
+ timeout=TIMEOUT,
47
+ )
48
+ image_client = Client("https://openskyml-fast-sdxl-stable-diffusion-xl.hf.space/--replicas/ffe2bn2dk/")
49
+ image_negative_prompt = "ultrarealistic, soft lighting, 8k, ugly, text, blurry"
50
+ image_positive_prompt = ""
51
  image_seed = 9
52
 
53
  processor = AutoProcessor.from_pretrained("suno/bark")
 
58
 
59
  # convert to bettertransformer
60
  model = model.to_bettertransformer()
61
+ BATCH_SIZE = 16
62
 
63
  # enable CPU offload
64
  model.enable_cpu_offload()
65
 
66
+ # MISTRAL ONLY
67
+ default_system_understand_message = (
68
+ "I understand, I am a Mistral chatbot."
69
+ )
70
+ system_understand_message = os.environ.get(
71
+ "SYSTEM_UNDERSTAND_MESSAGE", default_system_understand_message
72
+ )
73
+
74
+ # Mistral formatter
75
+ def format_prompt(message):
76
+ prompt = (
77
+ "<s>[INST]" + SYST_PROMPT + "[/INST]" + system_understand_message + "</s>"
78
+ )
79
+ prompt += f"[INST] {message} [/INST]"
80
+ return prompt
81
+
82
+
83
+ def generate_story(
84
+ story_prompt,
85
+ temperature=0.9,
86
+ max_new_tokens=1024,
87
+ top_p=0.95,
88
+ repetition_penalty=1.0,):
89
+
90
+ temperature = float(temperature)
91
+ if temperature < 1e-2:
92
+ temperature = 1e-2
93
+ top_p = float(top_p)
94
+
95
+ generate_kwargs = dict(
96
+ temperature=temperature,
97
+ max_new_tokens=max_new_tokens,
98
+ top_p=top_p,
99
+ repetition_penalty=repetition_penalty,
100
+ do_sample=True,
101
+ seed=42,
102
+ )
103
+
104
+ try:
105
+ output = text_client.text_generation(
106
+ format_prompt(story_prompt),
107
+ **generate_kwargs,
108
+ details=False,
109
+ return_full_text=False,
110
+ )
111
+ except Exception as e:
112
+ if "Too Many Requests" in str(e):
113
+ print("ERROR: Too many requests on mistral client")
114
+ gr.Warning("Unfortunately Mistral is unable to process")
115
+ output = "Unfortuanately I am not able to process your request now, too many people are asking me !"
116
+ elif "Model not loaded on the server" in str(e):
117
+ print("ERROR: Mistral server down")
118
+ gr.Warning("Unfortunately Mistral LLM is unable to process")
119
+ output = "Unfortuanately I am not able to process your request now, I have problem with Mistral!"
120
+ else:
121
+ print("Unhandled Exception: ", str(e))
122
+ gr.Warning("Unfortunately Mistral is unable to process")
123
+ output = "I do not know what happened but I could not understand you."
124
+ return output
125
+
126
+ return output
127
 
128
 
129
  def generate_audio_and_image(story_prompt, voice_preset=voice_preset):
130
 
131
 
132
+ story = generate_story(story_prompt)
 
 
 
 
 
 
 
 
133
 
134
+ print(story)
135
 
136
  model_input = story.replace("\n", " ").strip()
137
  model_input = nltk.sent_tokenize(model_input)
138
 
139
+ print("text generated - now calling for image")
140
+ job_img = image_client.submit(
141
+ story_prompt+image_positive_prompt, # str in 'parameter_11' Textbox component
142
+ image_negative_prompt, # str in 'parameter_12' Textbox component
143
+ 25,
144
+ 7,
145
+ 1024,
146
+ 1024,
147
+ image_seed,
148
+ fn_index=0,
149
+ )
150
+ print("image called - now generating audio")
151
+
152
  pieces = []
153
  for i in range(0, len(model_input), BATCH_SIZE):
154
  inputs = model_input[BATCH_SIZE*i:min(BATCH_SIZE*(i+1), len(model_input))]
 
156
  if len(inputs) != 0:
157
  inputs = processor(inputs, voice_preset=voice_preset)
158
 
159
+ speech_output, output_lengths = model.generate(**inputs.to(device), return_output_lengths=True, min_eos_p=0.2)
160
 
161
+ speech_output = [output[:length].cpu().numpy() for (output,length) in zip(speech_output, output_lengths)]
162
+
163
+ print(f"{i}-th part generated")
164
  pieces += [*speech_output, silence.copy()]
165
 
166
+ print("Calling image")
 
 
 
 
 
 
 
 
 
 
 
 
 
167
 
168
+ # TODO: if error catch it
169
+ img = job_img.result()
170
 
171
+ return story, (sampling_rate, np.concatenate(pieces)), img
172
 
173
 
174
 
175
 
176
  # Gradio blocks demo
177
  with gr.Blocks() as demo_blocks:
178
+ gr.Markdown("""<h1 align="center">🐶Children story</h1>""")
179
  gr.HTML("""<h3 style="text-align:center;">📢Audio Streaming powered by Gradio (v3.40.0 onwards)🦾! </h3>""")
180
  with gr.Group():
181
  with gr.Row():
 
190
 
191
  with gr.Row():
192
  btn = gr.Button("Create a story")
193
+
194
+ with gr.Row():
195
+ with gr.Column(scale=1):
196
+ image_output = gr.Image(elem_id="gallery")
197
  with gr.Row():
198
  out_audio = gr.Audio(
199
  streaming=False, autoplay=True) # needed to stream output audio
200
  out_text = gr.Text()
201
+ btn.click(generate_audio_and_image, [inp_text], [out_text, out_audio, image_output] ) #[out_audio]) #, out_count])
202
+
203
+
204
 
205
  demo_blocks.queue().launch(debug=True)