Munaf1987 commited on
Commit
c6e0655
·
verified ·
1 Parent(s): ba3051b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +88 -71
app.py CHANGED
@@ -1,85 +1,102 @@
1
  import gradio as gr
2
  import torch
3
- import numpy as np
4
- from diffusers import StableDiffusionXLInpaintPipeline
5
- from PIL import Image, ImageDraw
6
- from transformers import DetrImageProcessor, DetrForObjectDetection
 
 
 
 
7
  import spaces
8
 
9
- device = "cuda" if torch.cuda.is_available() else "cpu"
10
-
11
- # Load the Stable Diffusion XL Inpainting model
12
- pipe = StableDiffusionXLInpaintPipeline.from_pretrained(
13
- "stabilityai/stable-diffusion-xl-base-1.0",
14
- torch_dtype=torch.float16 if device == "cuda" else torch.float32,
15
- variant="fp16",
16
- ).to(device)
17
-
18
- # Load the DETR object detection model
19
- processor = DetrImageProcessor.from_pretrained("facebook/detr-resnet-50")
20
- detector = DetrForObjectDetection.from_pretrained("facebook/detr-resnet-50").to(device)
21
-
22
- @spaces.GPU
23
- def detect_and_replace_humans(input_image, prompt):
24
- if input_image is None or prompt == "":
25
- return None
26
-
27
- image_np = np.array(input_image)
28
- inputs = processor(images=input_image, return_tensors="pt").to(device)
29
-
30
- outputs = detector(**inputs)
31
- target_sizes = torch.tensor([image_np.shape[:2]]).to(device)
32
-
33
- results = processor.post_process_object_detection(outputs, target_sizes=target_sizes, threshold=0.9)[0]
34
-
35
- mask = Image.new("L", input_image.size, 0)
36
- draw = ImageDraw.Draw(mask)
37
-
38
- found = False
39
- for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
40
- if detector.config.id2label[label.item()] == "person":
41
- box = [int(i) for i in box.tolist()]
42
- draw.rectangle(box, fill=255)
43
- found = True
44
-
45
- if not found:
46
- return "No human detected."
47
-
48
- # Pre-defined positive and negative prompts
49
- positive_prompt = (
50
- "Replace the masked humans with imaginary Indian bride and groom wearing traditional Indian wedding attire, "
51
- "with detailed embroidery, colorful saree and sherwani, realistic faces, natural skin texture, matching pose, "
52
- "perfect lighting, and the same camera perspective. Keep the background unchanged."
53
- )
54
 
55
- negative_prompt = (
56
- "blurry, distorted, deformed, double face, extra limbs, low quality, bad proportions, low resolution, "
57
- "changed background, multiple faces, duplicate body parts, cartoon, watermark, text"
58
- )
59
 
60
- # Inpainting process
61
- output = pipe(
62
- prompt=positive_prompt,
63
- negative_prompt=negative_prompt,
64
- image=input_image,
65
- mask_image=mask,
66
- num_inference_steps=40,
67
- guidance_scale=8.5
68
- ).images[0]
69
 
70
- return output
 
 
 
 
 
 
 
71
 
72
- # Gradio UI
73
  with gr.Blocks() as demo:
74
- gr.Markdown("## Replace Humans with Imaginary Indian Bride and Groom (Background Preserved)")
75
 
76
  with gr.Row():
77
- input_image = gr.Image(type="pil", label="Input Image")
78
- output_image = gr.Image(type="pil", label="Output Image")
79
-
80
- prompt_text = gr.Textbox(label="Prompt (Optional, Predefined Prompt Used)", placeholder="You can leave this blank")
81
- submit = gr.Button("Submit")
 
 
 
 
 
 
 
 
 
 
82
 
83
- submit.click(detect_and_replace_humans, inputs=[input_image, prompt_text], outputs=output_image)
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
  demo.launch()
 
1
  import gradio as gr
2
  import torch
3
+ import base64
4
+ import io
5
+ from PIL import Image
6
+ from diffusers import StableDiffusionPipeline
7
+ from safetensors.torch import load_file
8
+ from src.pipeline import FluxPipeline
9
+ from src.transformer_flux import FluxTransformer2DModel
10
+ from src.lora_helper import set_single_lora, clear_cache
11
  import spaces
12
 
13
+ # Load Base Model and LoRA
14
+ base_model = "black-forest-labs/FLUX.1-dev"
15
+ lora_path = "checkpoints/models/Ghibli.safetensors"
16
+
17
+ # Load the main pipeline
18
+ pipe = FluxPipeline.from_pretrained(base_model, torch_dtype=torch.float16)
19
+ transformer = FluxTransformer2DModel.from_pretrained(base_model, subfolder="transformer", torch_dtype=torch.float16)
20
+ pipe.transformer = transformer
21
+ pipe.to("cuda")
22
+
23
+ # Load LoRA
24
+ set_single_lora(pipe.transformer, lora_path, lora_weights=[1], cond_size=512)
25
+
26
+ # Base64 to Image
27
+ def base64_to_image(base64_str):
28
+ image_data = base64.b64decode(base64_str)
29
+ return Image.open(io.BytesIO(image_data)).convert("RGB")
30
+
31
+ # Image to Base64
32
+ def image_to_base64(image):
33
+ buffered = io.BytesIO()
34
+ image.save(buffered, format="PNG")
35
+ return base64.b64encode(buffered.getvalue()).decode()
36
+
37
+ # Cartoonizer function
38
+ def cartoonize_base64(b64_image, prompt="Ghibli Studio style, hand-drawn anime illustration", height=768, width=768, seed=42):
39
+ input_image = base64_to_image(b64_image)
40
+
41
+ generator = torch.Generator(device="cuda").manual_seed(int(seed))
42
+
43
+ result = pipe(
44
+ prompt=prompt,
45
+ height=int(height),
46
+ width=int(width),
47
+ guidance_scale=3.5,
48
+ num_inference_steps=25,
49
+ generator=generator,
50
+ spatial_images=[input_image],
51
+ cond_size=512
52
+ ).images[0]
 
 
 
 
 
53
 
54
+ clear_cache(pipe.transformer)
 
 
 
55
 
56
+ return image_to_base64(result)
 
 
 
 
 
 
 
 
57
 
58
+ # Gradio UI function
59
+ def ui_cartoonize(image, prompt, height, width, seed):
60
+ buffered = io.BytesIO()
61
+ image.save(buffered, format="PNG")
62
+ b64_image = base64.b64encode(buffered.getvalue()).decode()
63
+ cartoon_b64 = cartoonize_base64(b64_image, prompt, height, width, seed)
64
+ cartoon_image = base64_to_image(cartoon_b64)
65
+ return cartoon_image
66
 
67
+ # Gradio App
68
  with gr.Blocks() as demo:
69
+ gr.Markdown("# 🎨 Ghibli Style Cartoonizer using EasyControl")
70
 
71
  with gr.Row():
72
+ with gr.Column():
73
+ input_image = gr.Image(type="pil", label="Upload Image")
74
+ prompt = gr.Textbox(label="Prompt", value="Ghibli Studio style, hand-drawn anime illustration")
75
+ height = gr.Slider(512, 1024, step=64, value=768, label="Height")
76
+ width = gr.Slider(512, 1024, step=64, value=768, label="Width")
77
+ seed = gr.Number(label="Seed", value=42)
78
+ generate_btn = gr.Button("Generate Ghibli Image")
79
+ with gr.Column():
80
+ output_image = gr.Image(label="Cartoonized Output")
81
+
82
+ generate_btn.click(
83
+ fn=ui_cartoonize,
84
+ inputs=[input_image, prompt, height, width, seed],
85
+ outputs=output_image
86
+ )
87
 
88
+ # Gradio API: Accept base64, return base64
89
+ gr.Interface(
90
+ fn=cartoonize_base64,
91
+ inputs=[
92
+ gr.Text(label="Base64 Image Input"),
93
+ gr.Text(label="Prompt"),
94
+ gr.Number(label="Height", value=768),
95
+ gr.Number(label="Width", value=768),
96
+ gr.Number(label="Seed", value=42)
97
+ ],
98
+ outputs=gr.Text(label="Base64 Cartoon Output"),
99
+ api_name="predict"
100
+ )
101
 
102
  demo.launch()