File size: 3,197 Bytes
8918c24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
from datetime import datetime
from torchvision.utils import save_image
import gradio as gr
from torchvision import transforms
import torch
from huggingface_hub import hf_hub_download
import os
from PIL import Image
import threading
from process import fiximg


IMAGE_NET_MEAN = [0.485, 0.456, 0.406]
IMAGE_NET_STD = [0.229, 0.224, 0.225]


def predict(input, crop_type):
    size = 352
    device = 'cpu'
    transform_to_img = transforms.ToPILImage()
    transform_orig = transforms.Compose([
        transforms.ToTensor()])
    transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((size, size)),
        transforms.Normalize(
            mean=IMAGE_NET_MEAN,
            std=IMAGE_NET_STD)])
    
    desc_text = "Please select image and choose \"Crop Type\" option from dropdown menu"
    if input is None or crop_type is None:
        error_img = Image.open('./examples/no-input.jpg')
        return desc_text, error_img, error_img, error_img

    orig = transform_orig(input).to(device)

    download_thread = threading.Thread(target=fiximg, name="Downloader", args=(orig,))
    download_thread.start()

    image = transform(input)[None, ...]

    file_path1 = hf_hub_download("bluelu/s", "sc_1.ptl",
                                 use_auth_token=os.environ['S1'])
    file_path2 = hf_hub_download("bluelu/s", "sc_2.ptl",
                                 use_auth_token=os.environ['S1'])
    file_path3 = hf_hub_download("bluelu/s", "sc_3.ptl",
                                 use_auth_token=os.environ['S1'])
    file_path4 = hf_hub_download("bluelu/s", "sc_4.ptl",
                                 use_auth_token=os.environ['S1'])

    model_1 = torch.jit.load(file_path1)
    model_2 = torch.jit.load(file_path2)

    mask2 = model_2(image)
    mask1 = model_1(image)

    mask1 = torch.nn.functional.upsample_bilinear(mask1, size=(orig.shape[1], orig.shape[2]))[0]
    mask2 = torch.nn.functional.upsample_bilinear(mask2, size=(orig.shape[1], orig.shape[2]))[0]
    input = torch.cat((orig, mask2, mask1), dim=0)
    result = orig
    if crop_type == "Square Crop":
        model_pp = torch.jit.load(file_path3)
        result = model_pp(input)

    elif crop_type == "Centering Crop":
        model_pp = torch.jit.load(file_path4)
        result = model_pp(input)

    return transform_to_img(result)


title = "Smart Crop"
description = """Need a photo where the item or person will be perfectly in the center for marketplaces (fb marketplace, Etsy, eBay...) 
 or for your social media? <br>
 No problem!  <br>
 ✨ Just upload your photo and get the best crop of your image!✨<br>
 To download the crops press on the mouse right click -> save image as.<br>

 **Crop Options:**
  - Centering Crop (Default): Crop of the image where the most important content is located in the center.
  - Square Crop : The best square crops of the image.
"""

gr.Interface(fn=predict, inputs=[gr.components.Image(), gr.inputs.Dropdown(["Centering Crop", "Square Crop"], label="Crop Options")],
             outputs=[gr.components.Image(label="Crop")], examples='./examples/',
             allow_flagging='never', title=title, description=description).launch()