Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import tempfile | |
| import uuid | |
| import torch | |
| from PIL import Image | |
| from torchvision import transforms | |
| from transformers import Qwen2_5_VLForConditionalGeneration, AutoProcessor | |
| from qwen_vl_utils import process_vision_info | |
| from osediff_sd3 import OSEDiff_SD3_TEST, SD3Euler | |
| from peft import PeftModel | |
| # ------------------------------------------------------------------- | |
| # Helper: Resize & center-crop to a fixed square | |
| # ------------------------------------------------------------------- | |
| def resize_and_center_crop(img: Image.Image, size: int) -> Image.Image: | |
| w, h = img.size | |
| scale = size / min(w, h) | |
| new_w, new_h = int(w * scale), int(h * scale) | |
| img = img.resize((new_w, new_h), Image.LANCZOS) | |
| left = (new_w - size) // 2 | |
| top = (new_h - size) // 2 | |
| return img.crop((left, top, left + size, top + size)) | |
| # ------------------------------------------------------------------- | |
| # Helper: Generate a single VLM prompt for recursive_multiscale | |
| # ------------------------------------------------------------------- | |
| def _generate_vlm_prompt( | |
| vlm_model: Qwen2_5_VLForConditionalGeneration, | |
| vlm_processor: AutoProcessor, | |
| process_vision_info, # this is your helper that turns “messages” → image_inputs / video_inputs | |
| prev_pil: Image.Image, # <– pass PIL instead of path | |
| zoomed_pil: Image.Image, # <– pass PIL instead of path | |
| device: str = "cuda" | |
| ) -> str: | |
| """ | |
| Given two PIL.Image inputs: | |
| - prev_pil: the “full” image at the previous recursion. | |
| - zoomed_pil: the cropped+resized (zoom) image for this step. | |
| Returns a single “recursive_multiscale” prompt string. | |
| """ | |
| # (1) System message | |
| message_text = ( | |
| "The second image is a zoom-in of the first image. " | |
| "Based on this knowledge, what is in the second image? " | |
| "Give me a set of words." | |
| ) | |
| # (2) Build the two-image “chat” payload | |
| # | |
| # Instead of passing a filename, we pass the actual PIL.Image. | |
| # The processor’s `process_vision_info` should know how to turn | |
| # a message of the form {"type":"image","image": PIL_IMAGE} into tensors. | |
| messages = [ | |
| {"role": "system", "content": message_text}, | |
| { | |
| "role": "user", | |
| "content": [ | |
| {"type": "image", "image": prev_pil}, | |
| {"type": "image", "image": zoomed_pil}, | |
| ], | |
| }, | |
| ] | |
| # (3) Now run the “chat” through the VL processor | |
| # | |
| # - `apply_chat_template` will build the tokenized prompt (without running it yet). | |
| # - `process_vision_info` should inspect the same `messages` list and return | |
| # `image_inputs` and `video_inputs` (tensors) for any attached PIL images. | |
| text = vlm_processor.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| image_inputs, video_inputs = process_vision_info(messages) | |
| inputs = vlm_processor( | |
| text=[text], | |
| images=image_inputs, | |
| videos=video_inputs, | |
| padding=True, | |
| return_tensors="pt", | |
| ).to(device) | |
| # (4) Generate and decode | |
| generated = vlm_model.generate(**inputs, max_new_tokens=128) | |
| trimmed = [ | |
| out_ids[len(in_ids):] | |
| for in_ids, out_ids in zip(inputs.input_ids, generated) | |
| ] | |
| out_text = vlm_processor.batch_decode( | |
| trimmed, skip_special_tokens=True, clean_up_tokenization_spaces=False | |
| )[0] | |
| return out_text.strip() | |
| VLM_NAME = "Qwen/Qwen2.5-VL-3B-Instruct" | |
| vlm_model = Qwen2_5_VLForConditionalGeneration.from_pretrained( | |
| VLM_NAME, | |
| torch_dtype="auto", | |
| device_map="auto" # immediately dispatches layers onto available GPUs | |
| ) | |
| vlm_processor = AutoProcessor.from_pretrained(VLM_NAME) | |
| vlm_model = PeftModel.from_pretrained(vlm_model, "ckpt/VLM_LoRA/checkpoint-10000") | |
| vlm_model = vlm_model.merge_and_unload() | |
| vlm_model.eval() | |
| device = "cuda" | |
| process_size = 512 | |
| LORA_PATH = "ckpt/SR_LoRA/model_20001.pkl" | |
| VAE_PATH = "ckpt/SR_VAE/vae_encoder_20001.pt" | |
| SD3_MODEL = "stabilityai/stable-diffusion-3-medium-diffusers" | |
| class _Args: | |
| pass | |
| args = _Args() | |
| args.upscale = 4 | |
| args.lora_path = LORA_PATH | |
| args.vae_path = VAE_PATH | |
| args.pretrained_model_name_or_path = SD3_MODEL | |
| args.merge_and_unload_lora = False | |
| args.lora_rank = 4 | |
| args.vae_decoder_tiled_size = 224 | |
| args.vae_encoder_tiled_size = 1024 | |
| args.latent_tiled_size = 96 | |
| args.latent_tiled_overlap = 32 | |
| args.mixed_precision = "fp16" | |
| args.efficient_memory = False | |
| sd3 = SD3Euler() | |
| sd3.text_enc_1.to(device) | |
| sd3.text_enc_2.to(device) | |
| sd3.text_enc_3.to(device) | |
| sd3.transformer.to(device, dtype=torch.float32) | |
| sd3.vae.to(device, dtype=torch.float32) | |
| for p in ( | |
| sd3.text_enc_1, | |
| sd3.text_enc_2, | |
| sd3.text_enc_3, | |
| sd3.transformer, | |
| sd3.vae, | |
| ): | |
| p.requires_grad_(False) | |
| model_test = OSEDiff_SD3_TEST(args, sd3) | |
| # ------------------------------------------------------------------- | |
| # Main Function: recursive_multiscale_sr (with multiple centers) | |
| # ------------------------------------------------------------------- | |
| def recursive_multiscale_sr( | |
| input_png_path: str, | |
| upscale: int, | |
| rec_num: int = 4, | |
| centers: list[tuple[float, float]] = None, | |
| ) -> tuple[list[Image.Image], list[str]]: | |
| """ | |
| Perform `rec_num` recursive_multiscale super-resolution steps on a single PNG. | |
| - input_png_path: path to a single .png file on disk. | |
| - upscale: integer up-scale factor per recursion (e.g. 4). | |
| - rec_num: how many recursion steps to perform. | |
| - centers: a list of normalized (x, y) tuples in [0, 1], one per recursion step, | |
| indicating where to center the low-res crop for each step. The list | |
| length must equal rec_num. If centers is None, defaults to center=(0.5, 0.5) | |
| for all steps. | |
| Returns a tuple (sr_pil_list, prompt_list), where: | |
| - sr_pil_list: list of PIL.Image outputs [SR1, SR2, …, SR_rec_num] in order. | |
| - prompt_list: list of the VLM prompts generated at each recursion. | |
| """ | |
| ############################### | |
| # 0. Validate / fill default centers | |
| ############################### | |
| if centers is None: | |
| # Default: use center (0.5, 0.5) for every recursion | |
| centers = [(0.5, 0.5) for _ in range(rec_num)] | |
| else: | |
| if not isinstance(centers, (list, tuple)) or len(centers) != rec_num: | |
| raise ValueError( | |
| f"`centers` must be a list of {rec_num} (x,y) tuples, but got length {len(centers)}." | |
| ) | |
| unique_id = uuid.uuid4().hex | |
| prefix = f"recms_{unique_id}_" | |
| with tempfile.TemporaryDirectory(prefix=prefix) as td: | |
| img0 = Image.open(input_png_path).convert("RGB") | |
| img0 = resize_and_center_crop(img0, process_size) | |
| prev_pil = img0.copy() | |
| sr_pil_list: list[Image.Image] = [] | |
| prompt_list: list[str] = [] | |
| for rec in range(rec_num): | |
| w, h = prev_pil.size # (512×512) | |
| new_w, new_h = w // upscale, h // upscale | |
| cx_norm, cy_norm = centers[rec] | |
| cx = int(cx_norm * w) | |
| cy = int(cy_norm * h) | |
| half_w, half_h = new_w // 2, new_h // 2 | |
| left = max(0, min(cx - half_w, w - new_w)) | |
| top = max(0, min(cy - half_h, h - new_h)) | |
| right, bottom = left + new_w, top + new_h | |
| cropped = prev_pil.crop((left, top, right, bottom)) | |
| zoomed_pil = cropped.resize((w, h), Image.BICUBIC) | |
| prompt_tag = _generate_vlm_prompt( | |
| vlm_model=vlm_model, | |
| vlm_processor=vlm_processor, | |
| process_vision_info=process_vision_info, | |
| prev_pil=prev_pil, # <– PIL | |
| zoomed_pil=zoomed_pil, # <– PIL | |
| device=device, | |
| ) | |
| to_tensor = transforms.ToTensor() | |
| lq = to_tensor(zoomed_pil).unsqueeze(0).to(device) # (1,3,512,512) | |
| lq = (lq * 2.0) - 1.0 | |
| with torch.no_grad(): | |
| out_tensor = model_test(lq, prompt=prompt_tag)[0] | |
| out_tensor = out_tensor.clamp(-1.0, 1.0).cpu() | |
| out_pil = transforms.ToPILImage()((out_tensor * 0.5) + 0.5) | |
| prev_pil = out_pil | |
| # (G) Append to results | |
| sr_pil_list.append(out_pil) | |
| prompt_list.append(prompt_tag) | |
| return sr_pil_list, prompt_list |