Munaf1987 commited on
Commit
55ad485
·
verified ·
1 Parent(s): 5b2fddf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +55 -96
app.py CHANGED
@@ -1,102 +1,61 @@
 
1
  import gradio as gr
2
- import torch
3
- import base64
4
- import io
5
  from PIL import Image
6
- from diffusers import StableDiffusionPipeline
7
- from safetensors.torch import load_file
8
- from src.pipeline import FluxPipeline
9
- from src.transformer_flux import FluxTransformer2DModel
10
- from src.lora_helper import set_single_lora, clear_cache
11
- import spaces
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
- # Load Base Model and LoRA
14
- base_model = "black-forest-labs/FLUX.1-dev"
15
- lora_path = "checkpoints/models/Ghibli.safetensors"
16
-
17
- # Load the main pipeline
18
- pipe = FluxPipeline.from_pretrained(base_model, torch_dtype=torch.float16)
19
- transformer = FluxTransformer2DModel.from_pretrained(base_model, subfolder="transformer", torch_dtype=torch.float16)
20
- pipe.transformer = transformer
21
- pipe.to("cuda")
22
-
23
- # Load LoRA
24
- set_single_lora(pipe.transformer, lora_path, lora_weights=[1], cond_size=512)
25
-
26
- # Base64 to Image
27
- def base64_to_image(base64_str):
28
- image_data = base64.b64decode(base64_str)
29
- return Image.open(io.BytesIO(image_data)).convert("RGB")
30
-
31
- # Image to Base64
32
- def image_to_base64(image):
33
- buffered = io.BytesIO()
34
- image.save(buffered, format="PNG")
35
- return base64.b64encode(buffered.getvalue()).decode()
36
-
37
- # Cartoonizer function
38
- def cartoonize_base64(b64_image, prompt="Ghibli Studio style, hand-drawn anime illustration", height=768, width=768, seed=42):
39
- input_image = base64_to_image(b64_image)
40
-
41
- generator = torch.Generator(device="cuda").manual_seed(int(seed))
42
-
43
- result = pipe(
44
- prompt=prompt,
45
- height=int(height),
46
- width=int(width),
47
- guidance_scale=3.5,
48
- num_inference_steps=25,
49
- generator=generator,
50
- spatial_images=[input_image],
51
- cond_size=512
52
- ).images[0]
53
-
54
- clear_cache(pipe.transformer)
55
-
56
- return image_to_base64(result)
57
-
58
- # Gradio UI function
59
- def ui_cartoonize(image, prompt, height, width, seed):
60
- buffered = io.BytesIO()
61
- image.save(buffered, format="PNG")
62
- b64_image = base64.b64encode(buffered.getvalue()).decode()
63
- cartoon_b64 = cartoonize_base64(b64_image, prompt, height, width, seed)
64
- cartoon_image = base64_to_image(cartoon_b64)
65
- return cartoon_image
66
-
67
- # Gradio App
68
  with gr.Blocks() as demo:
69
- gr.Markdown("# 🎨 Ghibli Style Cartoonizer using EasyControl")
70
-
71
- with gr.Row():
72
- with gr.Column():
73
- input_image = gr.Image(type="pil", label="Upload Image")
74
- prompt = gr.Textbox(label="Prompt", value="Ghibli Studio style, hand-drawn anime illustration")
75
- height = gr.Slider(512, 1024, step=64, value=768, label="Height")
76
- width = gr.Slider(512, 1024, step=64, value=768, label="Width")
77
- seed = gr.Number(label="Seed", value=42)
78
- generate_btn = gr.Button("Generate Ghibli Image")
79
- with gr.Column():
80
- output_image = gr.Image(label="Cartoonized Output")
81
-
82
- generate_btn.click(
83
- fn=ui_cartoonize,
84
- inputs=[input_image, prompt, height, width, seed],
85
- outputs=output_image
86
- )
87
-
88
- # Gradio API: Accept base64, return base64
89
- gr.Interface(
90
- fn=cartoonize_base64,
91
- inputs=[
92
- gr.Text(label="Base64 Image Input"),
93
- gr.Text(label="Prompt"),
94
- gr.Number(label="Height", value=768),
95
- gr.Number(label="Width", value=768),
96
- gr.Number(label="Seed", value=42)
97
- ],
98
- outputs=gr.Text(label="Base64 Cartoon Output"),
99
- api_name="predict"
100
- )
101
 
102
  demo.launch()
 
1
+ # app.py
2
  import gradio as gr
3
+ import torch, io, base64
 
 
4
  from PIL import Image
5
+ from diffusers import StableDiffusionImg2ImgPipeline
6
+ from vtoonify_model import load_vtoonify # see below
7
+
8
+ # Load models
9
+ pipe_ghibli = StableDiffusionImg2ImgPipeline.from_pretrained(
10
+ "nitrosocke/Ghibli-Diffusion", torch_dtype=torch.float16
11
+ ).to("cuda") # Ghibli-style fine-tuned SD :contentReference[oaicite:1]{index=1}
12
+
13
+ pipe_vtoonify = load_vtoonify().to("cuda") # cartoonization model loader
14
+
15
+ # Helpers for base64 conversion
16
+ def pil_to_b64(img: Image.Image) -> str:
17
+ buf = io.BytesIO()
18
+ img.save(buf, format="PNG")
19
+ return base64.b64encode(buf.getvalue()).decode()
20
+
21
+ def b64_to_pil(b64: str) -> Image.Image:
22
+ data = base64.b64decode(b64)
23
+ return Image.open(io.BytesIO(data)).convert("RGB")
24
+
25
+ # Core processor
26
+ def run_effect(input_b64: str, effect: str) -> dict:
27
+ img = b64_to_pil(input_b64)
28
+ if effect == "ghibli":
29
+ out = pipe_ghibli(prompt="ghibli style", image=img, strength=0.5, guidance_scale=7.5).images[0]
30
+ else:
31
+ out = pipe_vtoonify(img)
32
+ return {"output_b64": pil_to_b64(out)}
33
+
34
+ @gr.utils.decorators.thread_safe()
35
+ @spaces.GPU # enables GPU on ZeroGPU Infra
36
+ def api_process(input_b64, effect):
37
+ return run_effect(input_b64, effect)
38
+
39
+ def gradio_process(img: Image.Image, effect: str) -> Image.Image:
40
+ # Reuse logic, bypass base64
41
+ in_b64 = pil_to_b64(img)
42
+ return b64_to_pil(run_effect(in_b64, effect)["output_b64"])
43
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
  with gr.Blocks() as demo:
45
+ gr.Markdown("# Ghibli & VToonify Effects 🎨")
46
+
47
+ with gr.Tab("Web UI"):
48
+ inp = gr.Image(type="pil", label="Upload Image")
49
+ eff = gr.Radio(["ghibli", "vtoonify"], label="Effect")
50
+ btn = gr.Button("Apply Effect")
51
+ out = gr.Image(label="Result")
52
+ btn.click(gradio_process, [inp, eff], out)
53
+
54
+ with gr.Tab("API (base64)"):
55
+ inp_b64 = gr.Textbox(lines=4, label="Input Image (base64)")
56
+ eff2 = gr.Radio(["ghibli", "vtoonify"], label="Effect")
57
+ btn2 = gr.Button("Run API")
58
+ out_b64 = gr.Textbox(lines=4, label="Output Image (base64)")
59
+ btn2.click(api_process, [inp_b64, eff2], out_b64)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
 
61
  demo.launch()