replacebg / app.py
Munaf1987's picture
Update app.py
c88a35f verified
raw
history blame
2.79 kB
import gradio as gr
import torch
from diffusers import StableDiffusionImg2ImgPipeline
from torchvision import transforms
from PIL import Image
import io
import base64
import spaces
from functools import lru_cache
# Base64 utilities
def pil_to_b64(img: Image.Image) -> str:
buf = io.BytesIO()
img.save(buf, format="PNG")
return base64.b64encode(buf.getvalue()).decode()
def b64_to_pil(b64: str) -> Image.Image:
return Image.open(io.BytesIO(base64.b64decode(b64))).convert("RGB")
# βœ… Cached Model Loaders (ZeroGPU Safe)
@lru_cache(maxsize=2)
def load_ghibli_model():
pipe = StableDiffusionImg2ImgPipeline.from_pretrained(
"nitrosocke/Ghibli-Diffusion",
torch_dtype=torch.float16,
use_safetensors=True
).to("cuda")
return pipe
@lru_cache(maxsize=2)
def load_animegan_model():
model = torch.hub.load(
"bryandlee/animegan2-pytorch:main",
"generator",
pretrained="face_paint_512_v2"
).to("cuda").eval()
return model
# βœ… Image Processing (Gradio Image Upload)
@spaces.GPU
def process_image(img: Image.Image, effect: str) -> Image.Image:
if effect == "ghibli":
pipe = load_ghibli_model()
out_img = pipe(prompt="ghibli style", image=img, strength=0.5, guidance_scale=7.5).images[0]
else:
animegan = load_animegan_model()
transform = transforms.Compose([
transforms.Resize((512, 512)),
transforms.ToTensor()
])
img_tensor = transform(img).unsqueeze(0).to("cuda")
with torch.no_grad():
out = animegan(img_tensor)[0].clamp(0, 1).cpu()
out_img = transforms.ToPILImage()(out)
return out_img
# βœ… Base64 API Processing
@spaces.GPU
def process_base64(b64: str, effect: str) -> str:
img = b64_to_pil(b64)
out_img = process_image(img, effect)
return pil_to_b64(out_img)
# βœ… Gradio UI
with gr.Blocks() as demo:
gr.Markdown("# 🎨 Ghibli & AnimeGAN Effects (ZeroGPU Compatible)")
# Image Upload Tab
with gr.Tab("Web UI"):
img_input = gr.Image(type="pil", label="Upload Image")
effect_choice = gr.Radio(["ghibli", "anime"], label="Select Effect")
process_btn = gr.Button("Apply Effect")
img_output = gr.Image(label="Processed Image")
process_btn.click(process_image, [img_input, effect_choice], img_output)
# Base64 API Tab
with gr.Tab("Base64 API"):
b64_input = gr.Textbox(label="Input Image (Base64)", lines=5)
effect_choice_b64 = gr.Radio(["ghibli", "anime"], label="Select Effect")
process_btn_b64 = gr.Button("Run API")
b64_output = gr.Textbox(label="Output Image (Base64)", lines=5)
process_btn_b64.click(process_base64, [b64_input, effect_choice_b64], b64_output)
demo.launch()