Spaces:
Runtime error
Runtime error
from datetime import datetime | |
from torchvision.utils import save_image | |
import gradio as gr | |
from torchvision import transforms | |
import torch | |
from huggingface_hub import hf_hub_download | |
import os | |
from PIL import Image | |
import threading | |
from process import fiximg | |
IMAGE_NET_MEAN = [0.485, 0.456, 0.406] | |
IMAGE_NET_STD = [0.229, 0.224, 0.225] | |
def predict(input, crop_type): | |
size = 352 | |
device = 'cpu' | |
transform_to_img = transforms.ToPILImage() | |
transform_orig = transforms.Compose([ | |
transforms.ToTensor()]) | |
transform = transforms.Compose([ | |
transforms.ToTensor(), | |
transforms.Resize((size, size)), | |
transforms.Normalize( | |
mean=IMAGE_NET_MEAN, | |
std=IMAGE_NET_STD)]) | |
desc_text = "Please select image and choose \"Crop Type\" option from dropdown menu" | |
if input is None or crop_type is None: | |
error_img = Image.open('./examples/no-input.jpg') | |
return desc_text, error_img, error_img, error_img | |
orig = transform_orig(input).to(device) | |
download_thread = threading.Thread(target=fiximg, name="Downloader", args=(orig,)) | |
download_thread.start() | |
image = transform(input)[None, ...] | |
file_path1 = hf_hub_download("bluelu/s", "sc_1.ptl", | |
use_auth_token=os.environ['S1']) | |
file_path2 = hf_hub_download("bluelu/s", "sc_2.ptl", | |
use_auth_token=os.environ['S1']) | |
file_path3 = hf_hub_download("bluelu/s", "sc_3.ptl", | |
use_auth_token=os.environ['S1']) | |
file_path4 = hf_hub_download("bluelu/s", "sc_4.ptl", | |
use_auth_token=os.environ['S1']) | |
model_1 = torch.jit.load(file_path1) | |
model_2 = torch.jit.load(file_path2) | |
mask2 = model_2(image) | |
mask1 = model_1(image) | |
mask1 = torch.nn.functional.upsample_bilinear(mask1, size=(orig.shape[1], orig.shape[2]))[0] | |
mask2 = torch.nn.functional.upsample_bilinear(mask2, size=(orig.shape[1], orig.shape[2]))[0] | |
input = torch.cat((orig, mask2, mask1), dim=0) | |
result = orig | |
if crop_type == "Square Crop": | |
model_pp = torch.jit.load(file_path3) | |
result = model_pp(input) | |
elif crop_type == "Centering Crop": | |
model_pp = torch.jit.load(file_path4) | |
result = model_pp(input) | |
return transform_to_img(result) | |
title = "Smart Crop" | |
description = """Need a photo where the item or person will be perfectly in the center for marketplaces (fb marketplace, Etsy, eBay...) | |
or for your social media? <br> | |
No problem! <br> | |
✨ Just upload your photo and get the best crop of your image!✨<br> | |
To download the crops press on the mouse right click -> save image as.<br> | |
**Crop Options:** | |
- Centering Crop (Default): Crop of the image where the most important content is located in the center. | |
- Square Crop : The best square crops of the image. | |
""" | |
gr.Interface(fn=predict, inputs=[gr.components.Image(), gr.inputs.Dropdown(["Centering Crop", "Square Crop"], label="Crop Options")], | |
outputs=[gr.components.Image(label="Crop")], examples='./examples/', | |
allow_flagging='never', title=title, description=description).launch() | |