File size: 1,243 Bytes
4439915
edccdde
4439915
edccdde
4439915
a4f459d
725002a
4439915
 
 
ee73d02
4439915
b77578a
 
4439915
edccdde
4439915
37b18b3
4439915
 
 
 
 
 
 
 
 
 
 
 
edccdde
4439915
b77578a
4439915
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import torch, spaces
import gradio as gr
from diffusers import FluxPipeline

MODELS = {
    # 'FLUX.1 [dev]': 'black-forest-labs/FLUX.1-dev',
    'FLUX.1 [schnell]': 'black-forest-labs/FLUX.1-schnell',
    'OpenFLUX.1': 'ostris/OpenFLUX.1',
}
MODEL_CACHE = {}
for id, model in MODELS.items():
    print(f"Loading model {model}...")
    MODEL_CACHE[id] = FluxPipeline.from_pretrained(model, torch_dtype=torch.bfloat16)
    MODEL_CACHE[id].enable_model_cpu_offload() #save some VRAM by offloading the model to CPU. Remove this if you have enough GPU power
    print(f"Loaded model {model}")

@spaces.GPU
def generate(text):
    prompt = "A cat holding a sign that says hello world"
    image = MODEL_CACHE['OpenFLUX.1'](
        prompt,
        height=1024,
        width=1024,
        guidance_scale=3.5,
        num_inference_steps=50,
        max_sequence_length=512,
        generator=torch.Generator("cpu").manual_seed(0)
    ).images[0]
    return image
    # image.save("flux-dev.png")

with gr.Blocks() as demo:
    prompt = gr.Textbox(label="Prompt")
    btn = gr.Button("Generate", variant="primary")
    out = gr.Image(label="Generated image", interactive=False)
    btn.click(generate,inputs=prompt,outputs=out)
demo.queue().launch()