Spaces:
Build error
Build error
File size: 7,028 Bytes
f13c41f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
import redis
import pickle
import torch
from PIL import Image
from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline, FluxPipeline, DiffusionPipeline, DPMSolverMultistepScheduler
from diffusers.utils import export_to_video
from transformers import pipeline as transformers_pipeline, AutoModelForCausalLM, AutoTokenizer, TrainingArguments
from audiocraft.models import MusicGen
import gradio as gr
from huggingface_hub import snapshot_download, HfApi, HfFolder
import multiprocessing
import io
from dotenv import load_dotenv
import os
# Cargar las variables del archivo .env
load_dotenv()
# Obtener las variables de entorno
hf_token = os.getenv("HF_TOKEN")
redis_host = os.getenv("REDIS_HOST")
redis_port = os.getenv("REDIS_PORT")
redis_password = os.getenv("REDIS_PASSWORD")
# Usar las variables de huggingface
HfFolder.save_token(hf_token)
# Usar las variables de redis
def connect_to_redis():
return redis.Redis(host=redis_host, port=redis_port, password=redis_password)
def load_object_from_redis(key):
with connect_to_redis() as redis_client:
obj_data = redis_client.get(key)
return pickle.loads(obj_data) if obj_data else None
def save_object_to_redis(key, obj):
with connect_to_redis() as redis_client:
redis_client.set(key, pickle.dumps(obj))
def get_model_or_download(model_id, redis_key, loader_func):
model = load_object_from_redis(redis_key)
if not model:
model = loader_func(model_id, use_auth_token=hf_token, torch_dtype=torch.float16)
save_object_to_redis(redis_key, model)
return model
def generate_image(prompt):
return text_to_image_pipeline(prompt).images[0]
def edit_image_with_prompt(image, prompt, strength=0.75):
return img2img_pipeline(prompt=prompt, init_image=image.convert("RGB"), strength=strength).images[0]
def generate_song(prompt, duration=10):
return music_gen.generate(prompt, duration=duration)
def generate_text(prompt):
return text_gen_pipeline([{"role": "user", "content": prompt}], max_new_tokens=256)[0]["generated_text"][-1]["content"].strip()
def generate_flux_image(prompt):
return flux_pipeline(
prompt,
guidance_scale=0.0,
num_inference_steps=4,
max_sequence_length=256,
generator=torch.Generator("cpu").manual_seed(0)
).images[0]
def generate_code(prompt):
inputs = starcoder_tokenizer.encode(prompt, return_tensors="pt").to("cuda")
outputs = starcoder_model.generate(inputs)
return starcoder_tokenizer.decode(outputs[0])
def generate_video(prompt):
pipe = DiffusionPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16, variant="fp16")
pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
pipe.enable_model_cpu_offload()
return export_to_video(pipe(prompt, num_inference_steps=25).frames)
def test_model_meta_llama():
messages = [
{"role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!"},
{"role": "user", "content": "Who are you?"}
]
return meta_llama_pipeline(messages, max_new_tokens=256)[0]["generated_text"][-1]
def train_model(model, dataset, epochs, batch_size, learning_rate):
output_dir = io.BytesIO()
training_args = TrainingArguments(
output_dir=output_dir,
num_train_epochs=epochs,
per_device_train_batch_size=batch_size,
learning_rate=learning_rate,
)
trainer = Trainer(model=model, args=training_args, train_dataset=dataset)
trainer.train()
save_object_to_redis("trained_model", model)
save_object_to_redis("training_results", output_dir.getvalue())
def run_task(task_queue):
while True:
task = task_queue.get()
if task is None:
break
func, args, kwargs = task
func(*args, **kwargs)
task_queue = multiprocessing.Queue()
num_processes = multiprocessing.cpu_count()
processes = []
for _ in range(num_processes):
p = multiprocessing.Process(target=run_task, args=(task_queue,))
p.start()
processes.append(p)
device = "cuda" if torch.cuda.is_available() else "cpu"
text_to_image_pipeline = get_model_or_download("CompVis/stable-diffusion-v1-4", "text_to_image_model", StableDiffusionPipeline.from_pretrained).to(device)
img2img_pipeline = get_model_or_download("runwayml/stable-diffusion-inpainting", "img2img_model", StableDiffusionImg2ImgPipeline.from_pretrained).to(device)
flux_pipeline = get_model_or_download("black-forest-labs/FLUX.1-schnell", "flux_model", FluxPipeline.from_pretrained)
flux_pipeline.enable_model_cpu_offload()
music_gen = load_object_from_redis("music_gen") or MusicGen.get_pretrained('melody', use_auth_token=hf_token)
save_object_to_redis("music_gen", music_gen)
text_gen_pipeline = load_object_from_redis("text_gen_pipeline") or transformers_pipeline(
"text-generation",
model="google/gemma-2-2b-it",
model_kwargs={"torch_dtype": torch.bfloat16},
device=device,
use_auth_token=hf_token,
)
save_object_to_redis("text_gen_pipeline", text_gen_pipeline)
starcoder_tokenizer = AutoTokenizer.from_pretrained("bigcode/starcoder2-15b", use_auth_token=hf_token)
starcoder_model = AutoModelForCausalLM.from_pretrained("bigcode/starcoder2-15b", device_map="auto", torch_dtype=torch.bfloat16, use_auth_token=hf_token)
meta_llama_pipeline = transformers_pipeline(
"text-generation",
model="meta-llama/Meta-Llama-3.1-8B-Instruct",
model_kwargs={"torch_dtype": torch.bfloat16},
device_map="auto",
use_auth_token=hf_token
)
gen_image_tab = gr.Interface(generate_image, gr.inputs.Textbox(label="Prompt:"), gr.outputs.Image(type="pil"), title="Generate Images")
edit_image_tab = gr.Interface(edit_image_with_prompt, [gr.inputs.Image(type="pil", label="Image:"), gr.inputs.Textbox(label="Prompt:"), gr.inputs.Slider(0.1, 1.0, 0.75, step=0.05, label="Strength:")], gr.outputs.Image(type="pil"), title="Edit Images")
generate_song_tab = gr.Interface(generate_song, [gr.inputs.Textbox(label="Prompt:"), gr.inputs.Slider(5, 60, 10, step=1, label="Duration (s):")], gr.outputs.Audio(type="numpy"), title="Generate Songs")
generate_text_tab = gr.Interface(generate_text, gr.inputs.Textbox(label="Prompt:"), gr.outputs.Textbox(label="Generated Text:"), title="Generate Text")
generate_flux_image_tab = gr.Interface(generate_flux_image, gr.inputs.Textbox(label="Prompt:"), gr.outputs.Image(type="pil"), title="Generate FLUX Images")
model_meta_llama_test_tab = gr.Interface(test_model_meta_llama, gr.inputs.Textbox(label="Test Input:"), gr.outputs.Textbox(label="Model Output:"), title="Test Meta-Llama")
app = gr.TabbedInterface(
[gen_image_tab, edit_image_tab, generate_song_tab, generate_text_tab, generate_flux_image_tab, model_meta_llama_test_tab],
["Generate Image", "Edit Image", "Generate Song", "Generate Text", "Generate FLUX Image", "Test Meta-Llama"]
)
app.launch(share=True)
for _ in range(num_processes):
task_queue.put(None)
for p in processes:
p.join() |