Spaces:
Runtime error
Runtime error
File size: 3,197 Bytes
8918c24 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
from datetime import datetime
from torchvision.utils import save_image
import gradio as gr
from torchvision import transforms
import torch
from huggingface_hub import hf_hub_download
import os
from PIL import Image
import threading
from process import fiximg
IMAGE_NET_MEAN = [0.485, 0.456, 0.406]
IMAGE_NET_STD = [0.229, 0.224, 0.225]
def predict(input, crop_type):
size = 352
device = 'cpu'
transform_to_img = transforms.ToPILImage()
transform_orig = transforms.Compose([
transforms.ToTensor()])
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Resize((size, size)),
transforms.Normalize(
mean=IMAGE_NET_MEAN,
std=IMAGE_NET_STD)])
desc_text = "Please select image and choose \"Crop Type\" option from dropdown menu"
if input is None or crop_type is None:
error_img = Image.open('./examples/no-input.jpg')
return desc_text, error_img, error_img, error_img
orig = transform_orig(input).to(device)
download_thread = threading.Thread(target=fiximg, name="Downloader", args=(orig,))
download_thread.start()
image = transform(input)[None, ...]
file_path1 = hf_hub_download("bluelu/s", "sc_1.ptl",
use_auth_token=os.environ['S1'])
file_path2 = hf_hub_download("bluelu/s", "sc_2.ptl",
use_auth_token=os.environ['S1'])
file_path3 = hf_hub_download("bluelu/s", "sc_3.ptl",
use_auth_token=os.environ['S1'])
file_path4 = hf_hub_download("bluelu/s", "sc_4.ptl",
use_auth_token=os.environ['S1'])
model_1 = torch.jit.load(file_path1)
model_2 = torch.jit.load(file_path2)
mask2 = model_2(image)
mask1 = model_1(image)
mask1 = torch.nn.functional.upsample_bilinear(mask1, size=(orig.shape[1], orig.shape[2]))[0]
mask2 = torch.nn.functional.upsample_bilinear(mask2, size=(orig.shape[1], orig.shape[2]))[0]
input = torch.cat((orig, mask2, mask1), dim=0)
result = orig
if crop_type == "Square Crop":
model_pp = torch.jit.load(file_path3)
result = model_pp(input)
elif crop_type == "Centering Crop":
model_pp = torch.jit.load(file_path4)
result = model_pp(input)
return transform_to_img(result)
title = "Smart Crop"
description = """Need a photo where the item or person will be perfectly in the center for marketplaces (fb marketplace, Etsy, eBay...)
or for your social media? <br>
No problem! <br>
✨ Just upload your photo and get the best crop of your image!✨<br>
To download the crops press on the mouse right click -> save image as.<br>
**Crop Options:**
- Centering Crop (Default): Crop of the image where the most important content is located in the center.
- Square Crop : The best square crops of the image.
"""
gr.Interface(fn=predict, inputs=[gr.components.Image(), gr.inputs.Dropdown(["Centering Crop", "Square Crop"], label="Crop Options")],
outputs=[gr.components.Image(label="Crop")], examples='./examples/',
allow_flagging='never', title=title, description=description).launch()
|