File size: 7,028 Bytes
f13c41f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import redis
import pickle
import torch
from PIL import Image
from diffusers import StableDiffusionPipeline, StableDiffusionImg2ImgPipeline, FluxPipeline, DiffusionPipeline, DPMSolverMultistepScheduler
from diffusers.utils import export_to_video
from transformers import pipeline as transformers_pipeline, AutoModelForCausalLM, AutoTokenizer, TrainingArguments
from audiocraft.models import MusicGen
import gradio as gr
from huggingface_hub import snapshot_download, HfApi, HfFolder
import multiprocessing
import io
from dotenv import load_dotenv
import os

# Cargar las variables del archivo .env
load_dotenv()

# Obtener las variables de entorno
hf_token = os.getenv("HF_TOKEN")
redis_host = os.getenv("REDIS_HOST")
redis_port = os.getenv("REDIS_PORT")
redis_password = os.getenv("REDIS_PASSWORD")

# Usar las variables de huggingface
HfFolder.save_token(hf_token)

# Usar las variables de redis
def connect_to_redis():
    return  redis.Redis(host=redis_host, port=redis_port, password=redis_password)

def load_object_from_redis(key):
    with connect_to_redis() as redis_client:
        obj_data = redis_client.get(key)
    return pickle.loads(obj_data) if obj_data else None

def save_object_to_redis(key, obj):
    with connect_to_redis() as redis_client:
        redis_client.set(key, pickle.dumps(obj))

def get_model_or_download(model_id, redis_key, loader_func):
    model = load_object_from_redis(redis_key)
    if not model:
        model = loader_func(model_id, use_auth_token=hf_token, torch_dtype=torch.float16)
        save_object_to_redis(redis_key, model)
    return model

def generate_image(prompt):
    return text_to_image_pipeline(prompt).images[0]

def edit_image_with_prompt(image, prompt, strength=0.75):
    return img2img_pipeline(prompt=prompt, init_image=image.convert("RGB"), strength=strength).images[0]

def generate_song(prompt, duration=10):
    return music_gen.generate(prompt, duration=duration)

def generate_text(prompt):
    return text_gen_pipeline([{"role": "user", "content": prompt}], max_new_tokens=256)[0]["generated_text"][-1]["content"].strip()

def generate_flux_image(prompt):
    return flux_pipeline(
        prompt,
        guidance_scale=0.0,
        num_inference_steps=4,
        max_sequence_length=256,
        generator=torch.Generator("cpu").manual_seed(0)
    ).images[0]

def generate_code(prompt):
    inputs = starcoder_tokenizer.encode(prompt, return_tensors="pt").to("cuda")
    outputs = starcoder_model.generate(inputs)
    return starcoder_tokenizer.decode(outputs[0])

def generate_video(prompt):
    pipe = DiffusionPipeline.from_pretrained("damo-vilab/text-to-video-ms-1.7b", torch_dtype=torch.float16, variant="fp16")
    pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
    pipe.enable_model_cpu_offload()
    return export_to_video(pipe(prompt, num_inference_steps=25).frames)

def test_model_meta_llama():
    messages = [
        {"role": "system", "content": "You are a pirate chatbot who always responds in pirate speak!"},
        {"role": "user", "content": "Who are you?"}
    ]
    return meta_llama_pipeline(messages, max_new_tokens=256)[0]["generated_text"][-1]

def train_model(model, dataset, epochs, batch_size, learning_rate):
    output_dir = io.BytesIO()
    training_args = TrainingArguments(
        output_dir=output_dir,
        num_train_epochs=epochs,
        per_device_train_batch_size=batch_size,
        learning_rate=learning_rate,
    )
    trainer = Trainer(model=model, args=training_args, train_dataset=dataset)
    trainer.train()
    save_object_to_redis("trained_model", model)
    save_object_to_redis("training_results", output_dir.getvalue())

def run_task(task_queue):
    while True:
        task = task_queue.get()
        if task is None:
            break
        func, args, kwargs = task
        func(*args, **kwargs)

task_queue = multiprocessing.Queue()
num_processes = multiprocessing.cpu_count()

processes = []
for _ in range(num_processes):
    p = multiprocessing.Process(target=run_task, args=(task_queue,))
    p.start()
    processes.append(p)

device = "cuda" if torch.cuda.is_available() else "cpu"
text_to_image_pipeline = get_model_or_download("CompVis/stable-diffusion-v1-4", "text_to_image_model", StableDiffusionPipeline.from_pretrained).to(device)
img2img_pipeline = get_model_or_download("runwayml/stable-diffusion-inpainting", "img2img_model", StableDiffusionImg2ImgPipeline.from_pretrained).to(device)
flux_pipeline = get_model_or_download("black-forest-labs/FLUX.1-schnell", "flux_model", FluxPipeline.from_pretrained)
flux_pipeline.enable_model_cpu_offload()
music_gen = load_object_from_redis("music_gen") or MusicGen.get_pretrained('melody', use_auth_token=hf_token)
save_object_to_redis("music_gen", music_gen)
text_gen_pipeline = load_object_from_redis("text_gen_pipeline") or transformers_pipeline(
    "text-generation",
    model="google/gemma-2-2b-it",
    model_kwargs={"torch_dtype": torch.bfloat16},
    device=device,
    use_auth_token=hf_token,
)
save_object_to_redis("text_gen_pipeline", text_gen_pipeline)
starcoder_tokenizer = AutoTokenizer.from_pretrained("bigcode/starcoder2-15b", use_auth_token=hf_token)
starcoder_model = AutoModelForCausalLM.from_pretrained("bigcode/starcoder2-15b", device_map="auto", torch_dtype=torch.bfloat16, use_auth_token=hf_token)
meta_llama_pipeline = transformers_pipeline(
    "text-generation",
    model="meta-llama/Meta-Llama-3.1-8B-Instruct",
    model_kwargs={"torch_dtype": torch.bfloat16},
    device_map="auto",
    use_auth_token=hf_token
)

gen_image_tab = gr.Interface(generate_image, gr.inputs.Textbox(label="Prompt:"), gr.outputs.Image(type="pil"), title="Generate Images")
edit_image_tab = gr.Interface(edit_image_with_prompt, [gr.inputs.Image(type="pil", label="Image:"), gr.inputs.Textbox(label="Prompt:"), gr.inputs.Slider(0.1, 1.0, 0.75, step=0.05, label="Strength:")], gr.outputs.Image(type="pil"), title="Edit Images")
generate_song_tab = gr.Interface(generate_song, [gr.inputs.Textbox(label="Prompt:"), gr.inputs.Slider(5, 60, 10, step=1, label="Duration (s):")], gr.outputs.Audio(type="numpy"), title="Generate Songs")
generate_text_tab = gr.Interface(generate_text, gr.inputs.Textbox(label="Prompt:"), gr.outputs.Textbox(label="Generated Text:"), title="Generate Text")
generate_flux_image_tab = gr.Interface(generate_flux_image, gr.inputs.Textbox(label="Prompt:"), gr.outputs.Image(type="pil"), title="Generate FLUX Images")
model_meta_llama_test_tab = gr.Interface(test_model_meta_llama, gr.inputs.Textbox(label="Test Input:"), gr.outputs.Textbox(label="Model Output:"), title="Test Meta-Llama")

app = gr.TabbedInterface(
    [gen_image_tab, edit_image_tab, generate_song_tab, generate_text_tab, generate_flux_image_tab, model_meta_llama_test_tab],
    ["Generate Image", "Edit Image", "Generate Song", "Generate Text", "Generate FLUX Image", "Test Meta-Llama"]
)

app.launch(share=True)

for _ in range(num_processes):
    task_queue.put(None)
for p in processes:
    p.join()