flx-upscale / app.py
fantaxy's picture
Update app.py
8d65abf verified
raw
history blame
6.63 kB
import logging
import random
import warnings
import os
import gradio as gr
import numpy as np
import spaces
import torch
from diffusers import FluxControlNetModel
from diffusers.pipelines import FluxControlNetPipeline
from gradio_imageslider import ImageSlider
from PIL import Image
from huggingface_hub import snapshot_download
# 메모리 관리를 위한 gc 추가
import gc
gc.collect()
torch.cuda.empty_cache()
css = """
#col-container {
margin: 0 auto;
max-width: 512px;
}
"""
# Device setup with minimal memory usage
if torch.cuda.is_available():
power_device = "GPU"
device = "cuda"
dtype = torch.float16 # Use float16 for minimum memory
# Set CUDA memory fraction to 50%
torch.cuda.set_per_process_memory_fraction(0.5)
else:
power_device = "CPU"
device = "cpu"
dtype = torch.float32
huggingface_token = os.getenv("HUGGINFACE_TOKEN")
# Minimal model configuration
model_config = {
"low_cpu_mem_usage": True,
"torch_dtype": dtype,
"use_safetensors": True,
"variant": "fp16", # Use fp16 variant if available
}
model_path = snapshot_download(
repo_id="black-forest-labs/FLUX.1-dev",
repo_type="model",
ignore_patterns=["*.md", "*..gitattributes", "*.bin"], # Ignore unnecessary files
local_dir="FLUX.1-dev",
token=huggingface_token,
)
# Load models with minimal configuration
try:
controlnet = FluxControlNetModel.from_pretrained(
"jasperai/Flux.1-dev-Controlnet-Upscaler",
**model_config
).to(device)
pipe = FluxControlNetPipeline.from_pretrained(
model_path,
controlnet=controlnet,
**model_config
)
# Enable all memory optimizations
pipe.enable_model_cpu_offload()
pipe.enable_attention_slicing(1)
pipe.enable_sequential_cpu_offload()
pipe.enable_vae_slicing()
# Clear memory after loading
gc.collect()
torch.cuda.empty_cache()
except Exception as e:
print(f"Error loading models: {e}")
raise
# Extremely reduced parameters
MAX_SEED = 1000000
MAX_PIXEL_BUDGET = 128 * 128 # Extremely reduced from 256 * 256
def check_resources():
if torch.cuda.is_available():
memory_allocated = torch.cuda.memory_allocated(0)
memory_reserved = torch.cuda.memory_reserved(0)
if memory_allocated/memory_reserved > 0.7: # 70% threshold
gc.collect()
torch.cuda.empty_cache()
return True
def process_input(input_image, upscale_factor, **kwargs):
input_image = input_image.convert('RGB')
# Reduce image size more aggressively
w, h = input_image.size
max_size = int(np.sqrt(MAX_PIXEL_BUDGET))
if w > max_size or h > max_size:
if w > h:
new_w = max_size
new_h = int(h * max_size / w)
else:
new_h = max_size
new_w = int(w * max_size / h)
input_image = input_image.resize((new_w, new_h), Image.LANCZOS)
w, h = input_image.size
w = w - w % 8
h = h - h % 8
return input_image.resize((w, h)), w, h, True
@spaces.GPU
def infer(
seed,
randomize_seed,
input_image,
num_inference_steps,
upscale_factor,
controlnet_conditioning_scale,
progress=gr.Progress(track_tqdm=True),
):
try:
gc.collect()
torch.cuda.empty_cache()
if randomize_seed:
seed = random.randint(0, MAX_SEED)
input_image, w, h, _ = process_input(input_image, upscale_factor)
with torch.inference_mode():
generator = torch.Generator().manual_seed(seed)
image = pipe(
prompt="",
control_image=input_image,
controlnet_conditioning_scale=controlnet_conditioning_scale,
num_inference_steps=num_inference_steps,
guidance_scale=2.0, # Reduced from 3.5
height=h,
width=w,
generator=generator,
).images[0]
gc.collect()
torch.cuda.empty_cache()
return [input_image, image, seed]
except Exception as e:
gr.Error(f"An error occurred: {str(e)}")
return None
with gr.Blocks(theme="Yntec/HaleyCH_Theme_Orange", css=css) as demo:
with gr.Row():
run_button = gr.Button(value="Run")
with gr.Row():
with gr.Column(scale=4):
input_im = gr.Image(label="Input Image", type="pil")
with gr.Column(scale=1):
num_inference_steps = gr.Slider(
label="Steps",
minimum=1,
maximum=20, # Reduced from 30
step=1,
value=10, # Reduced from 20
)
upscale_factor = gr.Slider(
label="Scale",
minimum=1,
maximum=1, # Fixed at 1
step=1,
value=1,
)
controlnet_conditioning_scale = gr.Slider(
label="Control Scale",
minimum=0.1,
maximum=0.5, # Reduced from 1.0
step=0.1,
value=0.3, # Reduced from 0.5
)
seed = gr.Slider(
label="Seed",
minimum=0,
maximum=MAX_SEED,
step=1,
value=42,
)
randomize_seed = gr.Checkbox(label="Random Seed", value=True)
with gr.Row():
result = ImageSlider(label="Result", type="pil", interactive=True)
current_dir = os.path.dirname(os.path.abspath(__file__))
examples = gr.Examples(
examples=[
[42, False, os.path.join(current_dir, "z1.webp"), 10, 1, 0.3],
[42, False, os.path.join(current_dir, "z2.webp"), 10, 1, 0.3],
],
inputs=[
seed,
randomize_seed,
input_im,
num_inference_steps,
upscale_factor,
controlnet_conditioning_scale,
],
fn=infer,
outputs=result,
cache_examples=False, # Disable caching
)
gr.on(
[run_button.click],
fn=infer,
inputs=[
seed,
randomize_seed,
input_im,
num_inference_steps,
upscale_factor,
controlnet_conditioning_scale,
],
outputs=result,
show_api=False,
)
# Launch with minimal resources
demo.queue(max_size=1).launch(
share=False,
debug=True,
show_error=True,
max_threads=1,
enable_queue=True,
cache_examples=False,
quiet=True,
)