Spaces:
Running
on
Zero
Running
on
Zero
import spaces | |
import torch | |
import gradio as gr | |
from diffusers import FluxTransformer2DModel, FluxPipeline, BitsAndBytesConfig | |
from transformers import T5EncoderModel, BitsAndBytesConfig as BitsAndBytesConfigTF | |
# Initialize model outside the function | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
dtype = torch.bfloat16 | |
single_file_base_model = "camenduru/FLUX.1-dev-diffusers" | |
file_url = "https://huggingface.co/lodestones/Chroma/resolve/main/chroma-unlocked-v31.safetensors" | |
quantization_config_tf = BitsAndBytesConfigTF(load_in_8bit=True, bnb_8bit_compute_dtype=torch.bfloat16) | |
text_encoder_2 = T5EncoderModel.from_pretrained(single_file_base_model, subfolder="text_encoder_2", torch_dtype=dtype, config=single_file_base_model, quantization_config=quantization_config_tf) | |
quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16) | |
transformer = FluxTransformer2DModel.from_single_file(file_url, subfolder="transformer", torch_dtype=dtype, config=single_file_base_model, quantization_config=quantization_config) | |
flux_pipeline = FluxPipeline.from_pretrained(single_file_base_model, transformer=transformer, text_encoder_2=text_encoder_2, torch_dtype=dtype) | |
flux_pipeline.to(device) | |
def generate_image(prompt, negative_prompt="", num_inference_steps=30, guidance_scale=7.5): | |
# Generate image | |
image = flux_pipeline( | |
prompt=prompt, | |
negative_prompt=negative_prompt, | |
num_inference_steps=num_inference_steps, | |
guidance_scale=guidance_scale | |
).images[0] | |
return image | |
# Create Gradio interface | |
iface = gr.Interface( | |
fn=generate_image, | |
inputs=[ | |
gr.Textbox(label="Prompt", placeholder="Enter your prompt here..."), | |
gr.Textbox(label="Negative Prompt", placeholder="Enter negative prompt here...", value=""), | |
gr.Slider(minimum=1, maximum=100, value=30, step=1, label="Number of Inference Steps"), | |
gr.Slider(minimum=1.0, maximum=20.0, value=7.5, step=0.1, label="Guidance Scale") | |
], | |
outputs=gr.Image(label="Generated Image"), | |
title="Chroma Image Generator", | |
description="Generate images using the Chroma model with FLUX pipeline", | |
examples=[ | |
["A beautiful sunset over mountains, photorealistic, 8k", "blurry, low quality, distorted", 30, 7.5], | |
["A futuristic cityscape at night, neon lights, cyberpunk style", "ugly, deformed, low resolution", 30, 7.5] | |
] | |
) | |
if __name__ == "__main__": | |
iface.launch() |