Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import subprocess | |
| import os | |
| import shutil | |
| from pathlib import Path | |
| from PIL import Image | |
| import spaces | |
| # ----------------------------------------------------------------------------- | |
| # CONFIGURE THESE PATHS TO MATCH YOUR PROJECT STRUCTURE | |
| # ----------------------------------------------------------------------------- | |
| INPUT_DIR = "samples" | |
| OUTPUT_DIR = "inference_results/coz_vlmprompt" | |
| # ----------------------------------------------------------------------------- | |
| # HELPER FUNCTION TO RUN INFERENCE AND RETURN THE OUTPUT IMAGE | |
| # ----------------------------------------------------------------------------- | |
| def run_with_upload(uploaded_image_path): | |
| """ | |
| 1) Clear out INPUT_DIR (so old samples don’t linger). | |
| 2) Copy the uploaded image into INPUT_DIR. | |
| 3) Run your inference_coz.py command (which reads from -i INPUT_DIR). | |
| 4) After it finishes, find the most recently‐modified PNG in OUTPUT_DIR. | |
| 5) Return a PIL.Image, which Gradio will display. | |
| """ | |
| # 1) Make sure INPUT_DIR exists; if it does, delete everything inside. | |
| os.makedirs(INPUT_DIR, exist_ok=True) | |
| for fn in os.listdir(INPUT_DIR): | |
| full_path = os.path.join(INPUT_DIR, fn) | |
| try: | |
| if os.path.isfile(full_path) or os.path.islink(full_path): | |
| os.remove(full_path) | |
| elif os.path.isdir(full_path): | |
| shutil.rmtree(full_path) | |
| except Exception as e: | |
| print(f"Warning: could not delete {full_path}: {e}") | |
| # 2) Copy the uploaded image into INPUT_DIR. | |
| # Gradio will give us a path like "/tmp/gradio_xyz.png" | |
| if uploaded_image_path is None: | |
| return None | |
| try: | |
| # Open with PIL (this handles JPEG, BMP, TIFF, etc.) | |
| pil_img = Image.open(uploaded_image_path).convert("RGB") | |
| except Exception as e: | |
| print(f"Error: could not open uploaded image: {e}") | |
| return None | |
| # Save it as "input.png" in our INPUT_DIR | |
| save_path = Path(INPUT_DIR) / "input.png" | |
| try: | |
| pil_img.save(save_path, format="PNG") | |
| except Exception as e: | |
| print(f"Error: could not save as PNG: {e}") | |
| return None | |
| # 3) Build and run your inference_coz.py command. | |
| # This will block until it completes. | |
| cmd = [ | |
| "python", "inference_coz.py", | |
| "-i", INPUT_DIR, | |
| "-o", OUTPUT_DIR, | |
| "--rec_type", "recursive_multiscale", | |
| "--prompt_type", "vlm", | |
| "--upscale", "2", | |
| "--lora_path", "ckpt/SR_LoRA/model_20001.pkl", | |
| "--vae_path", "ckpt/SR_VAE/vae_encoder_20001.pt", | |
| "--pretrained_model_name_or_path", "stabilityai/stable-diffusion-3-medium-diffusers", | |
| "--ram_ft_path", "ckpt/DAPE/DAPE.pth", | |
| "--ram_path", "ckpt/RAM/ram_swin_large_14m.pth" | |
| ] | |
| try: | |
| subprocess.run(cmd, check=True) | |
| except subprocess.CalledProcessError as err: | |
| # If inference_coz.py crashes, we can print/log the error. | |
| print("Inference failed:", err) | |
| return None | |
| # 4) After it finishes, scan OUTPUT_DIR for .png files. | |
| RECUSIVE_DIR = f'{OUTPUT_DIR}/recursive' | |
| if not os.path.isdir(RECUSIVE_DIR): | |
| return None | |
| png_files = [ | |
| os.path.join(RECUSIVE_DIR, fn) | |
| for fn in os.listdir(RECUSIVE_DIR) | |
| if fn.lower().endswith(".png") | |
| ] | |
| if not png_files: | |
| return None | |
| # 5) Pick the most recently‐modified PNG | |
| latest_png = max(png_files, key=os.path.getmtime) | |
| # 6) Open and return a PIL.Image. Gradio will display it automatically. | |
| try: | |
| img = Image.open(latest_png).convert("RGB") | |
| except Exception as e: | |
| print(f"Error opening {latest_png}: {e}") | |
| return None | |
| return img | |
| # ----------------------------------------------------------------------------- | |
| # BUILD THE GRADIO INTERFACE | |
| # ----------------------------------------------------------------------------- | |
| with gr.Blocks() as demo: | |
| gr.Markdown("## Upload an image, then click **Run Inference** to process it.") | |
| # 1) Image upload component. We set type="filepath" so the callback | |
| # (run_with_upload) will receive a local path to the uploaded file. | |
| upload_image = gr.Image( | |
| label="Upload your input image", | |
| type="filepath" | |
| ) | |
| # 2) A button that the user will click to launch inference. | |
| run_button = gr.Button("Run Inference") | |
| # 3) An output <Image> where we will show the final PNG. | |
| output_image = gr.Image( | |
| label="Inference Result", | |
| type="pil" # because run_with_upload() returns a PIL.Image | |
| ) | |
| # Wire the button: when clicked, call run_with_upload(upload_image), put | |
| # its return value into output_image. | |
| run_button.click( | |
| fn=run_with_upload, | |
| inputs=upload_image, | |
| outputs=output_image | |
| ) | |
| # ----------------------------------------------------------------------------- | |
| # START THE GRADIO SERVER | |
| # ----------------------------------------------------------------------------- | |
| demo.launch(share=True) | |