Chroma-Extra / app.py
gokaygokay's picture
chroma
e25da44
raw
history blame
2.58 kB
import os
import torch
import gradio as gr
from diffusers import FluxTransformer2DModel, FluxPipeline, BitsAndBytesConfig
from transformers import T5EncoderModel, BitsAndBytesConfig as BitsAndBytesConfigTF
def generate_image(prompt, negative_prompt="", num_inference_steps=30, guidance_scale=7.5):
# Initialize Flux pipeline
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.bfloat16
single_file_base_model = "camenduru/FLUX.1-dev-diffusers"
file_url = "https://huggingface.co/lodestones/Chroma/resolve/main/chroma-unlocked-v31.safetensors"
quantization_config_tf = BitsAndBytesConfigTF(load_in_8bit=True, bnb_8bit_compute_dtype=torch.bfloat16)
text_encoder_2 = T5EncoderModel.from_pretrained(single_file_base_model, subfolder="text_encoder_2", torch_dtype=dtype, config=single_file_base_model, quantization_config=quantization_config_tf)
quantization_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.bfloat16)
transformer = FluxTransformer2DModel.from_single_file(file_url, subfolder="transformer", torch_dtype=dtype, config=single_file_base_model, quantization_config=quantization_config)
flux_pipeline = FluxPipeline.from_pretrained(single_file_base_model, transformer=transformer, text_encoder_2=text_encoder_2, torch_dtype=dtype)
flux_pipeline.to(device)
# Generate image
image = flux_pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale
).images[0]
return image
# Create Gradio interface
iface = gr.Interface(
fn=generate_image,
inputs=[
gr.Textbox(label="Prompt", placeholder="Enter your prompt here..."),
gr.Textbox(label="Negative Prompt", placeholder="Enter negative prompt here...", value=""),
gr.Slider(minimum=1, maximum=100, value=30, step=1, label="Number of Inference Steps"),
gr.Slider(minimum=1.0, maximum=20.0, value=7.5, step=0.1, label="Guidance Scale")
],
outputs=gr.Image(label="Generated Image"),
title="Chroma Image Generator",
description="Generate images using the Chroma model with FLUX pipeline",
examples=[
["A beautiful sunset over mountains, photorealistic, 8k", "blurry, low quality, distorted", 30, 7.5],
["A futuristic cityscape at night, neon lights, cyberpunk style", "ugly, deformed, low resolution", 30, 7.5]
]
)
if __name__ == "__main__":
iface.launch()