Ffftdtd5dtft commited on
Commit
3ae460d
·
verified ·
1 Parent(s): 03c9d48

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +360 -816
app.py CHANGED
@@ -1,823 +1,367 @@
1
  import os
2
- import pickle
 
3
  import torch
4
- from PIL import Image
5
- from diffusers import (
6
- StableDiffusionPipeline,
7
- StableDiffusionImg2ImgPipeline,
8
- FluxPipeline,
9
- DiffusionPipeline,
10
- DPMSolverMultistepScheduler,
11
- )
12
  from transformers import (
13
- pipeline as transformers_pipeline,
14
- AutoModelForCausalLM,
15
- AutoTokenizer,
16
- GPT2Tokenizer,
17
- GPT2Model,
18
- AutoModel
19
  )
20
- from audiocraft.models import musicgen
 
 
 
 
 
 
21
  import gradio as gr
22
- from huggingface_hub import snapshot_download, HfApi, HfFolder
23
- import io
24
- import time
25
- from tqdm import tqdm
26
- from google.cloud import storage
27
- import json
28
-
29
- hf_token = os.getenv("HF_TOKEN")
30
- gcs_credentials = json.loads(os.getenv("GCS_CREDENTIALS"))
31
- gcs_bucket_name = os.getenv("GCS_BUCKET_NAME")
32
-
33
- HfFolder.save_token(hf_token)
34
-
35
- storage_client = storage.Client.from_service_account_info(gcs_credentials)
36
- bucket = storage_client.bucket(gcs_bucket_name)
37
-
38
-
39
- def load_object_from_gcs(blob_name):
40
- blob = bucket.blob(blob_name)
41
- if blob.exists():
42
- return pickle.loads(blob.download_as_bytes())
43
- return None
44
-
45
-
46
- def save_object_to_gcs(blob_name, obj):
47
- blob = bucket.blob(blob_name)
48
- blob.upload_from_string(pickle.dumps(obj))
49
-
50
-
51
- def get_model_or_download(model_id, blob_name, loader_func):
52
- model = load_object_from_gcs(blob_name)
53
- if model:
54
- return model
55
- try:
56
- with tqdm(total=1, desc=f"Downloading {model_id}") as pbar:
57
- model = loader_func(model_id, torch_dtype=torch.float16)
58
- pbar.update(1)
59
- save_object_to_gcs(blob_name, model)
60
- return model
61
- except Exception as e:
62
- print(f"Failed to load or save model: {e}")
63
- return None
64
-
65
-
66
- def generate_image(prompt):
67
- blob_name = f"diffusers/generated_image:{prompt}"
68
- image_bytes = load_object_from_gcs(blob_name)
69
- if not image_bytes:
70
- try:
71
- with tqdm(total=1, desc="Generating image") as pbar:
72
- image = text_to_image_pipeline(prompt).images[0]
73
- pbar.update(1)
74
- buffered = io.BytesIO()
75
- image.save(buffered, format="JPEG")
76
- image_bytes = buffered.getvalue()
77
- save_object_to_gcs(blob_name, image_bytes)
78
- except Exception as e:
79
- print(f"Failed to generate image: {e}")
80
- return None
81
- return image_bytes
82
-
83
-
84
- def edit_image_with_prompt(image_bytes, prompt, strength=0.75):
85
- blob_name = f"diffusers/edited_image:{prompt}:{strength}"
86
- edited_image_bytes = load_object_from_gcs(blob_name)
87
- if not edited_image_bytes:
88
- try:
89
- image = Image.open(io.BytesIO(image_bytes))
90
- with tqdm(total=1, desc="Editing image") as pbar:
91
- edited_image = img2img_pipeline(
92
- prompt=prompt, image=image, strength=strength
93
- ).images[0]
94
- pbar.update(1)
95
- buffered = io.BytesIO()
96
- edited_image.save(buffered, format="JPEG")
97
- edited_image_bytes = buffered.getvalue()
98
- save_object_to_gcs(blob_name, edited_image_bytes)
99
- except Exception as e:
100
- print(f"Failed to edit image: {e}")
101
- return None
102
- return edited_image_bytes
103
-
104
-
105
- def generate_song(prompt, duration=10):
106
- blob_name = f"music/generated_song:{prompt}:{duration}"
107
- song_bytes = load_object_from_gcs(blob_name)
108
- if not song_bytes:
109
- try:
110
- with tqdm(total=1, desc="Generating song") as pbar:
111
- song = music_gen(prompt, duration=duration)
112
- pbar.update(1)
113
- song_bytes = song[0].getvalue()
114
- save_object_to_gcs(blob_name, song_bytes)
115
- except Exception as e:
116
- print(f"Failed to generate song: {e}")
117
- return None
118
- return song_bytes
119
-
120
-
121
- def generate_text(prompt):
122
- blob_name = f"transformers/generated_text:{prompt}"
123
- text = load_object_from_gcs(blob_name)
124
- if not text:
125
- try:
126
- with tqdm(total=1, desc="Generating text") as pbar:
127
- text = text_gen_pipeline(prompt, max_new_tokens=256)[0][
128
- "generated_text"
129
- ].strip()
130
- pbar.update(1)
131
- save_object_to_gcs(blob_name, text)
132
- except Exception as e:
133
- print(f"Failed to generate text: {e}")
134
- return None
135
- return text
136
-
137
-
138
- def generate_flux_image(prompt):
139
- blob_name = f"diffusers/generated_flux_image:{prompt}"
140
- flux_image_bytes = load_object_from_gcs(blob_name)
141
- if not flux_image_bytes:
142
- try:
143
- with tqdm(total=1, desc="Generating FLUX image") as pbar:
144
- flux_image = flux_pipeline(
145
- prompt,
146
- guidance_scale=0.0,
147
- num_inference_steps=4,
148
- max_length=256,
149
- generator=torch.Generator("cpu").manual_seed(0),
150
- ).images[0]
151
- pbar.update(1)
152
- buffered = io.BytesIO()
153
- flux_image.save(buffered, format="JPEG")
154
- flux_image_bytes = buffered.getvalue()
155
- save_object_to_gcs(blob_name, flux_image_bytes)
156
- except Exception as e:
157
- print(f"Failed to generate flux image: {e}")
158
- return None
159
- return flux_image_bytes
160
-
161
-
162
- def generate_code(prompt):
163
- blob_name = f"transformers/generated_code:{prompt}"
164
- code = load_object_from_gcs(blob_name)
165
- if not code:
166
- try:
167
- with tqdm(total=1, desc="Generating code") as pbar:
168
- inputs = starcoder_tokenizer.encode(prompt, return_tensors="pt")
169
- outputs = starcoder_model.generate(inputs, max_new_tokens=256)
170
- code = starcoder_tokenizer.decode(outputs[0])
171
- pbar.update(1)
172
- save_object_to_gcs(blob_name, code)
173
- except Exception as e:
174
- print(f"Failed to generate code: {e}")
175
- return None
176
- return code
177
-
178
-
179
- def test_model_meta_llama():
180
- blob_name = "transformers/meta_llama_test_response"
181
- response = load_object_from_gcs(blob_name)
182
- if not response:
183
- try:
184
- messages = [
185
- {
186
- "role": "system",
187
- "content": "You are a pirate chatbot who always responds in pirate speak!",
188
- },
189
- {"role": "user", "content": "Who are you?"},
190
- ]
191
- with tqdm(total=1, desc="Testing Meta-Llama") as pbar:
192
- response = meta_llama_pipeline(messages, max_new_tokens=256)[0][
193
- "generated_text"
194
- ].strip()
195
- pbar.update(1)
196
- save_object_to_gcs(blob_name, response)
197
- except Exception as e:
198
- print(f"Failed to test Meta-Llama: {e}")
199
- return None
200
- return response
201
-
202
-
203
- def generate_image_sdxl(prompt):
204
- blob_name = f"diffusers/generated_image_sdxl:{prompt}"
205
- image_bytes = load_object_from_gcs(blob_name)
206
- if not image_bytes:
207
- try:
208
- with tqdm(total=1, desc="Generating SDXL image") as pbar:
209
- image = base(
210
- prompt=prompt,
211
- num_inference_steps=40,
212
- denoising_end=0.8,
213
- output_type="latent",
214
- ).images
215
- image = refiner(
216
- prompt=prompt,
217
- num_inference_steps=40,
218
- denoising_start=0.8,
219
- image=image,
220
- ).images[0]
221
- pbar.update(1)
222
- buffered = io.BytesIO()
223
- image.save(buffered, format="JPEG")
224
- image_bytes = buffered.getvalue()
225
- save_object_to_gcs(blob_name, image_bytes)
226
- except Exception as e:
227
- print(f"Failed to generate SDXL image: {e}")
228
- return None
229
- return image_bytes
230
-
231
-
232
- def generate_musicgen_melody(prompt):
233
- blob_name = f"music/generated_musicgen_melody:{prompt}"
234
- song_bytes = load_object_from_gcs(blob_name)
235
- if not song_bytes:
236
- try:
237
- with tqdm(total=1, desc="Generating MusicGen melody") as pbar:
238
- melody, sr = torchaudio.load("./assets/bach.mp3")
239
- wav = music_gen_melody.generate_with_chroma(
240
- [prompt], melody[None].expand(3, -1, -1), sr
241
- )
242
- pbar.update(1)
243
- song_bytes = wav[0].getvalue()
244
- save_object_to_gcs(blob_name, song_bytes)
245
- except Exception as e:
246
- print(f"Failed to generate MusicGen melody: {e}")
247
- return None
248
- return song_bytes
249
-
250
-
251
- def generate_musicgen_large(prompt):
252
- blob_name = f"music/generated_musicgen_large:{prompt}"
253
- song_bytes = load_object_from_gcs(blob_name)
254
- if not song_bytes:
255
- try:
256
- with tqdm(total=1, desc="Generating MusicGen large") as pbar:
257
- wav = music_gen_large.generate([prompt])
258
- pbar.update(1)
259
- song_bytes = wav[0].getvalue()
260
- save_object_to_gcs(blob_name, song_bytes)
261
- except Exception as e:
262
- print(f"Failed to generate MusicGen large: {e}")
263
- return None
264
- return song_bytes
265
-
266
-
267
- def transcribe_audio(audio_sample):
268
- blob_name = f"transformers/transcribed_audio:{hash(audio_sample.tobytes())}"
269
- text = load_object_from_gcs(blob_name)
270
- if not text:
271
- try:
272
- with tqdm(total=1, desc="Transcribing audio") as pbar:
273
- text = whisper_pipeline(audio_sample.copy(), batch_size=8)["text"]
274
- pbar.update(1)
275
- save_object_to_gcs(blob_name, text)
276
- except Exception as e:
277
- print(f"Failed to transcribe audio: {e}")
278
- return None
279
- return text
280
-
281
-
282
- def generate_mistral_instruct(prompt):
283
- blob_name = f"transformers/generated_mistral_instruct:{prompt}"
284
- response = load_object_from_gcs(blob_name)
285
- if not response:
286
- try:
287
- conversation = [{"role": "user", "content": prompt}]
288
- with tqdm(total=1, desc="Generating Mistral Instruct response") as pbar:
289
- inputs = mistral_instruct_tokenizer.apply_chat_template(
290
- conversation,
291
- tools=tools,
292
- add_generation_prompt=True,
293
- return_dict=True,
294
- return_tensors="pt",
295
- )
296
- outputs = mistral_instruct_model.generate(
297
- **inputs, max_new_tokens=1000
298
- )
299
- response = mistral_instruct_tokenizer.decode(
300
- outputs[0], skip_special_tokens=True
301
- )
302
- pbar.update(1)
303
- save_object_to_gcs(blob_name, response)
304
- except Exception as e:
305
- print(f"Failed to generate Mistral Instruct response: {e}")
306
- return None
307
- return response
308
-
309
-
310
- def generate_mistral_nemo(prompt):
311
- blob_name = f"transformers/generated_mistral_nemo:{prompt}"
312
- response = load_object_from_gcs(blob_name)
313
- if not response:
314
- try:
315
- conversation = [{"role": "user", "content": prompt}]
316
- with tqdm(total=1, desc="Generating Mistral Nemo response") as pbar:
317
- inputs = mistral_nemo_tokenizer.apply_chat_template(
318
- conversation,
319
- tools=tools,
320
- add_generation_prompt=True,
321
- return_dict=True,
322
- return_tensors="pt",
323
- )
324
- outputs = mistral_nemo_model.generate(**inputs, max_new_tokens=1000)
325
- response = mistral_nemo_tokenizer.decode(
326
- outputs[0], skip_special_tokens=True
327
- )
328
- pbar.update(1)
329
- save_object_to_gcs(blob_name, response)
330
- except Exception as e:
331
- print(f"Failed to generate Mistral Nemo response: {e}")
332
- return None
333
- return response
334
-
335
-
336
- def generate_gpt2_xl(prompt):
337
- blob_name = f"transformers/generated_gpt2_xl:{prompt}"
338
- response = load_object_from_gcs(blob_name)
339
- if not response:
340
- try:
341
- with tqdm(total=1, desc="Generating GPT-2 XL response") as pbar:
342
- inputs = gpt2_xl_tokenizer(prompt, return_tensors="pt")
343
- outputs = gpt2_xl_model(**inputs)
344
- response = gpt2_xl_tokenizer.decode(
345
- outputs[0][0], skip_special_tokens=True
346
- )
347
- pbar.update(1)
348
- save_object_to_gcs(blob_name, response)
349
- except Exception as e:
350
- print(f"Failed to generate GPT-2 XL response: {e}")
351
- return None
352
  return response
353
 
354
-
355
- def store_user_question(question):
356
- blob_name = "user_questions.txt"
357
- blob = bucket.blob(blob_name)
358
- if blob.exists():
359
- blob.download_to_filename("user_questions.txt")
360
- with open("user_questions.txt", "a") as f:
361
- f.write(question + "\n")
362
- blob.upload_from_filename("user_questions.txt")
363
-
364
-
365
- def retrain_models():
366
- pass
367
-
368
-
369
- def generate_text_to_video_ms_1_7b(prompt, num_frames=200):
370
- blob_name = f"diffusers/text_to_video_ms_1_7b:{prompt}:{num_frames}"
371
- video_bytes = load_object_from_gcs(blob_name)
372
- if not video_bytes:
373
- try:
374
- with tqdm(total=1, desc="Generating video") as pbar:
375
- video_frames = text_to_video_ms_1_7b_pipeline(
376
- prompt, num_inference_steps=25, num_frames=num_frames
377
- ).frames
378
- pbar.update(1)
379
- video_path = export_to_video(video_frames)
380
- with open(video_path, "rb") as f:
381
- video_bytes = f.read()
382
- save_object_to_gcs(blob_name, video_bytes)
383
- os.remove(video_path)
384
- except Exception as e:
385
- print(f"Failed to generate video: {e}")
386
- return None
387
- return video_bytes
388
-
389
-
390
- def generate_text_to_video_ms_1_7b_short(prompt):
391
- blob_name = f"diffusers/text_to_video_ms_1_7b_short:{prompt}"
392
- video_bytes = load_object_from_gcs(blob_name)
393
- if not video_bytes:
394
- try:
395
- with tqdm(total=1, desc="Generating short video") as pbar:
396
- video_frames = text_to_video_ms_1_7b_short_pipeline(
397
- prompt, num_inference_steps=25
398
- ).frames
399
- pbar.update(1)
400
- video_path = export_to_video(video_frames)
401
- with open(video_path, "rb") as f:
402
- video_bytes = f.read()
403
- save_object_to_gcs(blob_name, video_bytes)
404
- os.remove(video_path)
405
- except Exception as e:
406
- print(f"Failed to generate short video: {e}")
407
- return None
408
- return video_bytes
409
-
410
-
411
- text_to_image_pipeline = get_model_or_download(
412
- "stabilityai/stable-diffusion-2",
413
- "diffusers/text_to_image_model",
414
- StableDiffusionPipeline.from_pretrained,
415
- )
416
- img2img_pipeline = get_model_or_download(
417
- "CompVis/stable-diffusion-v1-4",
418
- "diffusers/img2img_model",
419
- StableDiffusionImg2ImgPipeline.from_pretrained,
420
- )
421
- flux_pipeline = get_model_or_download(
422
- "black-forest-labs/FLUX.1-schnell",
423
- "diffusers/flux_model",
424
- FluxPipeline.from_pretrained,
425
- )
426
- text_gen_pipeline = transformers_pipeline(
427
- "text-generation", model="google/gemma-2-9b", tokenizer="google/gemma-2-9b"
428
- )
429
- music_gen = (
430
- load_object_from_gcs("music/music_gen")
431
- or musicgen.MusicGen.get_pretrained("melody")
432
- )
433
- meta_llama_pipeline = get_model_or_download(
434
- "meta-llama/Meta-Llama-3.1-8B-Instruct",
435
- "transformers/meta_llama_model",
436
- transformers_pipeline,
437
- )
438
- starcoder_model = AutoModelForCausalLM.from_pretrained("bigcode/starcoder")
439
- starcoder_tokenizer = AutoTokenizer.from_pretrained("bigcode/starcoder")
440
-
441
- base = DiffusionPipeline.from_pretrained(
442
- "stabilityai/stable-diffusion-xl-base-1.0",
443
- torch_dtype=torch.float16,
444
- variant="fp16",
445
- use_safetensors=True,
446
- )
447
- refiner = DiffusionPipeline.from_pretrained(
448
- "stabilityai/stable-diffusion-xl-refiner-1.0",
449
- text_encoder_2=base.text_encoder_2,
450
- vae=base.vae,
451
- torch_dtype=torch.float16,
452
- use_safetensors=True,
453
- variant="fp16",
454
- )
455
- music_gen_melody = musicgen.MusicGen.get_pretrained("melody")
456
- music_gen_melody.set_generation_params(duration=8)
457
- music_gen_large = musicgen.MusicGen.get_pretrained("large")
458
- music_gen_large.set_generation_params(duration=8)
459
- whisper_pipeline = transformers_pipeline(
460
- "automatic-speech-recognition",
461
- model="openai/whisper-small",
462
- chunk_length_s=30,
463
- )
464
- mistral_instruct_model = AutoModelForCausalLM.from_pretrained(
465
- "mistralai/Mistral-Large-Instruct-2407",
466
- torch_dtype=torch.bfloat16,
467
- device_map="auto",
468
- )
469
- mistral_instruct_tokenizer = AutoTokenizer.from_pretrained(
470
- "mistralai/Mistral-Large-Instruct-2407"
471
- )
472
- mistral_nemo_model = AutoModelForCausalLM.from_pretrained(
473
- "mistralai/Mistral-Nemo-Instruct-2407",
474
- torch_dtype=torch.bfloat16,
475
- device_map="auto",
476
- )
477
- mistral_nemo_tokenizer = AutoTokenizer.from_pretrained(
478
- "mistralai/Mistral-Nemo-Instruct-2407"
479
- )
480
- gpt2_xl_tokenizer = GPT2Tokenizer.from_pretrained("gpt2-xl")
481
- gpt2_xl_model = GPT2Model.from_pretrained("gpt2-xl")
482
-
483
- llama_3_groq_70b_tool_use_pipeline = transformers_pipeline(
484
- "text-generation", model="Groq/Llama-3-Groq-70B-Tool-Use"
485
- )
486
- phi_3_5_mini_instruct_model = AutoModelForCausalLM.from_pretrained(
487
- "microsoft/Phi-3.5-mini-instruct", torch_dtype="auto", trust_remote_code=True
488
- )
489
- phi_3_5_mini_instruct_tokenizer = AutoTokenizer.from_pretrained(
490
- "microsoft/Phi-3.5-mini-instruct"
491
- )
492
- phi_3_5_mini_instruct_pipeline = transformers_pipeline(
493
- "text-generation",
494
- model=phi_3_5_mini_instruct_model,
495
- tokenizer=phi_3_5_mini_instruct_tokenizer,
496
- )
497
- meta_llama_3_1_8b_pipeline = transformers_pipeline(
498
- "text-generation",
499
- model="meta-llama/Meta-Llama-3.1-8B",
500
- model_kwargs={"torch_dtype": torch.bfloat16},
501
- )
502
- meta_llama_3_1_70b_pipeline = transformers_pipeline(
503
- "text-generation",
504
- model="meta-llama/Meta-Llama-3.1-70B",
505
- model_kwargs={"torch_dtype": torch.bfloat16},
506
- )
507
- medical_text_summarization_pipeline = transformers_pipeline(
508
- "summarization", model="your/medical_text_summarization_model"
509
- )
510
- bart_large_cnn_summarization_pipeline = transformers_pipeline(
511
- "summarization", model="facebook/bart-large-cnn"
512
- )
513
- flux_1_dev_pipeline = FluxPipeline.from_pretrained(
514
- "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16
515
- )
516
- flux_1_dev_pipeline.enable_model_cpu_offload()
517
- gemma_2_9b_pipeline = transformers_pipeline("text-generation", model="google/gemma-2-9b")
518
- gemma_2_9b_it_pipeline = transformers_pipeline(
519
- "text-generation",
520
- model="google/gemma-2-9b-it",
521
- model_kwargs={"torch_dtype": torch.bfloat16},
522
- )
523
- gemma_2_2b_pipeline = transformers_pipeline("text-generation", model="google/gemma-2-2b")
524
- gemma_2_2b_it_pipeline = transformers_pipeline(
525
- "text-generation",
526
- model="google/gemma-2-2b-it",
527
- model_kwargs={"torch_dtype": torch.bfloat16},
528
- )
529
- gemma_2_27b_tokenizer = AutoTokenizer.from_pretrained("google/gemma-2-27b")
530
- gemma_2_27b_model = AutoModelForCausalLM.from_pretrained("google/gemma-2-27b")
531
- gemma_2_27b_it_pipeline = transformers_pipeline(
532
- "text-generation",
533
- model="google/gemma-2-27b-it",
534
- model_kwargs={"torch_dtype": torch.bfloat16},
535
- )
536
- text_to_video_ms_1_7b_pipeline = DiffusionPipeline.from_pretrained(
537
- "damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16, variant="fp16"
538
- )
539
- text_to_video_ms_1_7b_pipeline.scheduler = DPMSolverMultistepScheduler.from_config(
540
- text_to_video_ms_1_7b_pipeline.scheduler.config
541
- )
542
- text_to_video_ms_1_7b_pipeline.enable_model_cpu_offload()
543
- text_to_video_ms_1_7b_pipeline.enable_vae_slicing()
544
- text_to_video_ms_1_7b_short_pipeline = DiffusionPipeline.from_pretrained(
545
- "damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16, variant="fp16"
546
- )
547
- text_to_video_ms_1_7b_short_pipeline.scheduler = (
548
- DPMSolverMultistepScheduler.from_config(
549
- text_to_video_ms_1_7b_short_pipeline.scheduler.config
550
- )
551
- )
552
- text_to_video_ms_1_7b_short_pipeline.enable_model_cpu_offload()
553
-
554
- tools = []
555
-
556
- gen_image_tab = gr.Interface(
557
- fn=generate_image,
558
- inputs=gr.Textbox(label="Prompt:"),
559
- outputs=gr.Image(type="pil"),
560
- title="Generate Image",
561
- )
562
- edit_image_tab = gr.Interface(
563
- fn=edit_image_with_prompt,
564
- inputs=[
565
- gr.Image(type="pil", label="Image:"),
566
- gr.Textbox(label="Prompt:"),
567
- gr.Slider(0.1, 1.0, 0.75, step=0.05, label="Strength:"),
568
- ],
569
- outputs=gr.Image(type="pil"),
570
- title="Edit Image",
571
- )
572
- generate_song_tab = gr.Interface(
573
- fn=generate_song,
574
- inputs=[
575
- gr.Textbox(label="Prompt:"),
576
- gr.Slider(5, 60, 10, step=1, label="Duration (s):"),
577
- ],
578
- outputs=gr.Audio(type="numpy"),
579
- title="Generate Songs",
580
- )
581
- generate_text_tab = gr.Interface(
582
- fn=generate_text,
583
- inputs=gr.Textbox(label="Prompt:"),
584
- outputs=gr.Textbox(label="Generated Text:"),
585
- title="Generate Text",
586
- )
587
- generate_flux_image_tab = gr.Interface(
588
- fn=generate_flux_image,
589
- inputs=gr.Textbox(label="Prompt:"),
590
- outputs=gr.Image(type="pil"),
591
- title="Generate FLUX Images",
592
- )
593
- generate_code_tab = gr.Interface(
594
- fn=generate_code,
595
- inputs=gr.Textbox(label="Prompt:"),
596
- outputs=gr.Textbox(label="Generated Code:"),
597
- title="Generate Code",
598
- )
599
- model_meta_llama_test_tab = gr.Interface(
600
- fn=test_model_meta_llama,
601
- inputs=None,
602
- outputs=gr.Textbox(label="Model Output:"),
603
- title="Test Meta-Llama",
604
- )
605
- generate_image_sdxl_tab = gr.Interface(
606
- fn=generate_image_sdxl,
607
- inputs=gr.Textbox(label="Prompt:"),
608
- outputs=gr.Image(type="pil"),
609
- title="Generate SDXL Image",
610
- )
611
- generate_musicgen_melody_tab = gr.Interface(
612
- fn=generate_musicgen_melody,
613
- inputs=gr.Textbox(label="Prompt:"),
614
- outputs=gr.Audio(type="numpy"),
615
- title="Generate MusicGen Melody",
616
- )
617
- generate_musicgen_large_tab = gr.Interface(
618
- fn=generate_musicgen_large,
619
- inputs=gr.Textbox(label="Prompt:"),
620
- outputs=gr.Audio(type="numpy"),
621
- title="Generate MusicGen Large",
622
- )
623
- transcribe_audio_tab = gr.Interface(
624
- fn=transcribe_audio,
625
- inputs=gr.Audio(type="numpy", label="Audio Sample:"),
626
- outputs=gr.Textbox(label="Transcribed Text:"),
627
- title="Transcribe Audio",
628
- )
629
- generate_mistral_instruct_tab = gr.Interface(
630
- fn=generate_mistral_instruct,
631
- inputs=gr.Textbox(label="Prompt:"),
632
- outputs=gr.Textbox(label="Mistral Instruct Response:"),
633
- title="Generate Mistral Instruct Response",
634
- )
635
- generate_mistral_nemo_tab = gr.Interface(
636
- fn=generate_mistral_nemo,
637
- inputs=gr.Textbox(label="Prompt:"),
638
- outputs=gr.Textbox(label="Mistral Nemo Response:"),
639
- title="Generate Mistral Nemo Response",
640
- )
641
- generate_gpt2_xl_tab = gr.Interface(
642
- fn=generate_gpt2_xl,
643
- inputs=gr.Textbox(label="Prompt:"),
644
- outputs=gr.Textbox(label="GPT-2 XL Response:"),
645
- title="Generate GPT-2 XL Response",
646
- )
647
- answer_question_minicpm_tab = gr.Interface(
648
- fn=answer_question_minicpm,
649
- inputs=[
650
- gr.Image(type="pil", label="Image:"),
651
- gr.Textbox(label="Question:"),
652
- ],
653
- outputs=gr.Textbox(label="MiniCPM Answer:"),
654
- title="Answer Question with MiniCPM",
655
- )
656
- llama_3_groq_70b_tool_use_tab = gr.Interface(
657
- fn=llama_3_groq_70b_tool_use_pipeline,
658
- inputs=[gr.Textbox(label="Prompt:")],
659
- outputs=gr.Textbox(label="Llama 3 Groq 70B Tool Use Response:"),
660
- title="Llama 3 Groq 70B Tool Use",
661
- )
662
- phi_3_5_mini_instruct_tab = gr.Interface(
663
- fn=phi_3_5_mini_instruct_pipeline,
664
- inputs=[gr.Textbox(label="Prompt:")],
665
- outputs=gr.Textbox(label="Phi 3.5 Mini Instruct Response:"),
666
- title="Phi 3.5 Mini Instruct",
667
- )
668
- meta_llama_3_1_8b_tab = gr.Interface(
669
- fn=meta_llama_3_1_8b_pipeline,
670
- inputs=[gr.Textbox(label="Prompt:")],
671
- outputs=gr.Textbox(label="Meta Llama 3.1 8B Response:"),
672
- title="Meta Llama 3.1 8B",
673
- )
674
- meta_llama_3_1_70b_tab = gr.Interface(
675
- fn=meta_llama_3_1_70b_pipeline,
676
- inputs=[gr.Textbox(label="Prompt:")],
677
- outputs=gr.Textbox(label="Meta Llama 3.1 70B Response:"),
678
- title="Meta Llama 3.1 70B",
679
- )
680
- medical_text_summarization_tab = gr.Interface(
681
- fn=medical_text_summarization_pipeline,
682
- inputs=[gr.Textbox(label="Medical Document:")],
683
- outputs=gr.Textbox(label="Medical Text Summarization:"),
684
- title="Medical Text Summarization",
685
- )
686
- bart_large_cnn_summarization_tab = gr.Interface(
687
- fn=bart_large_cnn_summarization_pipeline,
688
- inputs=[gr.Textbox(label="Article:")],
689
- outputs=gr.Textbox(label="Bart Large CNN Summarization:"),
690
- title="Bart Large CNN Summarization",
691
- )
692
- flux_1_dev_tab = gr.Interface(
693
- fn=flux_1_dev_pipeline,
694
- inputs=[gr.Textbox(label="Prompt:")],
695
- outputs=gr.Image(type="pil"),
696
- title="FLUX 1 Dev",
697
- )
698
- gemma_2_9b_tab = gr.Interface(
699
- fn=gemma_2_9b_pipeline,
700
- inputs=[gr.Textbox(label="Prompt:")],
701
- outputs=gr.Textbox(label="Gemma 2 9B Response:"),
702
- title="Gemma 2 9B",
703
- )
704
- gemma_2_9b_it_tab = gr.Interface(
705
- fn=gemma_2_9b_it_pipeline,
706
- inputs=[gr.Textbox(label="Prompt:")],
707
- outputs=gr.Textbox(label="Gemma 2 9B IT Response:"),
708
- title="Gemma 2 9B IT",
709
- )
710
- gemma_2_2b_tab = gr.Interface(
711
- fn=gemma_2_2b_pipeline,
712
- inputs=[gr.Textbox(label="Prompt:")],
713
- outputs=gr.Textbox(label="Gemma 2 2B Response:"),
714
- title="Gemma 2 2B",
715
- )
716
- gemma_2_2b_it_tab = gr.Interface(
717
- fn=gemma_2_2b_it_pipeline,
718
- inputs=[gr.Textbox(label="Prompt:")],
719
- outputs=gr.Textbox(label="Gemma 2 2B IT Response:"),
720
- title="Gemma 2 2B IT",
721
- )
722
-
723
-
724
- def generate_gemma_2_27b(prompt):
725
- input_ids = gemma_2_27b_tokenizer(prompt, return_tensors="pt")
726
- outputs = gemma_2_27b_model.generate(**input_ids, max_new_tokens=32)
727
- return gemma_2_27b_tokenizer.decode(outputs[0])
728
-
729
-
730
- gemma_2_27b_tab = gr.Interface(
731
- fn=generate_gemma_2_27b,
732
- inputs=[gr.Textbox(label="Prompt:")],
733
- outputs=gr.Textbox(label="Gemma 2 27B Response:"),
734
- title="Gemma 2 27B",
735
- )
736
- gemma_2_27b_it_tab = gr.Interface(
737
- fn=gemma_2_27b_it_pipeline,
738
- inputs=[gr.Textbox(label="Prompt:")],
739
- outputs=gr.Textbox(label="Gemma 2 27B IT Response:"),
740
- title="Gemma 2 27B IT",
741
- )
742
- text_to_video_ms_1_7b_tab = gr.Interface(
743
- fn=generate_text_to_video_ms_1_7b,
744
- inputs=[
745
- gr.Textbox(label="Prompt:"),
746
- gr.Slider(50, 200, 200, step=1, label="Number of Frames:"),
747
- ],
748
- outputs=gr.Video(),
749
- title="Text to Video MS 1.7B",
750
- )
751
- text_to_video_ms_1_7b_short_tab = gr.Interface(
752
- fn=generate_text_to_video_ms_1_7b_short,
753
- inputs=[gr.Textbox(label="Prompt:")],
754
- outputs=gr.Video(),
755
- title="Text to Video MS 1.7B Short",
756
- )
757
-
758
- app = gr.TabbedInterface(
759
- [
760
- gen_image_tab,
761
- edit_image_tab,
762
- generate_song_tab,
763
- generate_text_tab,
764
- generate_flux_image_tab,
765
- generate_code_tab,
766
- model_meta_llama_test_tab,
767
- generate_image_sdxl_tab,
768
- generate_musicgen_melody_tab,
769
- generate_musicgen_large_tab,
770
- transcribe_audio_tab,
771
- generate_mistral_instruct_tab,
772
- generate_mistral_nemo_tab,
773
- generate_gpt2_xl_tab,
774
- llama_3_groq_70b_tool_use_tab,
775
- phi_3_5_mini_instruct_tab,
776
- meta_llama_3_1_8b_tab,
777
- meta_llama_3_1_70b_tab,
778
- medical_text_summarization_tab,
779
- bart_large_cnn_summarization_tab,
780
- flux_1_dev_tab,
781
- gemma_2_9b_tab,
782
- gemma_2_9b_it_tab,
783
- gemma_2_2b_tab,
784
- gemma_2_2b_it_tab,
785
- gemma_2_27b_tab,
786
- gemma_2_27b_it_tab,
787
- text_to_video_ms_1_7b_tab,
788
- text_to_video_ms_1_7b_short_tab,
789
- ],
790
- [
791
- "Generate Image",
792
- "Edit Image",
793
- "Generate Song",
794
- "Generate Text",
795
- "Generate FLUX Image",
796
- "Generate Code",
797
- "Test Meta-Llama",
798
- "Generate SDXL Image",
799
- "Generate MusicGen Melody",
800
- "Generate MusicGen Large",
801
- "Transcribe Audio",
802
- "Generate Mistral Instruct Response",
803
- "Generate Mistral Nemo Response",
804
- "Generate GPT-2 XL Response",
805
- "Llama 3 Groq 70B Tool Use",
806
- "Phi 3.5 Mini Instruct",
807
- "Meta Llama 3.1 8B",
808
- "Meta Llama 3.1 70B",
809
- "Medical Text Summarization",
810
- "Bart Large CNN Summarization",
811
- "FLUX 1 Dev",
812
- "Gemma 2 9B",
813
- "Gemma 2 9B IT",
814
- "Gemma 2 2B",
815
- "Gemma 2 2B IT",
816
- "Gemma 2 27B",
817
- "Gemma 2 27B IT",
818
- "Text to Video MS 1.7B",
819
- "Text to Video MS 1.7B Short",
820
- ],
821
- )
822
-
823
- app.launch(share=True)
 
1
  import os
2
+ import uuid
3
+ import redis
4
  import torch
5
+ import scipy
 
 
 
 
 
 
 
6
  from transformers import (
7
+ pipeline, AutoTokenizer, AutoModelForCausalLM, AutoProcessor,
8
+ MusicgenForConditionalGeneration, WhisperProcessor, WhisperForConditionalGeneration,
9
+ MarianMTModel, MarianTokenizer, BartTokenizer, BartForConditionalGeneration
 
 
 
10
  )
11
+ from diffusers import (
12
+ FluxPipeline, StableDiffusionPipeline, DPMSolverMultistepScheduler,
13
+ StableDiffusionImg2ImgPipeline, DiffusionPipeline
14
+ )
15
+ from diffusers.utils import export_to_video
16
+ from datasets import load_dataset
17
+ from PIL import Image
18
  import gradio as gr
19
+ from dotenv import load_dotenv
20
+ import multiprocessing
21
+
22
+ load_dotenv()
23
+
24
+ redis_client = redis.Redis(
25
+ host=os.getenv('REDIS_HOST'),
26
+ port=os.getenv('REDIS_PORT'),
27
+ redis_password=os.getenv("REDIS_PASSWORD")
28
+ )
29
+
30
+ huggingface_token = os.getenv('HUGGINGFACE_TOKEN')
31
+
32
+ def generate_unique_id():
33
+ return str(uuid.uuid4())
34
+
35
+ def store_special_tokens(tokenizer, model_name):
36
+ special_tokens = {
37
+ 'pad_token': tokenizer.pad_token,
38
+ 'pad_token_id': tokenizer.pad_token_id,
39
+ 'eos_token': tokenizer.eos_token,
40
+ 'eos_token_id': tokenizer.eos_token_id,
41
+ 'unk_token': tokenizer.unk_token,
42
+ 'unk_token_id': tokenizer.unk_token_id,
43
+ 'bos_token': tokenizer.bos_token,
44
+ 'bos_token_id': tokenizer.bos_token_id
45
+ }
46
+ redis_client.hmset(f"tokenizer_special_tokens:{model_name}", special_tokens)
47
+
48
+ def load_special_tokens(tokenizer, model_name):
49
+ special_tokens = redis_client.hgetall(f"tokenizer_special_tokens:{model_name}")
50
+ if special_tokens:
51
+ tokenizer.pad_token = special_tokens.get('pad_token')
52
+ tokenizer.pad_token_id = int(special_tokens.get('pad_token_id', -1))
53
+ tokenizer.eos_token = special_tokens.get('eos_token')
54
+ tokenizer.eos_token_id = int(special_tokens.get('eos_token_id', -1))
55
+ tokenizer.unk_token = special_tokens.get('unk_token')
56
+ tokenizer.unk_token_id = int(special_tokens.get('unk_token_id', -1))
57
+ tokenizer.bos_token = special_tokens.get('bos_token')
58
+ tokenizer.bos_token_id = int(special_tokens.get('bos_token_id', -1))
59
+
60
+ def train_and_store_transformers_model(model_name, data):
61
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
62
+ model = AutoModelForCausalLM.from_pretrained(model_name)
63
+ model.train()
64
+ store_special_tokens(tokenizer, model_name)
65
+ torch.save(model.state_dict(), "transformers_model.pt")
66
+ with open("transformers_model.pt", "rb") as f:
67
+ model_data = f.read()
68
+ redis_client.set(f"transformers_model:{model_name}:state_dict", model_data)
69
+ tokenizer_data = tokenizer.save_pretrained("transformers_tokenizer")
70
+ redis_client.set(f"transformers_tokenizer:{model_name}", tokenizer_data)
71
+
72
+ def generate_transformers_response_from_redis(model_name, prompt):
73
+ unique_id = generate_unique_id()
74
+ model_data = redis_client.get(f"transformers_model:{model_name}:state_dict")
75
+ with open("transformers_model.pt", "wb") as f:
76
+ f.write(model_data)
77
+ model = AutoModelForCausalLM.from_pretrained(model_name)
78
+ model.load_state_dict(torch.load("transformers_model.pt"))
79
+ tokenizer_data = redis_client.get(f"transformers_tokenizer:{model_name}")
80
+ tokenizer = AutoTokenizer.from_pretrained("transformers_tokenizer")
81
+ load_special_tokens(tokenizer, model_name)
82
+ inputs = tokenizer(prompt, return_tensors="pt")
83
+ outputs = model.generate(inputs.input_ids, max_length=50)
84
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
85
+ redis_client.set(f"transformers_response:{unique_id}", response)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  return response
87
 
88
+ def train_and_store_diffusers_model(model_name, data):
89
+ pipe = FluxPipeline.from_pretrained(model_name, torch_dtype=torch.bfloat16)
90
+ pipe.enable_model_cpu_offload()
91
+ pipe.train()
92
+ pipe.save_pretrained("diffusers_model")
93
+ with open("diffusers_model/flux_pipeline.pt", "rb") as f:
94
+ model_data = f.read()
95
+ redis_client.set(f"diffusers_model:{model_name}", model_data)
96
+
97
+ def generate_diffusers_image_from_redis(model_name, prompt):
98
+ unique_id = generate_unique_id()
99
+ model_data = redis_client.get(f"diffusers_model:{model_name}")
100
+ with open("diffusers_model/flux_pipeline.pt", "wb") as f:
101
+ f.write(model_data)
102
+ pipe = FluxPipeline.from_pretrained("diffusers_model", torch_dtype=torch.bfloat16)
103
+ pipe.enable_model_cpu_offload()
104
+ image = pipe(prompt, guidance_scale=0.0, num_inference_steps=4, max_sequence_length=256, generator=torch.Generator("cpu").manual_seed(0)).images[0]
105
+ image_path = f"images/diffusers_{unique_id}.png"
106
+ image.save(image_path)
107
+ redis_client.set(f"diffusers_image:{unique_id}", image_path)
108
+ return image
109
+
110
+ def train_and_store_musicgen_model(model_name, data):
111
+ processor = AutoProcessor.from_pretrained(model_name)
112
+ model = MusicgenForConditionalGeneration.from_pretrained(model_name)
113
+ model.train()
114
+ torch.save(model.state_dict(), "musicgen_model.pt")
115
+ with open("musicgen_model.pt", "rb") as f:
116
+ model_data = f.read()
117
+ redis_client.set(f"musicgen_model:{model_name}:state_dict", model_data)
118
+ processor_data = processor.save_pretrained("musicgen_processor")
119
+ redis_client.set(f"musicgen_processor:{model_name}", processor_data)
120
+
121
+ def generate_musicgen_audio_from_redis(model_name, text_prompts):
122
+ unique_id = generate_unique_id()
123
+ model_data = redis_client.get(f"musicgen_model:{model_name}:state_dict")
124
+ with open("musicgen_model.pt", "wb") as f:
125
+ f.write(model_data)
126
+ model = MusicgenForConditionalGeneration.from_pretrained(model_name)
127
+ model.load_state_dict(torch.load("musicgen_model.pt"))
128
+ processor_data = redis_client.get(f"musicgen_processor:{model_name}")
129
+ processor = AutoProcessor.from_pretrained("musicgen_processor")
130
+ inputs = processor(text=text_prompts, padding=True, return_tensors="pt")
131
+ audio_values = model.generate(**inputs, max_new_tokens=256)
132
+ audio_path = f"audio/musicgen_{unique_id}.wav"
133
+ scipy.io.wavfile.write(audio_path, rate=audio_values["sampling_rate"], data=audio_values["audio"])
134
+ redis_client.set(f"musicgen_audio:{unique_id}", audio_path)
135
+ return audio_path
136
+
137
+ def train_and_store_stable_diffusion_model(model_name, data):
138
+ pipe = StableDiffusionPipeline.from_pretrained(model_name, torch_dtype=torch.float16)
139
+ pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
140
+ pipe = pipe.to("cuda")
141
+ pipe.train()
142
+ pipe.save_pretrained("stable_diffusion_model")
143
+ with open("stable_diffusion_model/stable_diffusion_pipeline.pt", "rb") as f:
144
+ model_data = f.read()
145
+ redis_client.set(f"stable_diffusion_model:{model_name}", model_data)
146
+
147
+ def generate_stable_diffusion_image_from_redis(model_name, prompt):
148
+ unique_id = generate_unique_id()
149
+ model_data = redis_client.get(f"stable_diffusion_model:{model_name}")
150
+ with open("stable_diffusion_model/stable_diffusion_pipeline.pt", "wb") as f:
151
+ f.write(model_data)
152
+ pipe = StableDiffusionPipeline.from_pretrained("stable_diffusion_model", torch_dtype=torch.float16)
153
+ pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
154
+ pipe = pipe.to("cuda")
155
+ image = pipe(prompt).images[0]
156
+ image_path = f"images/stable_diffusion_{unique_id}.png"
157
+ image.save(image_path)
158
+ redis_client.set(f"stable_diffusion_image:{unique_id}", image_path)
159
+ return image
160
+
161
+ def train_and_store_img2img_model(model_name, data):
162
+ pipe = StableDiffusionImg2ImgPipeline.from_pretrained(model_name, torch_dtype=torch.float16)
163
+ pipe = pipe.to("cuda")
164
+ pipe.train()
165
+ pipe.save_pretrained("img2img_model")
166
+ with open("img2img_model/img2img_pipeline.pt", "rb") as f:
167
+ model_data = f.read()
168
+ redis_client.set(f"img2img_model:{model_name}", model_data)
169
+
170
+ def generate_img2img_from_redis(model_name, init_image, prompt, strength=0.75):
171
+ unique_id = generate_unique_id()
172
+ model_data = redis_client.get(f"img2img_model:{model_name}")
173
+ with open("img2img_model/img2img_pipeline.pt", "wb") as f:
174
+ f.write(model_data)
175
+ pipe = StableDiffusionImg2ImgPipeline.from_pretrained("img2img_model", torch_dtype=torch.float16)
176
+ pipe = pipe.to("cuda")
177
+ init_image = Image.open(init_image).convert("RGB")
178
+ image = pipe(prompt=prompt, init_image=init_image, strength=strength).images[0]
179
+ image_path = f"images/img2img_{unique_id}.png"
180
+ image.save(image_path)
181
+ redis_client.set(f"img2img_image:{unique_id}", image_path)
182
+ return image
183
+
184
+ def train_and_store_marianmt_model(model_name, data):
185
+ tokenizer = MarianTokenizer.from_pretrained(model_name)
186
+ model = MarianMTModel.from_pretrained(model_name)
187
+ model.train()
188
+ torch.save(model.state_dict(), "marianmt_model.pt")
189
+ with open("marianmt_model.pt", "rb") as f:
190
+ model_data = f.read()
191
+ redis_client.set(f"marianmt_model:{model_name}:state_dict", model_data)
192
+ tokenizer_data = tokenizer.save_pretrained("marianmt_tokenizer")
193
+ redis_client.set(f"marianmt_tokenizer:{model_name}", tokenizer_data)
194
+
195
+ def translate_text_from_redis(model_name, text, src_lang, tgt_lang):
196
+ unique_id = generate_unique_id()
197
+ model_data = redis_client.get(f"marianmt_model:{model_name}:state_dict")
198
+ with open("marianmt_model.pt", "wb") as f:
199
+ f.write(model_data)
200
+ model = MarianMTModel.from_pretrained(model_name)
201
+ model.load_state_dict(torch.load("marianmt_model.pt"))
202
+ tokenizer_data = redis_client.get(f"marianmt_tokenizer:{model_name}")
203
+ tokenizer = MarianTokenizer.from_pretrained("marianmt_tokenizer")
204
+ inputs = tokenizer(text, return_tensors="pt", src_lang=src_lang, tgt_lang=tgt_lang)
205
+ translated_tokens = model.generate(**inputs)
206
+ translation = tokenizer.decode(translated_tokens[0], skip_special_tokens=True)
207
+ redis_client.set(f"marianmt_translation:{unique_id}", translation)
208
+ return translation
209
+
210
+ def train_and_store_bart_model(model_name, data):
211
+ tokenizer = BartTokenizer.from_pretrained(model_name)
212
+ model = BartForConditionalGeneration.from_pretrained(model_name)
213
+ model.train()
214
+ torch.save(model.state_dict(), "bart_model.pt")
215
+ with open("bart_model.pt", "rb") as f:
216
+ model_data = f.read()
217
+ redis_client.set(f"bart_model:{model_name}:state_dict", model_data)
218
+ tokenizer_data = tokenizer.save_pretrained("bart_tokenizer")
219
+ redis_client.set(f"bart_tokenizer:{model_name}", tokenizer_data)
220
+
221
+ def summarize_text_from_redis(model_name, text):
222
+ unique_id = generate_unique_id()
223
+ model_data = redis_client.get(f"bart_model:{model_name}:state_dict")
224
+ with open("bart_model.pt", "wb") as f:
225
+ f.write(model_data)
226
+ model = BartForConditionalGeneration.from_pretrained(model_name)
227
+ model.load_state_dict(torch.load("bart_model.pt"))
228
+ tokenizer_data = redis_client.get(f"bart_tokenizer:{model_name}")
229
+ tokenizer = BartTokenizer.from_pretrained("bart_tokenizer")
230
+ load_special_tokens(tokenizer, model_name)
231
+ inputs = tokenizer(text, return_tensors="pt", truncation=True)
232
+ summary_ids = model.generate(inputs["input_ids"], max_length=150, min_length=40, length_penalty=2.0, num_beams=4)
233
+ summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True)
234
+ redis_client.set(f"bart_summary:{unique_id}", summary)
235
+ return summary
236
+
237
+ def auto_train_and_store(model_name, task, data):
238
+ if task == "text-generation":
239
+ train_and_store_transformers_model(model_name, data)
240
+ elif task == "diffusers":
241
+ train_and_store_diffusers_model(model_name, data)
242
+ elif task == "musicgen":
243
+ train_and_store_musicgen_model(model_name, data)
244
+ elif task == "stable-diffusion":
245
+ train_and_store_stable_diffusion_model(model_name, data)
246
+ elif task == "img2img":
247
+ train_and_store_img2img_model(model_name, data)
248
+ elif task == "translation":
249
+ train_and_store_marianmt_model(model_name, data)
250
+ elif task == "summarization":
251
+ train_and_store_bart_model(model_name, data)
252
+
253
+ def transcribe_audio_from_redis(audio_file):
254
+ audio_file_path = "audio_file.wav"
255
+ with open(audio_file_path, "wb") as f:
256
+ f.write(audio_file)
257
+ processor = WhisperProcessor.from_pretrained("openai/whisper-small")
258
+ model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-small")
259
+ model.config.forced_decoder_ids = None
260
+ sample = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")[0]["audio"]
261
+ input_features = processor(sample["array"], sampling_rate=sample["sampling_rate"], return_tensors="pt").input_features
262
+ predicted_ids = model.generate(input_features)
263
+ transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
264
+ return transcription[0]
265
+
266
+ def generate_image_from_redis(model_name, prompt, model_type):
267
+ if model_type == "diffusers":
268
+ image = generate_diffusers_image_from_redis(model_name, prompt)
269
+ elif model_type == "stable-diffusion":
270
+ image = generate_stable_diffusion_image_from_redis(model_name, prompt)
271
+ elif model_type == "img2img":
272
+ image = generate_img2img_from_redis(model_name, "init_image.png", prompt)
273
+ return image
274
+
275
+ def generate_video_from_redis(prompt):
276
+ pipe = DiffusionPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16, variant="fp16")
277
+ pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
278
+ pipe.enable_model_cpu_offload()
279
+ video_frames = pipe(prompt, num_inference_steps=25).frames
280
+ video_path = export_to_video(video_frames)
281
+ unique_id = generate_unique_id()
282
+ redis_client.set(f"video_{unique_id}", video_path)
283
+ return video_path
284
+
285
+ def generate_random_response(prompts, generator):
286
+ responses = []
287
+ for prompt in prompts:
288
+ response = generator(prompt, max_length=50)[0]['generated_text']
289
+ responses.append(response)
290
+ return responses
291
+
292
+ def process_parallel(tasks):
293
+ with multiprocessing.Pool() as pool:
294
+ results = pool.map(lambda task: task(), tasks)
295
+ return results
296
+
297
+ def generate_response_from_prompt(prompt, generator):
298
+ responses = generate_random_response([prompt], generator)
299
+ return responses[0]
300
+
301
+ def generate_image_from_prompt(prompt, image_type):
302
+ if image_type == "diffusers":
303
+ image = generate_diffusers_image_from_redis("diffusers_model_name", prompt)
304
+ elif image_type == "stable-diffusion":
305
+ image = generate_stable_diffusion_image_from_redis("stable_diffusion_model_name", prompt)
306
+ elif image_type == "img2img":
307
+ image = generate_img2img_from_redis("img2img_model_name", "init_image.png", prompt)
308
+ return image
309
+
310
+ def gradio_app():
311
+ with gr.Blocks() as app:
312
+ gr.Markdown("## Generación de Texto con Transformers")
313
+ with gr.Row():
314
+ prompt_text = gr.Textbox(label="Texto de Entrada")
315
+ text_output = gr.Textbox(label="Respuesta")
316
+ text_button = gr.Button("Generar Texto")
317
+ text_button.click(generate_response_from_prompt, inputs=prompt_text, outputs=text_output)
318
+
319
+ gr.Markdown("## Generación de Imágenes con Diffusers, Stable Diffusion e Img2Img")
320
+ with gr.Row():
321
+ prompt_image = gr.Textbox(label="Prompt de Imagen")
322
+ image_type = gr.Dropdown(["diffusers", "stable-diffusion", "img2img"], label="Tipo de Imagen")
323
+ image_output = gr.Image(type="pil", label="Imagen Generada")
324
+ image_button = gr.Button("Generar Imagen")
325
+ image_button.click(generate_image_from_prompt, inputs=[prompt_image, image_type], outputs=image_output)
326
+
327
+ gr.Markdown("## Generación de Video")
328
+ with gr.Row():
329
+ prompt_video = gr.Textbox(label="Prompt de Video")
330
+ video_output = gr.Video(type="file", label="Video Generado")
331
+ video_button = gr.Button("Generar Video")
332
+ video_button.click(generate_video_from_redis, inputs=prompt_video, outputs=video_output)
333
+
334
+ gr.Markdown("## Generación de Audio con MusicGen")
335
+ with gr.Row():
336
+ text_prompts_audio = gr.Textbox(label="Prompts de Audio")
337
+ audio_output = gr.Audio(type="file", label="Audio Generado")
338
+ audio_button = gr.Button("Generar Audio")
339
+ audio_button.click(generate_musicgen_audio_from_redis, inputs=text_prompts_audio, outputs=audio_output)
340
+
341
+ gr.Markdown("## Transcripción de Audio con Whisper")
342
+ with gr.Row():
343
+ audio_file = gr.Audio(type="file", label="Archivo de Audio")
344
+ transcription_output = gr.Textbox(label="Transcripción")
345
+ audio_button = gr.Button("Transcribir Audio")
346
+ audio_button.click(transcribe_audio_from_redis, inputs=audio_file, outputs=transcription_output)
347
+
348
+ gr.Markdown("## Traducción de Texto")
349
+ with gr.Row():
350
+ text_input = gr.Textbox(label="Texto a Traducir")
351
+ translation_output = gr.Textbox(label="Traducción")
352
+ src_lang_input = gr.Textbox(label="Idioma de Origen", value="en")
353
+ tgt_lang_input = gr.Textbox(label="Idioma de Destino", value="es")
354
+ translate_button = gr.Button("Traducir Texto")
355
+ translate_button.click(translate_text_from_redis, inputs=[text_input, src_lang_input, tgt_lang_input], outputs=translation_output)
356
+
357
+ gr.Markdown("## Resumen de Texto")
358
+ with gr.Row():
359
+ text_to_summarize = gr.Textbox(label="Texto para Resumir")
360
+ summary_output = gr.Textbox(label="Resumen")
361
+ summarize_button = gr.Button("Generar Resumen")
362
+ summarize_button.click(summarize_text_from_redis, inputs=text_to_summarize, outputs=summary_output)
363
+
364
+ app.launch()
365
+
366
+ if __name__ == "__main__":
367
+ gradio_app()