Sartc commited on
Commit
b848dd9
·
verified ·
1 Parent(s): dbb8b38

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -0
app.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from PIL import Image, ImageDraw, ImageFont
3
+ import numpy as np
4
+ import torch
5
+ from data import transform_img
6
+ from inference import load_model, predict
7
+
8
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
+ weights_path = "unet_model.pth"
10
+ model = load_model(weights_path, device)
11
+
12
+ def process_image(image, text, font_size):
13
+ image = image.convert("RGB")
14
+ print(f"image: {image}")
15
+ background_with_text = image.copy()
16
+ draw = ImageDraw.Draw(background_with_text)
17
+ font = ImageFont.truetype("/usr/share/fonts/truetype/freefont/FreeSansBold.ttf", font_size)
18
+ text_position = (50, 50)
19
+ text_color = (0, 0, 0)
20
+ draw.text(text_position, text, fill=text_color, font=font)
21
+
22
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
+ weights_path = "unet_model.pth"
24
+ model = load_model(weights_path, device)
25
+ transform = transform_img()
26
+ image_tensor = transform(image).unsqueeze(0)
27
+ mask = predict(model, image_tensor, device)
28
+ mask = mask.squeeze(0)
29
+ mask_binary = (mask > 0.5).astype(np.uint8) * 255
30
+ mask_img = Image.fromarray(mask_binary, mode="L")
31
+ mask_img = mask_img.resize(image.size, resample=Image.NEAREST)
32
+
33
+ original_rgba = image.convert("RGBA")
34
+
35
+ r, g, b, _ = original_rgba.split()
36
+ subject_img = Image.merge("RGBA", (r, g, b, mask_img))
37
+
38
+ background_with_text.paste(subject_img, (0, 0), subject_img)
39
+ return background_with_text
40
+
41
+ interface = gr.Interface(
42
+ fn=process_image,
43
+ inputs=[
44
+ gr.Image(type="pil", label="Upload Image"),
45
+ gr.Textbox(label="Enter Text"),
46
+ gr.Slider(10, 70, value=5, step=5, label="Font Size")
47
+ ],
48
+ outputs=gr.Image(type="pil", label="Output Image"),
49
+ title="Text Behind Image Generator",
50
+ description="Upload an image, enter text, and choose font size to generate the output image."
51
+ )
52
+
53
+ interface.launch()