|
import spaces |
|
import gradio as gr |
|
import random |
|
import os |
|
import time |
|
import torch |
|
from diffusers import FluxPipeline |
|
|
|
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" |
|
print(f"Using device: {DEVICE}") |
|
|
|
DEFAULT_HEIGHT = 1024 |
|
DEFAULT_WIDTH = 1024 |
|
DEFAULT_GUIDANCE_SCALE = 3.5 |
|
DEFAULT_NUM_INFERENCE_STEPS = 15 |
|
DEFAULT_MAX_SEQUENCE_LENGTH = 512 |
|
HF_TOKEN = os.environ.get("HF_ACCESS_TOKEN") |
|
|
|
|
|
CACHED_PIPE = None |
|
|
|
def load_bnb_4bit_pipeline(): |
|
"""Load the 4-bit quantized pipeline""" |
|
global CACHED_PIPE |
|
if CACHED_PIPE is not None: |
|
return CACHED_PIPE |
|
|
|
print("Loading 4-bit BNB pipeline...") |
|
MODEL_ID = "derekl35/FLUX.1-dev-nf4" |
|
|
|
start_time = time.time() |
|
try: |
|
pipe = FluxPipeline.from_pretrained( |
|
MODEL_ID, |
|
torch_dtype=torch.bfloat16 |
|
) |
|
pipe.enable_model_cpu_offload() |
|
end_time = time.time() |
|
mem_reserved = torch.cuda.memory_reserved(0)/1024**3 if DEVICE == "cuda" else 0 |
|
print(f"4-bit BNB pipeline loaded in {end_time - start_time:.2f}s. Memory reserved: {mem_reserved:.2f} GB") |
|
CACHED_PIPE = pipe |
|
return pipe |
|
except Exception as e: |
|
print(f"Error loading 4-bit BNB pipeline: {e}") |
|
raise |
|
|
|
@spaces.GPU(duration=240) |
|
def generate_image(prompt, progress=gr.Progress(track_tqdm=True)): |
|
"""Generate image using 4-bit quantized model""" |
|
if not prompt: |
|
return None, "Please enter a prompt." |
|
|
|
progress(0.2, desc="Loading 4-bit quantized model...") |
|
|
|
try: |
|
|
|
pipe = load_bnb_4bit_pipeline() |
|
|
|
|
|
pipe_kwargs = { |
|
"prompt": prompt, |
|
"height": DEFAULT_HEIGHT, |
|
"width": DEFAULT_WIDTH, |
|
"guidance_scale": DEFAULT_GUIDANCE_SCALE, |
|
"num_inference_steps": DEFAULT_NUM_INFERENCE_STEPS, |
|
"max_sequence_length": DEFAULT_MAX_SEQUENCE_LENGTH, |
|
} |
|
|
|
|
|
seed = random.getrandbits(64) |
|
print(f"Using seed: {seed}") |
|
|
|
progress(0.5, desc="Generating image...") |
|
|
|
|
|
gen_start_time = time.time() |
|
image = pipe(**pipe_kwargs, generator=torch.manual_seed(seed)).images[0] |
|
gen_end_time = time.time() |
|
|
|
print(f"Image generated in {gen_end_time - gen_start_time:.2f} seconds") |
|
mem_reserved = torch.cuda.memory_reserved(0)/1024**3 if DEVICE == "cuda" else 0 |
|
print(f"Memory reserved: {mem_reserved:.2f} GB") |
|
|
|
return image, f"Generation complete! (Seed: {seed})" |
|
|
|
except Exception as e: |
|
print(f"Error during generation: {e}") |
|
return None, f"Error: {e}" |
|
|
|
|
|
with gr.Blocks(title="FLUXllama", theme=gr.themes.Soft()) as demo: |
|
gr.HTML( |
|
""" |
|
<div style='text-align: center; margin-bottom: 20px;'> |
|
<h1>FLUXllama</h1> |
|
<p>FLUX.1-dev 4-bit Quantized Version</p> |
|
</div> |
|
""" |
|
) |
|
|
|
gr.HTML( |
|
""" |
|
<div class='container' style='display:flex; justify-content:center; gap:12px; margin-bottom: 20px;'> |
|
<a href="https://huggingface.co/spaces/openfree/Best-AI" target="_blank"> |
|
<img src="https://img.shields.io/static/v1?label=OpenFree&message=BEST%20AI%20Services&color=%230000ff&labelColor=%23000080&logo=huggingface&logoColor=%23ffa500&style=for-the-badge" alt="OpenFree badge"> |
|
</a> |
|
|
|
<a href="https://discord.gg/openfreeai" target="_blank"> |
|
<img src="https://img.shields.io/static/v1?label=Discord&message=Openfree%20AI&color=%230000ff&labelColor=%23800080&logo=discord&logoColor=white&style=for-the-badge" alt="Discord badge"> |
|
</a> |
|
</div> |
|
""" |
|
) |
|
|
|
with gr.Row(): |
|
prompt_input = gr.Textbox( |
|
label="Enter your prompt", |
|
placeholder="e.g., A photorealistic portrait of an astronaut on Mars", |
|
lines=2, |
|
scale=4 |
|
) |
|
generate_button = gr.Button("Generate", variant="primary", scale=1) |
|
|
|
output_image = gr.Image( |
|
label="Generated Image (4-bit Quantized)", |
|
type="pil", |
|
height=600 |
|
) |
|
|
|
status_text = gr.Textbox( |
|
label="Status", |
|
interactive=False, |
|
lines=1 |
|
) |
|
|
|
|
|
generate_button.click( |
|
fn=generate_image, |
|
inputs=[prompt_input], |
|
outputs=[output_image, status_text] |
|
) |
|
|
|
|
|
prompt_input.submit( |
|
fn=generate_image, |
|
inputs=[prompt_input], |
|
outputs=[output_image, status_text] |
|
) |
|
|
|
|
|
gr.Examples( |
|
examples=[ |
|
"A photorealistic portrait of an astronaut on Mars", |
|
"Water-color painting of a cat wearing sunglasses", |
|
"Neo-tokyo cyberpunk cityscape at night, rain-soaked streets, 8K", |
|
"A majestic dragon flying over a medieval castle at sunset", |
|
"Abstract art representing the concept of time and space", |
|
"Detailed oil painting of a steampunk clockwork city", |
|
"Underwater scene with bioluminescent creatures in deep ocean", |
|
"Japanese garden in autumn with falling maple leaves" |
|
], |
|
inputs=prompt_input |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch(share=True) |