Munaf1987 commited on
Commit
0d40aa7
·
verified ·
1 Parent(s): 8ff49cd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -55
app.py CHANGED
@@ -3,75 +3,56 @@ import torch
3
  from diffusers import StableDiffusionImg2ImgPipeline
4
  from torchvision import transforms
5
  from PIL import Image
6
- import io
7
- import base64
8
- import spaces
9
 
10
- # Load Ghibli model
11
  pipe_ghibli = StableDiffusionImg2ImgPipeline.from_pretrained(
12
  "nitrosocke/Ghibli-Diffusion", torch_dtype=torch.float16
13
  ).to("cuda")
14
 
15
- # Load CartoonGAN model
16
- cartoon_model = torch.hub.load('AK391/CartoonGAN', 'cartoon_gan', pretrained=True).to("cuda").eval()
 
 
 
17
 
18
- # Base64 utilities
19
- def pil_to_base64(img: Image.Image) -> str:
20
- buffer = io.BytesIO()
21
- img.save(buffer, format="PNG")
22
- return base64.b64encode(buffer.getvalue()).decode()
23
-
24
- def base64_to_pil(b64: str) -> Image.Image:
25
- image_data = base64.b64decode(b64)
26
- return Image.open(io.BytesIO(image_data)).convert("RGB")
27
 
28
  # CartoonGAN processor
29
- def apply_cartoon_gan(img: Image.Image) -> Image.Image:
30
- transform = transforms.Compose([
31
- transforms.Resize((256, 256)),
32
- transforms.ToTensor()
33
- ])
34
- img_tensor = transform(img).unsqueeze(0).to("cuda")
35
- with torch.no_grad():
36
- output = cartoon_model(img_tensor)[0].clamp(0, 1).cpu()
37
- output_pil = transforms.ToPILImage()(output)
38
- return output_pil
39
-
40
- # Unified effect processor
41
- def process_image(input_image: Image.Image, effect: str) -> Image.Image:
42
- if effect == "ghibli":
43
- output_image = pipe_ghibli(prompt="ghibli style", image=input_image, strength=0.5, guidance_scale=7.5).images[0]
44
- else:
45
- output_image = apply_cartoon_gan(input_image)
46
- return output_image
47
 
48
  @spaces.GPU
49
- def process_base64(input_b64: str, effect: str) -> str:
50
- input_image = base64_to_pil(input_b64)
51
- output_image = process_image(input_image, effect)
52
- return pil_to_base64(output_image)
53
 
54
  # Gradio UI
55
  with gr.Blocks() as demo:
56
- gr.Markdown("# 🎨 Ghibli & CartoonGAN Effects")
57
-
58
  with gr.Tab("Web UI"):
59
- with gr.Row():
60
- image_input = gr.Image(type="pil", label="Upload Image")
61
- effect_selector = gr.Radio(["ghibli", "cartoon"], label="Select Effect")
62
- with gr.Row():
63
- apply_button = gr.Button("Apply Effect")
64
- with gr.Row():
65
- image_output = gr.Image(label="Processed Image")
66
-
67
- apply_button.click(process_image, [image_input, effect_selector], image_output)
68
-
69
  with gr.Tab("Base64 API"):
70
- base64_input = gr.Textbox(label="Input Image (Base64)", lines=5)
71
- effect_choice = gr.Radio(["ghibli", "cartoon"], label="Select Effect")
72
- api_button = gr.Button("Run API")
73
- base64_output = gr.Textbox(label="Output Image (Base64)", lines=5)
74
-
75
- api_button.click(process_base64, [base64_input, effect_choice], base64_output)
76
 
77
  demo.launch()
 
3
  from diffusers import StableDiffusionImg2ImgPipeline
4
  from torchvision import transforms
5
  from PIL import Image
6
+ import io, base64, spaces
 
 
7
 
8
+ # Ghibli model
9
  pipe_ghibli = StableDiffusionImg2ImgPipeline.from_pretrained(
10
  "nitrosocke/Ghibli-Diffusion", torch_dtype=torch.float16
11
  ).to("cuda")
12
 
13
+ # CartoonGAN model via torch.hub
14
+ cartoon_model = torch.hub.load(
15
+ 'znxlwm/pytorch-CartoonGAN', 'CartoonGAN',
16
+ pretrained=True, trust_repo=True
17
+ ).to("cuda").eval()
18
 
19
+ # Helpers: PIL ↔ Base64
20
+ def pil_to_b64(img): buf=io.BytesIO(); img.save(buf,"PNG"); return base64.b64encode(buf.getvalue()).decode()
21
+ def b64_to_pil(b): return Image.open(io.BytesIO(base64.b64decode(b))).convert("RGB")
 
 
 
 
 
 
22
 
23
  # CartoonGAN processor
24
+ def apply_cartoon(img):
25
+ t=transforms.Compose([transforms.Resize((256,256)), transforms.ToTensor()])
26
+ x=t(img).unsqueeze(0).to("cuda")
27
+ with torch.no_grad(): y=cartoon_model(x)[0].clamp(0,1).cpu()
28
+ return transforms.ToPILImage()(y)
29
+
30
+ # Unified image processor
31
+ def process_image(img, effect):
32
+ if effect=="ghibli":
33
+ return pipe_ghibli(prompt="ghibli style", image=img, strength=0.5, guidance_scale=7.5).images[0]
34
+ return apply_cartoon(img)
 
 
 
 
 
 
 
35
 
36
  @spaces.GPU
37
+ def process_base64(b64, effect):
38
+ img = b64_to_pil(b64)
39
+ out = process_image(img, effect)
40
+ return pil_to_b64(out)
41
 
42
  # Gradio UI
43
  with gr.Blocks() as demo:
44
+ gr.Markdown("# 🎨 Ghibli & CartoonGAN Effects (ZeroGPU)")
 
45
  with gr.Tab("Web UI"):
46
+ inp = gr.Image(type="pil")
47
+ eff = gr.Radio(["ghibli","cartoon"], label="Effect")
48
+ btn = gr.Button("Apply")
49
+ out_img = gr.Image()
50
+ btn.click(process_image, [inp, eff], out_img)
 
 
 
 
 
51
  with gr.Tab("Base64 API"):
52
+ in_b64 = gr.Textbox(lines=5)
53
+ eff2 = gr.Radio(["ghibli","cartoon"], label="Effect")
54
+ btn2 = gr.Button("Run API")
55
+ out_b64 = gr.Textbox(lines=5)
56
+ btn2.click(process_base64, [in_b64, eff2], out_b64)
 
57
 
58
  demo.launch()