|
import gradio as gr |
|
import torch |
|
import base64 |
|
import io |
|
from PIL import Image |
|
from diffusers import StableDiffusionPipeline |
|
from safetensors.torch import load_file |
|
from src.pipeline import FluxPipeline |
|
from src.transformer_flux import FluxTransformer2DModel |
|
from src.lora_helper import set_single_lora, clear_cache |
|
import spaces |
|
|
|
|
|
base_model = "black-forest-labs/FLUX.1-dev" |
|
lora_path = "checkpoints/models/Ghibli.safetensors" |
|
|
|
|
|
pipe = FluxPipeline.from_pretrained(base_model, torch_dtype=torch.float16) |
|
transformer = FluxTransformer2DModel.from_pretrained(base_model, subfolder="transformer", torch_dtype=torch.float16) |
|
pipe.transformer = transformer |
|
pipe.to("cuda") |
|
|
|
|
|
set_single_lora(pipe.transformer, lora_path, lora_weights=[1], cond_size=512) |
|
|
|
|
|
def base64_to_image(base64_str): |
|
image_data = base64.b64decode(base64_str) |
|
return Image.open(io.BytesIO(image_data)).convert("RGB") |
|
|
|
|
|
def image_to_base64(image): |
|
buffered = io.BytesIO() |
|
image.save(buffered, format="PNG") |
|
return base64.b64encode(buffered.getvalue()).decode() |
|
|
|
|
|
def cartoonize_base64(b64_image, prompt="Ghibli Studio style, hand-drawn anime illustration", height=768, width=768, seed=42): |
|
input_image = base64_to_image(b64_image) |
|
|
|
generator = torch.Generator(device="cuda").manual_seed(int(seed)) |
|
|
|
result = pipe( |
|
prompt=prompt, |
|
height=int(height), |
|
width=int(width), |
|
guidance_scale=3.5, |
|
num_inference_steps=25, |
|
generator=generator, |
|
spatial_images=[input_image], |
|
cond_size=512 |
|
).images[0] |
|
|
|
clear_cache(pipe.transformer) |
|
|
|
return image_to_base64(result) |
|
|
|
|
|
def ui_cartoonize(image, prompt, height, width, seed): |
|
buffered = io.BytesIO() |
|
image.save(buffered, format="PNG") |
|
b64_image = base64.b64encode(buffered.getvalue()).decode() |
|
cartoon_b64 = cartoonize_base64(b64_image, prompt, height, width, seed) |
|
cartoon_image = base64_to_image(cartoon_b64) |
|
return cartoon_image |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# π¨ Ghibli Style Cartoonizer using EasyControl") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
input_image = gr.Image(type="pil", label="Upload Image") |
|
prompt = gr.Textbox(label="Prompt", value="Ghibli Studio style, hand-drawn anime illustration") |
|
height = gr.Slider(512, 1024, step=64, value=768, label="Height") |
|
width = gr.Slider(512, 1024, step=64, value=768, label="Width") |
|
seed = gr.Number(label="Seed", value=42) |
|
generate_btn = gr.Button("Generate Ghibli Image") |
|
with gr.Column(): |
|
output_image = gr.Image(label="Cartoonized Output") |
|
|
|
generate_btn.click( |
|
fn=ui_cartoonize, |
|
inputs=[input_image, prompt, height, width, seed], |
|
outputs=output_image |
|
) |
|
|
|
|
|
gr.Interface( |
|
fn=cartoonize_base64, |
|
inputs=[ |
|
gr.Text(label="Base64 Image Input"), |
|
gr.Text(label="Prompt"), |
|
gr.Number(label="Height", value=768), |
|
gr.Number(label="Width", value=768), |
|
gr.Number(label="Seed", value=42) |
|
], |
|
outputs=gr.Text(label="Base64 Cartoon Output"), |
|
api_name="predict" |
|
) |
|
|
|
demo.launch() |
|
|