erwold
ZeroGPU
7ffc337
raw
history blame
13.1 kB
import gradio as gr
import torch
from PIL import Image
import os
from transformers import CLIPTokenizer, CLIPTextModel, AutoProcessor, T5EncoderModel, T5TokenizerFast
from diffusers import AutoencoderKL, FlowMatchEulerDiscreteScheduler
from flux.transformer_flux import FluxTransformer2DModel
from flux.pipeline_flux_chameleon import FluxPipeline
import torch.nn as nn
import math
import logging
import sys
from huggingface_hub import snapshot_download
from qwen2_vl.modeling_qwen2_vl import Qwen2VLSimplifiedModel
import spaces
# 设置日志
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[logging.StreamHandler(sys.stdout)]
)
logger = logging.getLogger(__name__)
MODEL_ID = "Djrango/Qwen2vl-Flux"
MODEL_CACHE_DIR = "model_cache"
device = "cuda" if torch.cuda.is_available() else "cpu"
dtype = torch.bfloat16
# 预下载模型
if not os.path.exists(MODEL_CACHE_DIR):
logger.info("Starting model download...")
try:
snapshot_download(
repo_id=MODEL_ID,
local_dir=MODEL_CACHE_DIR,
local_dir_use_symlinks=False
)
logger.info("Model download completed successfully")
except Exception as e:
logger.error(f"Error downloading models: {str(e)}")
raise
# 加载小模型到 GPU
logger.info("Loading small models to GPU...")
tokenizer = CLIPTokenizer.from_pretrained(os.path.join(MODEL_CACHE_DIR, "flux/tokenizer"))
text_encoder = CLIPTextModel.from_pretrained(
os.path.join(MODEL_CACHE_DIR, "flux/text_encoder")
).to(dtype).to(device)
text_encoder_two = T5EncoderModel.from_pretrained(
os.path.join(MODEL_CACHE_DIR, "flux/text_encoder_2")
).to(dtype).to(device)
tokenizer_two = T5TokenizerFast.from_pretrained(
os.path.join(MODEL_CACHE_DIR, "flux/tokenizer_2"))
# 大模型初始加载到 CPU
logger.info("Loading large models to CPU...")
vae = AutoencoderKL.from_pretrained(
os.path.join(MODEL_CACHE_DIR, "flux/vae")
).to(dtype).cpu()
transformer = FluxTransformer2DModel.from_pretrained(
os.path.join(MODEL_CACHE_DIR, "flux/transformer")
).to(dtype).cpu()
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained(
os.path.join(MODEL_CACHE_DIR, "flux/scheduler"),
shift=1
)
qwen2vl = Qwen2VLSimplifiedModel.from_pretrained(
os.path.join(MODEL_CACHE_DIR, "qwen2-vl")
).to(dtype).cpu()
qwen2vl_processor = AutoProcessor.from_pretrained(
MODEL_ID,
subfolder="qwen2-vl",
min_pixels=256*28*28,
max_pixels=256*28*28
)
# 加载 connector 和 embedder 到 CPU
class Qwen2Connector(nn.Module):
def __init__(self, input_dim=3584, output_dim=4096):
super().__init__()
self.linear = nn.Linear(input_dim, output_dim)
def forward(self, x):
return self.linear(x)
connector = Qwen2Connector().to(dtype).cpu()
connector_path = os.path.join(MODEL_CACHE_DIR, "qwen2-vl/connector.pt")
connector_state = torch.load(connector_path, map_location='cpu')
connector_state = {k.replace('module.', ''): v.to(dtype) for k, v in connector_state.items()}
connector.load_state_dict(connector_state)
t5_context_embedder = nn.Linear(4096, 3072).to(dtype).cpu()
t5_embedder_path = os.path.join(MODEL_CACHE_DIR, "qwen2-vl/t5_embedder.pt")
t5_embedder_state = torch.load(t5_embedder_path, map_location='cpu')
t5_embedder_state = {k: v.to(dtype) for k, v in t5_embedder_state.items()}
t5_context_embedder.load_state_dict(t5_embedder_state)
# 创建pipeline (先用CPU上的模型)
pipeline = FluxPipeline(
transformer=transformer,
scheduler=scheduler,
vae=vae,
text_encoder=text_encoder,
tokenizer=tokenizer,
)
# 设置所有模型为eval模式
for model in [text_encoder, text_encoder_two, vae, transformer, qwen2vl,
connector, t5_context_embedder]:
model.requires_grad_(False)
model.eval()
# Aspect ratio options
ASPECT_RATIOS = {
"1:1": (1024, 1024),
"16:9": (1344, 768),
"9:16": (768, 1344),
"2.4:1": (1536, 640),
"3:4": (896, 1152),
"4:3": (1152, 896),
}
def process_image(image):
"""Process image with Qwen2VL model"""
try:
# 将 Qwen2VL 相关模型移到 GPU
logger.info("Moving Qwen2VL models to GPU...")
qwen2vl.to(device)
connector.to(device)
message = [
{
"role": "user",
"content": [
{"type": "image", "image": image},
{"type": "text", "text": "Describe this image."},
]
}
]
text = qwen2vl_processor.apply_chat_template(
message,
tokenize=False,
add_generation_prompt=True
)
with torch.no_grad():
inputs = qwen2vl_processor(
text=[text],
images=[image],
padding=True,
return_tensors="pt"
).to(device)
output_hidden_state, image_token_mask, image_grid_thw = qwen2vl(**inputs)
image_hidden_state = output_hidden_state[image_token_mask].view(1, -1, output_hidden_state.size(-1))
image_hidden_state = connector(image_hidden_state)
# 保存结果到 CPU
result = (image_hidden_state.cpu(), image_grid_thw)
# 将模型移回 CPU 并清理显存
logger.info("Moving Qwen2VL models back to CPU...")
qwen2vl.cpu()
connector.cpu()
torch.cuda.empty_cache()
return result
except Exception as e:
logger.error(f"Error in process_image: {str(e)}")
raise
def compute_t5_text_embeddings(prompt):
"""Compute T5 embeddings for text prompt"""
if prompt == "":
return None
text_inputs = tokenizer_two(
prompt,
padding="max_length",
max_length=256,
truncation=True,
return_tensors="pt"
).to(device)
prompt_embeds = text_encoder_two(text_inputs.input_ids)[0]
# 将 t5_context_embedder 移到 GPU
t5_context_embedder.to(device)
prompt_embeds = t5_context_embedder(prompt_embeds)
# 将 t5_context_embedder 移回 CPU
t5_context_embedder.cpu()
return prompt_embeds
def compute_text_embeddings(prompt=""):
"""Compute text embeddings for the prompt"""
with torch.no_grad():
text_inputs = tokenizer(
prompt,
padding="max_length",
max_length=77,
truncation=True,
return_tensors="pt"
).to(device)
prompt_embeds = text_encoder(
text_inputs.input_ids,
output_hidden_states=False
)
return prompt_embeds.pooler_output
@spaces.GPU(duration=120) # 使用ZeroGPU装饰器
def generate_images(input_image, prompt="", guidance_scale=3.5,
num_inference_steps=28, num_images=1, seed=None, aspect_ratio="1:1"):
"""Generate images using the pipeline"""
try:
logger.info(f"Starting generation with prompt: {prompt}")
if input_image is None:
raise ValueError("No input image provided")
if seed is not None:
torch.manual_seed(seed)
logger.info(f"Set random seed to: {seed}")
# Process image with Qwen2VL
qwen2_hidden_state, image_grid_thw = process_image(input_image)
# Compute text embeddings
pooled_prompt_embeds = compute_text_embeddings(prompt)
t5_prompt_embeds = compute_t5_text_embeddings(prompt)
# Get dimensions
width, height = ASPECT_RATIOS[aspect_ratio]
logger.info(f"Using dimensions: {width}x{height}")
# Generate images
try:
logger.info("Starting image generation...")
# 将 Transformer 和 VAE 移到 GPU
logger.info("Moving Transformer and VAE to GPU...")
transformer.to(device)
vae.to(device)
# 更新 pipeline 中的模型引用
pipeline.transformer = transformer
pipeline.vae = vae
output_images = pipeline(
prompt_embeds=qwen2_hidden_state.to(device).repeat(num_images, 1, 1),
pooled_prompt_embeds=pooled_prompt_embeds,
t5_prompt_embeds=t5_prompt_embeds.repeat(num_images, 1, 1) if t5_prompt_embeds is not None else None,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
height=height,
width=width,
).images
logger.info("Image generation completed")
# 将 Transformer 和 VAE 移回 CPU
logger.info("Moving models back to CPU...")
transformer.cpu()
vae.cpu()
torch.cuda.empty_cache()
return output_images
except Exception as e:
raise RuntimeError(f"Error generating images: {str(e)}")
except Exception as e:
logger.error(f"Error during generation: {str(e)}")
raise gr.Error(f"Generation failed: {str(e)}")
# Create Gradio interface
with gr.Blocks(
theme=gr.themes.Soft(),
css="""
.container { max-width: 1200px; margin: auto; padding: 0 20px; }
.header { text-align: center; margin: 20px 0 40px 0; padding: 20px; background: #f7f7f7; border-radius: 12px; }
.param-row { padding: 10px 0; }
footer { margin-top: 40px; padding: 20px; border-top: 1px solid #eee; }
"""
) as demo:
with gr.Column(elem_classes="container"):
gr.Markdown("""
<div class="header">
# 🎨 Qwen2vl-Flux Image Variation Demo
Generate creative variations of your images with optional text guidance
</div>
""")
with gr.Row(equal_height=True):
with gr.Column(scale=1):
input_image = gr.Image(
label="Upload Your Image",
type="pil",
height=384,
sources=["upload", "clipboard"]
)
with gr.Accordion("Advanced Settings", open=False):
with gr.Group():
prompt = gr.Textbox(
label="Text Prompt (Optional)",
placeholder="As Long As Possible...",
lines=3
)
with gr.Row(elem_classes="param-row"):
guidance = gr.Slider(
minimum=1,
maximum=10,
value=3.5,
step=0.5,
label="Guidance Scale"
)
steps = gr.Slider(
minimum=1,
maximum=30,
value=28,
step=1,
label="Sampling Steps"
)
with gr.Row(elem_classes="param-row"):
num_images = gr.Slider(
minimum=1,
maximum=2,
value=1, # 默认改为1
step=1,
label="Number of Images"
)
seed = gr.Number(
label="Random Seed",
value=None,
precision=0
)
aspect_ratio = gr.Radio(
label="Aspect Ratio",
choices=["1:1", "16:9", "9:16", "2.4:1", "3:4", "4:3"],
value="1:1"
)
submit_btn = gr.Button("🎨 Generate", variant="primary", size="lg")
with gr.Column(scale=1):
output_gallery = gr.Gallery(
label="Generated Variations",
columns=2,
rows=2,
height=700,
object_fit="contain",
show_label=True,
allow_preview=True
)
submit_btn.click(
fn=generate_images,
inputs=[
input_image,
prompt,
guidance,
steps,
num_images,
seed,
aspect_ratio
],
outputs=[output_gallery],
show_progress=True
)
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False
)