Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from PIL import Image | |
import os | |
from transformers import CLIPTokenizer, CLIPTextModel, AutoProcessor, T5EncoderModel, T5TokenizerFast | |
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler | |
from flux.transformer_flux import FluxTransformer2DModel | |
from flux.pipeline_flux_chameleon import FluxPipeline | |
import torch.nn as nn | |
import math | |
import logging | |
import sys | |
from huggingface_hub import snapshot_download | |
from qwen2_vl.modeling_qwen2_vl import Qwen2VLSimplifiedModel | |
import spaces | |
# 设置日志 | |
logging.basicConfig( | |
level=logging.INFO, | |
format='%(asctime)s - %(levelname)s - %(message)s', | |
handlers=[logging.StreamHandler(sys.stdout)] | |
) | |
logger = logging.getLogger(__name__) | |
MODEL_ID = "Djrango/Qwen2vl-Flux" | |
MODEL_CACHE_DIR = "model_cache" | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
dtype = torch.bfloat16 | |
# 预下载模型 | |
if not os.path.exists(MODEL_CACHE_DIR): | |
logger.info("Starting model download...") | |
try: | |
snapshot_download( | |
repo_id=MODEL_ID, | |
local_dir=MODEL_CACHE_DIR, | |
local_dir_use_symlinks=False | |
) | |
logger.info("Model download completed successfully") | |
except Exception as e: | |
logger.error(f"Error downloading models: {str(e)}") | |
raise | |
# 加载小模型到 GPU | |
logger.info("Loading small models to GPU...") | |
tokenizer = CLIPTokenizer.from_pretrained(os.path.join(MODEL_CACHE_DIR, "flux/tokenizer")) | |
text_encoder = CLIPTextModel.from_pretrained( | |
os.path.join(MODEL_CACHE_DIR, "flux/text_encoder") | |
).to(dtype).to(device) | |
text_encoder_two = T5EncoderModel.from_pretrained( | |
os.path.join(MODEL_CACHE_DIR, "flux/text_encoder_2") | |
).to(dtype).to(device) | |
tokenizer_two = T5TokenizerFast.from_pretrained( | |
os.path.join(MODEL_CACHE_DIR, "flux/tokenizer_2")) | |
# 大模型初始加载到 CPU | |
logger.info("Loading large models to CPU...") | |
vae = AutoencoderKL.from_pretrained( | |
os.path.join(MODEL_CACHE_DIR, "flux/vae") | |
).to(dtype).cpu() | |
transformer = FluxTransformer2DModel.from_pretrained( | |
os.path.join(MODEL_CACHE_DIR, "flux/transformer") | |
).to(dtype).cpu() | |
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained( | |
os.path.join(MODEL_CACHE_DIR, "flux/scheduler"), | |
shift=1 | |
) | |
qwen2vl = Qwen2VLSimplifiedModel.from_pretrained( | |
os.path.join(MODEL_CACHE_DIR, "qwen2-vl") | |
).to(dtype).cpu() | |
qwen2vl_processor = AutoProcessor.from_pretrained( | |
MODEL_ID, | |
subfolder="qwen2-vl", | |
min_pixels=256*28*28, | |
max_pixels=256*28*28 | |
) | |
# 加载 connector 和 embedder 到 CPU | |
class Qwen2Connector(nn.Module): | |
def __init__(self, input_dim=3584, output_dim=4096): | |
super().__init__() | |
self.linear = nn.Linear(input_dim, output_dim) | |
def forward(self, x): | |
return self.linear(x) | |
connector = Qwen2Connector().to(dtype).cpu() | |
connector_path = os.path.join(MODEL_CACHE_DIR, "qwen2-vl/connector.pt") | |
connector_state = torch.load(connector_path, map_location='cpu') | |
connector_state = {k.replace('module.', ''): v.to(dtype) for k, v in connector_state.items()} | |
connector.load_state_dict(connector_state) | |
t5_context_embedder = nn.Linear(4096, 3072).to(dtype).cpu() | |
t5_embedder_path = os.path.join(MODEL_CACHE_DIR, "qwen2-vl/t5_embedder.pt") | |
t5_embedder_state = torch.load(t5_embedder_path, map_location='cpu') | |
t5_embedder_state = {k: v.to(dtype) for k, v in t5_embedder_state.items()} | |
t5_context_embedder.load_state_dict(t5_embedder_state) | |
# 创建pipeline (先用CPU上的模型) | |
pipeline = FluxPipeline( | |
transformer=transformer, | |
scheduler=scheduler, | |
vae=vae, | |
text_encoder=text_encoder, | |
tokenizer=tokenizer, | |
) | |
# 设置所有模型为eval模式 | |
for model in [text_encoder, text_encoder_two, vae, transformer, qwen2vl, | |
connector, t5_context_embedder]: | |
model.requires_grad_(False) | |
model.eval() | |
# Aspect ratio options | |
ASPECT_RATIOS = { | |
"1:1": (1024, 1024), | |
"16:9": (1344, 768), | |
"9:16": (768, 1344), | |
"2.4:1": (1536, 640), | |
"3:4": (896, 1152), | |
"4:3": (1152, 896), | |
} | |
def process_image(image): | |
"""Process image with Qwen2VL model""" | |
try: | |
# 将 Qwen2VL 相关模型移到 GPU | |
logger.info("Moving Qwen2VL models to GPU...") | |
qwen2vl.to(device) | |
connector.to(device) | |
message = [ | |
{ | |
"role": "user", | |
"content": [ | |
{"type": "image", "image": image}, | |
{"type": "text", "text": "Describe this image."}, | |
] | |
} | |
] | |
text = qwen2vl_processor.apply_chat_template( | |
message, | |
tokenize=False, | |
add_generation_prompt=True | |
) | |
with torch.no_grad(): | |
inputs = qwen2vl_processor( | |
text=[text], | |
images=[image], | |
padding=True, | |
return_tensors="pt" | |
).to(device) | |
output_hidden_state, image_token_mask, image_grid_thw = qwen2vl(**inputs) | |
image_hidden_state = output_hidden_state[image_token_mask].view(1, -1, output_hidden_state.size(-1)) | |
image_hidden_state = connector(image_hidden_state) | |
# 保存结果到 CPU | |
result = (image_hidden_state.cpu(), image_grid_thw) | |
# 将模型移回 CPU 并清理显存 | |
logger.info("Moving Qwen2VL models back to CPU...") | |
qwen2vl.cpu() | |
connector.cpu() | |
torch.cuda.empty_cache() | |
return result | |
except Exception as e: | |
logger.error(f"Error in process_image: {str(e)}") | |
raise | |
def compute_t5_text_embeddings(prompt): | |
"""Compute T5 embeddings for text prompt""" | |
if prompt == "": | |
return None | |
text_inputs = tokenizer_two( | |
prompt, | |
padding="max_length", | |
max_length=256, | |
truncation=True, | |
return_tensors="pt" | |
).to(device) | |
prompt_embeds = text_encoder_two(text_inputs.input_ids)[0] | |
# 将 t5_context_embedder 移到 GPU | |
t5_context_embedder.to(device) | |
prompt_embeds = t5_context_embedder(prompt_embeds) | |
# 将 t5_context_embedder 移回 CPU | |
t5_context_embedder.cpu() | |
return prompt_embeds | |
def compute_text_embeddings(prompt=""): | |
"""Compute text embeddings for the prompt""" | |
with torch.no_grad(): | |
text_inputs = tokenizer( | |
prompt, | |
padding="max_length", | |
max_length=77, | |
truncation=True, | |
return_tensors="pt" | |
).to(device) | |
prompt_embeds = text_encoder( | |
text_inputs.input_ids, | |
output_hidden_states=False | |
) | |
return prompt_embeds.pooler_output | |
# 使用ZeroGPU装饰器 | |
def generate_images(input_image, prompt="", guidance_scale=3.5, | |
num_inference_steps=28, num_images=1, seed=None, aspect_ratio="1:1"): | |
"""Generate images using the pipeline""" | |
try: | |
logger.info(f"Starting generation with prompt: {prompt}") | |
if input_image is None: | |
raise ValueError("No input image provided") | |
if seed is not None: | |
torch.manual_seed(seed) | |
logger.info(f"Set random seed to: {seed}") | |
# Process image with Qwen2VL | |
qwen2_hidden_state, image_grid_thw = process_image(input_image) | |
# Compute text embeddings | |
pooled_prompt_embeds = compute_text_embeddings(prompt) | |
t5_prompt_embeds = compute_t5_text_embeddings(prompt) | |
# Get dimensions | |
width, height = ASPECT_RATIOS[aspect_ratio] | |
logger.info(f"Using dimensions: {width}x{height}") | |
# Generate images | |
try: | |
logger.info("Starting image generation...") | |
# 将 Transformer 和 VAE 移到 GPU | |
logger.info("Moving Transformer and VAE to GPU...") | |
transformer.to(device) | |
vae.to(device) | |
# 更新 pipeline 中的模型引用 | |
pipeline.transformer = transformer | |
pipeline.vae = vae | |
output_images = pipeline( | |
prompt_embeds=qwen2_hidden_state.to(device).repeat(num_images, 1, 1), | |
pooled_prompt_embeds=pooled_prompt_embeds, | |
t5_prompt_embeds=t5_prompt_embeds.repeat(num_images, 1, 1) if t5_prompt_embeds is not None else None, | |
num_inference_steps=num_inference_steps, | |
guidance_scale=guidance_scale, | |
height=height, | |
width=width, | |
).images | |
logger.info("Image generation completed") | |
# 将 Transformer 和 VAE 移回 CPU | |
logger.info("Moving models back to CPU...") | |
transformer.cpu() | |
vae.cpu() | |
torch.cuda.empty_cache() | |
return output_images | |
except Exception as e: | |
raise RuntimeError(f"Error generating images: {str(e)}") | |
except Exception as e: | |
logger.error(f"Error during generation: {str(e)}") | |
raise gr.Error(f"Generation failed: {str(e)}") | |
# Create Gradio interface | |
with gr.Blocks( | |
theme=gr.themes.Soft(), | |
css=""" | |
.container { max-width: 1200px; margin: auto; padding: 0 20px; } | |
.header { text-align: center; margin: 20px 0 40px 0; padding: 20px; background: #f7f7f7; border-radius: 12px; } | |
.param-row { padding: 10px 0; } | |
footer { margin-top: 40px; padding: 20px; border-top: 1px solid #eee; } | |
""" | |
) as demo: | |
with gr.Column(elem_classes="container"): | |
gr.Markdown(""" | |
<div class="header"> | |
# 🎨 Qwen2vl-Flux Image Variation Demo | |
Generate creative variations of your images with optional text guidance | |
</div> | |
""") | |
with gr.Row(equal_height=True): | |
with gr.Column(scale=1): | |
input_image = gr.Image( | |
label="Upload Your Image", | |
type="pil", | |
height=384, | |
sources=["upload", "clipboard"] | |
) | |
with gr.Accordion("Advanced Settings", open=False): | |
with gr.Group(): | |
prompt = gr.Textbox( | |
label="Text Prompt (Optional)", | |
placeholder="As Long As Possible...", | |
lines=3 | |
) | |
with gr.Row(elem_classes="param-row"): | |
guidance = gr.Slider( | |
minimum=1, | |
maximum=10, | |
value=3.5, | |
step=0.5, | |
label="Guidance Scale" | |
) | |
steps = gr.Slider( | |
minimum=1, | |
maximum=30, | |
value=28, | |
step=1, | |
label="Sampling Steps" | |
) | |
with gr.Row(elem_classes="param-row"): | |
num_images = gr.Slider( | |
minimum=1, | |
maximum=2, | |
value=1, # 默认改为1 | |
step=1, | |
label="Number of Images" | |
) | |
seed = gr.Number( | |
label="Random Seed", | |
value=None, | |
precision=0 | |
) | |
aspect_ratio = gr.Radio( | |
label="Aspect Ratio", | |
choices=["1:1", "16:9", "9:16", "2.4:1", "3:4", "4:3"], | |
value="1:1" | |
) | |
submit_btn = gr.Button("🎨 Generate", variant="primary", size="lg") | |
with gr.Column(scale=1): | |
output_gallery = gr.Gallery( | |
label="Generated Variations", | |
columns=2, | |
rows=2, | |
height=700, | |
object_fit="contain", | |
show_label=True, | |
allow_preview=True | |
) | |
submit_btn.click( | |
fn=generate_images, | |
inputs=[ | |
input_image, | |
prompt, | |
guidance, | |
steps, | |
num_images, | |
seed, | |
aspect_ratio | |
], | |
outputs=[output_gallery], | |
show_progress=True | |
) | |
if __name__ == "__main__": | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=False | |
) |