import os import shutil import tempfile import threading import subprocess from pathlib import Path import gradio as gr from PIL import Image RIFE_DIR = Path("rife") # Local bundled repo OUTPUT_DIR = RIFE_DIR / "output" LOCK = threading.Lock() def interpolate(a: Image.Image, b: Image.Image, fps: int = 14, exp: int = 2): with LOCK: # Normalize inputs a = a.convert("RGB") b = b.convert("RGB") if b.size != a.size: b = b.resize(a.size, Image.BICUBIC) work_dir = Path(tempfile.mkdtemp(prefix="rife_run_")) p1 = work_dir / "a.png" p2 = work_dir / "b.png" a.save(p1, "PNG") b.save(p2, "PNG") # Clean previous outputs if OUTPUT_DIR.exists(): shutil.rmtree(OUTPUT_DIR) OUTPUT_DIR.mkdir(parents=True, exist_ok=True) # Run RIFE inference cmd = ["python3", str(RIFE_DIR / "inference_img.py"), "--img", str(p1), str(p2)] if isinstance(exp, int) and exp >= 1: cmd += ["--exp", str(exp)] subprocess.run(cmd, cwd=str(RIFE_DIR), check=True) # Collect interpolated frames frames = [] i = 1 while True: fp = OUTPUT_DIR / f"img{i}.png" if not fp.exists(): break frames.append(fp) i += 1 if not frames: raise RuntimeError("No frames generated.") # Build GIF images = [Image.open(p).convert("RGBA") for p in frames] duration_ms = max(1, int(1000 / max(1, fps))) gif_path = work_dir / "interpolation.gif" images[0].save( gif_path, save_all=True, append_images=images[1:], optimize=False, duration=duration_ms, loop=0, disposal=2, ) # Optional cleanup try: shutil.rmtree(OUTPUT_DIR) except Exception: pass return str(gif_path) # Gradio UI TITLE = "🔥 RIFE Interpolation Demo (PyTorch, Local)" with gr.Blocks(title=TITLE, analytics_enabled=False) as demo: gr.Markdown(f"# {TITLE}") with gr.Row(): with gr.Column(): img_a = gr.Image(type="pil", label="Image A") img_b = gr.Image(type="pil", label="Image B") with gr.Column(): fps = gr.Slider(6, 30, value=14, step=1, label="FPS") exp = gr.Slider(1, 4, value=2, step=1, label="Interpolation exponent") run = gr.Button("Interpolate", variant="primary") gif_out = gr.Image(type="filepath", label="Result GIF") run.click(interpolate, inputs=[img_a, img_b, fps, exp], outputs=[gif_out]) demo.queue(concurrency_count=1, max_size=8) if __name__ == "__main__": demo.launch()