replacebg / app.py
Munaf1987's picture
Update app.py
7fe7d39 verified
raw
history blame
2.37 kB
import gradio as gr
import torch
from diffusers import StableDiffusionImg2ImgPipeline
from torchvision import transforms
from PIL import Image
import io
import base64
import spaces
# Load Ghibli model
pipe_ghibli = StableDiffusionImg2ImgPipeline.from_pretrained(
"nitrosocke/Ghibli-Diffusion",
torch_dtype=torch.float16,
use_safetensors=True
).to("cuda")
# Load AnimeGANv2 model
animegan = torch.hub.load(
"bryandlee/animegan2-pytorch:main",
"generator",
pretrained="face_paint_512_v2"
).to("cuda").eval()
# Base64 conversion helpers
def pil_to_b64(img: Image.Image) -> str:
buf = io.BytesIO()
img.save(buf, format="PNG")
return base64.b64encode(buf.getvalue()).decode()
def b64_to_pil(b64: str) -> Image.Image:
return Image.open(io.BytesIO(base64.b64decode(b64))).convert("RGB")
# AnimeGAN processor
def apply_animegan(img: Image.Image) -> Image.Image:
transform = transforms.Compose([
transforms.Resize((512, 512)),
transforms.ToTensor()
])
img_tensor = transform(img).unsqueeze(0).to("cuda")
with torch.no_grad():
out = animegan(img_tensor)[0].clamp(0, 1).cpu()
return transforms.ToPILImage()(out)
# Unified image processor
def process_image(img: Image.Image, effect: str) -> Image.Image:
if effect == "ghibli":
return pipe_ghibli(prompt="ghibli style", image=img, strength=0.5, guidance_scale=7.5).images[0]
else:
return apply_animegan(img)
@spaces.GPU
def process_base64(b64: str, effect: str) -> str:
img = b64_to_pil(b64)
out = process_image(img, effect)
return pil_to_b64(out)
# Gradio UI
with gr.Blocks() as demo:
gr.Markdown("# 🎨 Ghibli & AnimeGAN Effects (ZeroGPU Compatible)")
with gr.Tab("Web UI"):
inp = gr.Image(type="pil", label="Upload Image")
eff = gr.Radio(["ghibli", "anime"], label="Select Effect")
btn = gr.Button("Apply Effect")
out_img = gr.Image(label="Output Image")
btn.click(process_image, [inp, eff], out_img)
with gr.Tab("Base64 API"):
b64_in = gr.Textbox(label="Input Image (Base64)", lines=5)
eff2 = gr.Radio(["ghibli", "anime"], label="Select Effect")
btn2 = gr.Button("Run API")
b64_out = gr.Textbox(label="Output Image (Base64)", lines=5)
btn2.click(process_base64, [b64_in, eff2], b64_out)
demo.launch()