awacke1's picture
Update app.py
87411a9
raw
history blame
5.04 kB
import numpy as np
import torch
import torch.nn as nn
import gradio as gr
from PIL import Image
import torchvision.transforms as transforms
norm_layer = nn.InstanceNorm2d
class ResidualBlock(nn.Module):
def __init__(self, in_features):
super(ResidualBlock, self).__init__()
conv_block = [ nn.ReflectionPad2d(1),
nn.Conv2d(in_features, in_features, 3),
norm_layer(in_features),
nn.ReLU(inplace=True),
nn.ReflectionPad2d(1),
nn.Conv2d(in_features, in_features, 3),
norm_layer(in_features)
]
self.conv_block = nn.Sequential(*conv_block)
def forward(self, x):
return x + self.conv_block(x)
class Generator(nn.Module):
def __init__(self, input_nc, output_nc, n_residual_blocks=9, sigmoid=True):
super(Generator, self).__init__()
model0 = [ nn.ReflectionPad2d(3),
nn.Conv2d(input_nc, 64, 7),
norm_layer(64),
nn.ReLU(inplace=True) ]
self.model0 = nn.Sequential(*model0)
model1 = []
in_features = 64
out_features = in_features*2
for _ in range(2):
model1 += [ nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
norm_layer(out_features),
nn.ReLU(inplace=True) ]
in_features = out_features
out_features = in_features*2
self.model1 = nn.Sequential(*model1)
model2 = []
for _ in range(n_residual_blocks):
model2 += [ResidualBlock(in_features)]
self.model2 = nn.Sequential(*model2)
model3 = []
out_features = in_features//2
for _ in range(2):
model3 += [ nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
norm_layer(out_features),
nn.ReLU(inplace=True) ]
in_features = out_features
out_features = in_features//2
self.model3 = nn.Sequential(*model3)
model4 = [ nn.ReflectionPad2d(3),
nn.Conv2d(64, output_nc, 7)]
if sigmoid:
model4 += [nn.Sigmoid()]
self.model4 = nn.Sequential(*model4)
def forward(self, x, cond=None):
out = self.model0(x)
out = self.model1(out)
out = self.model2(out)
out = self.model3(out)
out = self.model4(out)
return out
model1 = Generator(3, 1, 3)
model1.load_state_dict(torch.load('model.pth', map_location=torch.device('cpu')))
model1.eval()
model2 = Generator(3, 1, 3)
model2.load_state_dict(torch.load('model2.pth', map_location=torch.device('cpu')))
model2.eval()
def predict(input_img, ver):
input_img = Image.open(input_img)
transform = transforms.Compose([transforms.Resize(256, Image.BICUBIC), transforms.ToTensor()])
input_img = transform(input_img)
input_img = torch.unsqueeze(input_img, 0)
drawing = 0
with torch.no_grad():
if ver == 'Simple Lines':
drawing = model2(input_img)[0].detach()
else:
drawing = model1(input_img)[0].detach()
drawing = transforms.ToPILImage()(drawing)
return drawing
title="Art Style Line Drawings - Foods and Nutrition"
description="Art Style Line Drawings 🥐🥑🥒🥓🥔🥕🥖🥗🥘🥙🥚🥛🥜🥝🥞🥟🥠🥡🥢🥣🥤🥥🥦🥧🥨🥩🥪🥫🥬🥭🥮🥯"
# article = "<p style='text-align: center'></p>"
examples=[
['Cucumbers.jpg', 'Complex Lines'],
['Croissant.jpg', 'Complex Lines'],
['Avocado.jpg', 'Complex Lines'],
['nuTJCub5mbDmN0v8oL6Y_15.625x.jpg', 'Complex Lines'],
['ouSFtOAajjvbhB9QYh8v_20x.jpg', 'Complex Lines'],
['Baubles-standard-scale-20_00x-gigapixel.jpeg', 'Simple Lines'],
['Mannerism Klee 3-gigapixel-art-scale-6_00x.png', 'Simple Lines'],
['Cubism Klimt-gigapixel-art-scale-6_00x.png', 'Complex Lines'],
['Magic Realism Dore rainbow 5-gigapixel-art-scale-6_00x.png', 'Simple Lines'],
['Neoclassicism Klimt 4-gigapixel-art-scale-6_00x.png', 'Complex Lines'],
['Realism Basquiat 1-gigapixel-art-scale-6_00x.png', 'Simple Lines'],
['Surrealism Aaron Wacker-gigapixel-art-scale-6_00x.png', 'Complex Lines'],
['FuturismEgon3-gigapixel-art-scale-6_00x.png', 'Simple Lines']
]
css = ".output-image, .input-image {height: 40rem !important; width: 100% !important;}"
#css = "@media screen and (max-width: 600px) { .output_image, .input_image {height:20rem !important; width: 100% !important;} }"
# css = ".output_image, .input_image {height: 600px !important}"
# css = ".image-preview {height: auto !important;}"
iface = gr.Interface(predict, [gr.inputs.Image(type='filepath'),
gr.inputs.Radio(['Complex Lines','Simple Lines'], type="value", default='Simple Lines', label='version')],
gr.outputs.Image(type="pil"), title=title,description=description,examples=examples)
#iface.launch()
iface.launch()