import gradio as gr import subprocess import os import shutil from pathlib import Path from PIL import Image import spaces # ----------------------------------------------------------------------------- # CONFIGURE THESE PATHS TO MATCH YOUR PROJECT STRUCTURE # ----------------------------------------------------------------------------- INPUT_DIR = "samples" OUTPUT_DIR = "inference_results/coz_vlmprompt" # ----------------------------------------------------------------------------- # HELPER FUNCTION TO RUN INFERENCE AND RETURN THE OUTPUT IMAGE # ----------------------------------------------------------------------------- @spaces.GPU() def run_with_upload(uploaded_image_path, upscale_option): """ 1) Clear INPUT_DIR 2) Save the uploaded file as input.png in INPUT_DIR 3) Read `upscale_option` (e.g. "1x", "2x", "4x") → turn it into "1", "2", or "4" 4) Call inference_coz.py with `--upscale ` 5) (Here we assume you still stitch together 1.png–4.png, or however you want.) """ # 1) Make sure INPUT_DIR exists; if it does, delete everything inside. os.makedirs(INPUT_DIR, exist_ok=True) for fn in os.listdir(INPUT_DIR): full_path = os.path.join(INPUT_DIR, fn) try: if os.path.isfile(full_path) or os.path.islink(full_path): os.remove(full_path) elif os.path.isdir(full_path): shutil.rmtree(full_path) except Exception as e: print(f"Warning: could not delete {full_path}: {e}") # 2) Copy the uploaded image into INPUT_DIR. # Gradio will give us a path like "/tmp/gradio_xyz.png" if uploaded_image_path is None: return None try: # Open with PIL (this handles JPEG, BMP, TIFF, etc.) pil_img = Image.open(uploaded_image_path).convert("RGB") except Exception as e: print(f"Error: could not open uploaded image: {e}") return None # Save it as "input.png" in our INPUT_DIR save_path = Path(INPUT_DIR) / "input.png" try: pil_img.save(save_path, format="PNG") except Exception as e: print(f"Error: could not save as PNG: {e}") return None # 3) Build and run your inference_coz.py command. # This will block until it completes. upscale_value = upscale_option.replace("x", "") # e.g. "2x" → "2" cmd = [ "python", "inference_coz.py", "-i", INPUT_DIR, "-o", OUTPUT_DIR, "--rec_type", "recursive_multiscale", "--prompt_type", "vlm", "--upscale", upscale_value, "--lora_path", "ckpt/SR_LoRA/model_20001.pkl", "--vae_path", "ckpt/SR_VAE/vae_encoder_20001.pt", "--pretrained_model_name_or_path", "stabilityai/stable-diffusion-3-medium-diffusers", "--ram_ft_path", "ckpt/DAPE/DAPE.pth", "--ram_path", "ckpt/RAM/ram_swin_large_14m.pth" ] try: subprocess.run(cmd, check=True) except subprocess.CalledProcessError as err: # If inference_coz.py crashes, we can print/log the error. print("Inference failed:", err) return None # ------------------------------------------------------------------------- # 4) After inference, look for the four numbered PNGs and stitch them # ------------------------------------------------------------------------- per_sample_dir = os.path.join(OUTPUT_DIR, "per-sample", "input") expected_files = [os.path.join(per_sample_dir, f"{i}.png") for i in range(1, 5)] pil_images = [] for fp in expected_files: if not os.path.isfile(fp): print(f"Warning: expected file not found: {fp}") return None try: img = Image.open(fp).convert("RGB") pil_images.append(img) except Exception as e: print(f"Error opening {fp}: {e}") return None if len(pil_images) != 4: print(f"Error: found {len(pil_images)} images, but need 4.") return None widths, heights = zip(*(im.size for im in pil_images)) w, h = widths[0], heights[0] grid_w = w * 2 grid_h = h * 2 # composite = Image.new("RGB", (grid_w, grid_h)) # composite.paste(pil_images[0], (0, 0)) # composite.paste(pil_images[1], (w, 0)) # composite.paste(pil_images[2], (0, h)) # composite.paste(pil_images[3], (w, h)) return [pil_images[0], pil_images[1], pil_images[2], pil_images[3]] # ------------------------------------------------------------- # BUILD THE GRADIO INTERFACE # ----------------------------------------------------------------------------- css=""" #col-container { margin: 0 auto; max-width: 1024px; } """ with gr.Blocks(css=css) as demo: gr.HTML( """

Chain-of-Zoom

Extreme Super-Resolution via Scale Autoregression and Preference Alignment


""" ) with gr.Column(elem_id="col-container"): with gr.Row(): with gr.Column(): # 1) Image upload component. We set type="filepath" so the callback # (run_with_upload) will receive a local path to the uploaded file. upload_image = gr.Image( label="Upload your input image", type="filepath" ) # 2) Radio for choosing 1× / 2× / 4× upscaling upscale_radio = gr.Radio( choices=["1x", "2x", "4x"], value="2x", show_label=False ) # 2) A button that the user will click to launch inference. run_button = gr.Button("Chain-of-Zoom it") # (3) Gallery to display multiple output images output_gallery = gr.Gallery( label="Inference Results", show_label=True, elem_id="gallery", columns=[2], rows=[2] ) # Wire the button: when clicked, call run_with_upload(upload_image), put # its return value into output_image. run_button.click( fn=run_with_upload, inputs=[upload_image, upscale_radio], outputs=output_gallery ) # ----------------------------------------------------------------------------- # START THE GRADIO SERVER # ----------------------------------------------------------------------------- demo.launch(share=True)