Chroma-Extra / app.py
gokaygokay's picture
chroma
fc6a8b1
raw
history blame
2.88 kB
import spaces
import torch
import gradio as gr
import os
from diffusers import FluxPipeline, FluxTransformer2DModel, BitsAndBytesConfig
from transformers import T5EncoderModel, BitsAndBytesConfig as BitsAndBytesConfigTF
# Initialize model outside the function
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.bfloat16
file_url = "https://huggingface.co/lodestones/Chroma/blob/main/chroma-unlocked-v31.safetensors"
huggingface_token = os.getenv("HUGGINGFACE_TOKEN")
single_file_base_model = "camenduru/FLUX.1-dev-diffusers"
# Initialize text encoder
quantization_config_tf = BitsAndBytesConfigTF(load_in_8bit=True, bnb_8bit_compute_dtype=torch.bfloat16)
text_encoder_2 = T5EncoderModel.from_pretrained(
single_file_base_model,
subfolder="text_encoder_2",
torch_dtype=dtype,
config=single_file_base_model,
quantization_config=quantization_config_tf,
token=huggingface_token
)
# Initialize transformer
quantization_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16,
token=huggingface_token
)
transformer = FluxTransformer2DModel.from_single_file(
file_url,
subfolder="transformer",
torch_dtype=dtype,
config=single_file_base_model,
quantization_config=quantization_config,
token=huggingface_token
)
# Load the pipeline with proper configuration
flux_pipeline = FluxPipeline.from_pretrained(
single_file_base_model,
transformer=transformer,
text_encoder_2=text_encoder_2,
torch_dtype=dtype,
token=huggingface_token
)
flux_pipeline.to(device)
@spaces.GPU()
def generate_image(prompt, negative_prompt="", num_inference_steps=30, guidance_scale=7.5):
# Generate image
image = flux_pipeline(
prompt=prompt,
negative_prompt=negative_prompt,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale
).images[0]
return image
# Create Gradio interface
iface = gr.Interface(
fn=generate_image,
inputs=[
gr.Textbox(label="Prompt", placeholder="Enter your prompt here..."),
gr.Textbox(label="Negative Prompt", placeholder="Enter negative prompt here...", value=""),
gr.Slider(minimum=1, maximum=100, value=30, step=1, label="Number of Inference Steps"),
gr.Slider(minimum=1.0, maximum=20.0, value=7.5, step=0.1, label="Guidance Scale")
],
outputs=gr.Image(label="Generated Image"),
title="Chroma Image Generator",
description="Generate images using the Chroma model",
examples=[
["A beautiful sunset over mountains, photorealistic, 8k", "blurry, low quality, distorted", 30, 7.5],
["A futuristic cityscape at night, neon lights, cyberpunk style", "ugly, deformed, low resolution", 30, 7.5]
]
)
if __name__ == "__main__":
iface.launch()