Spaces:
Running
Running
import gradio as gr | |
import torch | |
import base64 | |
import io | |
from PIL import Image | |
from diffusers import StableDiffusionPipeline | |
from safetensors.torch import load_file | |
from src.pipeline import FluxPipeline | |
from src.transformer_flux import FluxTransformer2DModel | |
from src.lora_helper import set_single_lora, clear_cache | |
import spaces | |
# Load Base Model and LoRA | |
base_model = "black-forest-labs/FLUX.1-dev" | |
lora_path = "checkpoints/models/Ghibli.safetensors" | |
# Load the main pipeline | |
pipe = FluxPipeline.from_pretrained(base_model, torch_dtype=torch.float16) | |
transformer = FluxTransformer2DModel.from_pretrained(base_model, subfolder="transformer", torch_dtype=torch.float16) | |
pipe.transformer = transformer | |
pipe.to("cuda") | |
# Load LoRA | |
set_single_lora(pipe.transformer, lora_path, lora_weights=[1], cond_size=512) | |
# Base64 to Image | |
def base64_to_image(base64_str): | |
image_data = base64.b64decode(base64_str) | |
return Image.open(io.BytesIO(image_data)).convert("RGB") | |
# Image to Base64 | |
def image_to_base64(image): | |
buffered = io.BytesIO() | |
image.save(buffered, format="PNG") | |
return base64.b64encode(buffered.getvalue()).decode() | |
# Cartoonizer function | |
def cartoonize_base64(b64_image, prompt="Ghibli Studio style, hand-drawn anime illustration", height=768, width=768, seed=42): | |
input_image = base64_to_image(b64_image) | |
generator = torch.Generator(device="cuda").manual_seed(int(seed)) | |
result = pipe( | |
prompt=prompt, | |
height=int(height), | |
width=int(width), | |
guidance_scale=3.5, | |
num_inference_steps=25, | |
generator=generator, | |
spatial_images=[input_image], | |
cond_size=512 | |
).images[0] | |
clear_cache(pipe.transformer) | |
return image_to_base64(result) | |
# Gradio UI function | |
def ui_cartoonize(image, prompt, height, width, seed): | |
buffered = io.BytesIO() | |
image.save(buffered, format="PNG") | |
b64_image = base64.b64encode(buffered.getvalue()).decode() | |
cartoon_b64 = cartoonize_base64(b64_image, prompt, height, width, seed) | |
cartoon_image = base64_to_image(cartoon_b64) | |
return cartoon_image | |
# Gradio App | |
with gr.Blocks() as demo: | |
gr.Markdown("# π¨ Ghibli Style Cartoonizer using EasyControl") | |
with gr.Row(): | |
with gr.Column(): | |
input_image = gr.Image(type="pil", label="Upload Image") | |
prompt = gr.Textbox(label="Prompt", value="Ghibli Studio style, hand-drawn anime illustration") | |
height = gr.Slider(512, 1024, step=64, value=768, label="Height") | |
width = gr.Slider(512, 1024, step=64, value=768, label="Width") | |
seed = gr.Number(label="Seed", value=42) | |
generate_btn = gr.Button("Generate Ghibli Image") | |
with gr.Column(): | |
output_image = gr.Image(label="Cartoonized Output") | |
generate_btn.click( | |
fn=ui_cartoonize, | |
inputs=[input_image, prompt, height, width, seed], | |
outputs=output_image | |
) | |
# Gradio API: Accept base64, return base64 | |
gr.Interface( | |
fn=cartoonize_base64, | |
inputs=[ | |
gr.Text(label="Base64 Image Input"), | |
gr.Text(label="Prompt"), | |
gr.Number(label="Height", value=768), | |
gr.Number(label="Width", value=768), | |
gr.Number(label="Seed", value=42) | |
], | |
outputs=gr.Text(label="Base64 Cartoon Output"), | |
api_name="predict" | |
) | |
demo.launch() | |