replacebg / app.py
Munaf1987's picture
Update app.py
c6e0655 verified
raw
history blame
3.39 kB
import gradio as gr
import torch
import base64
import io
from PIL import Image
from diffusers import StableDiffusionPipeline
from safetensors.torch import load_file
from src.pipeline import FluxPipeline
from src.transformer_flux import FluxTransformer2DModel
from src.lora_helper import set_single_lora, clear_cache
import spaces
# Load Base Model and LoRA
base_model = "black-forest-labs/FLUX.1-dev"
lora_path = "checkpoints/models/Ghibli.safetensors"
# Load the main pipeline
pipe = FluxPipeline.from_pretrained(base_model, torch_dtype=torch.float16)
transformer = FluxTransformer2DModel.from_pretrained(base_model, subfolder="transformer", torch_dtype=torch.float16)
pipe.transformer = transformer
pipe.to("cuda")
# Load LoRA
set_single_lora(pipe.transformer, lora_path, lora_weights=[1], cond_size=512)
# Base64 to Image
def base64_to_image(base64_str):
image_data = base64.b64decode(base64_str)
return Image.open(io.BytesIO(image_data)).convert("RGB")
# Image to Base64
def image_to_base64(image):
buffered = io.BytesIO()
image.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode()
# Cartoonizer function
def cartoonize_base64(b64_image, prompt="Ghibli Studio style, hand-drawn anime illustration", height=768, width=768, seed=42):
input_image = base64_to_image(b64_image)
generator = torch.Generator(device="cuda").manual_seed(int(seed))
result = pipe(
prompt=prompt,
height=int(height),
width=int(width),
guidance_scale=3.5,
num_inference_steps=25,
generator=generator,
spatial_images=[input_image],
cond_size=512
).images[0]
clear_cache(pipe.transformer)
return image_to_base64(result)
# Gradio UI function
def ui_cartoonize(image, prompt, height, width, seed):
buffered = io.BytesIO()
image.save(buffered, format="PNG")
b64_image = base64.b64encode(buffered.getvalue()).decode()
cartoon_b64 = cartoonize_base64(b64_image, prompt, height, width, seed)
cartoon_image = base64_to_image(cartoon_b64)
return cartoon_image
# Gradio App
with gr.Blocks() as demo:
gr.Markdown("# 🎨 Ghibli Style Cartoonizer using EasyControl")
with gr.Row():
with gr.Column():
input_image = gr.Image(type="pil", label="Upload Image")
prompt = gr.Textbox(label="Prompt", value="Ghibli Studio style, hand-drawn anime illustration")
height = gr.Slider(512, 1024, step=64, value=768, label="Height")
width = gr.Slider(512, 1024, step=64, value=768, label="Width")
seed = gr.Number(label="Seed", value=42)
generate_btn = gr.Button("Generate Ghibli Image")
with gr.Column():
output_image = gr.Image(label="Cartoonized Output")
generate_btn.click(
fn=ui_cartoonize,
inputs=[input_image, prompt, height, width, seed],
outputs=output_image
)
# Gradio API: Accept base64, return base64
gr.Interface(
fn=cartoonize_base64,
inputs=[
gr.Text(label="Base64 Image Input"),
gr.Text(label="Prompt"),
gr.Number(label="Height", value=768),
gr.Number(label="Width", value=768),
gr.Number(label="Seed", value=42)
],
outputs=gr.Text(label="Base64 Cartoon Output"),
api_name="predict"
)
demo.launch()