awacke1's picture
Update app.py
91875a1
raw
history blame
4.49 kB
import numpy as np
import torch
import torch.nn as nn
import gradio as gr
from PIL import Image
import torchvision.transforms as transforms
norm_layer = nn.InstanceNorm2d
class ResidualBlock(nn.Module):
def __init__(self, in_features):
super(ResidualBlock, self).__init__()
conv_block = [ nn.ReflectionPad2d(1),
nn.Conv2d(in_features, in_features, 3),
norm_layer(in_features),
nn.ReLU(inplace=True),
nn.ReflectionPad2d(1),
nn.Conv2d(in_features, in_features, 3),
norm_layer(in_features)
]
self.conv_block = nn.Sequential(*conv_block)
def forward(self, x):
return x + self.conv_block(x)
class Generator(nn.Module):
def __init__(self, input_nc, output_nc, n_residual_blocks=9, sigmoid=True):
super(Generator, self).__init__()
model0 = [ nn.ReflectionPad2d(3),
nn.Conv2d(input_nc, 64, 7),
norm_layer(64),
nn.ReLU(inplace=True) ]
self.model0 = nn.Sequential(*model0)
model1 = []
in_features = 64
out_features = in_features*2
for _ in range(2):
model1 += [ nn.Conv2d(in_features, out_features, 3, stride=2, padding=1),
norm_layer(out_features),
nn.ReLU(inplace=True) ]
in_features = out_features
out_features = in_features*2
self.model1 = nn.Sequential(*model1)
model2 = []
for _ in range(n_residual_blocks):
model2 += [ResidualBlock(in_features)]
self.model2 = nn.Sequential(*model2)
model3 = []
out_features = in_features//2
for _ in range(2):
model3 += [ nn.ConvTranspose2d(in_features, out_features, 3, stride=2, padding=1, output_padding=1),
norm_layer(out_features),
nn.ReLU(inplace=True) ]
in_features = out_features
out_features = in_features//2
self.model3 = nn.Sequential(*model3)
model4 = [ nn.ReflectionPad2d(3),
nn.Conv2d(64, output_nc, 7)]
if sigmoid:
model4 += [nn.Sigmoid()]
self.model4 = nn.Sequential(*model4)
def forward(self, x, cond=None):
out = self.model0(x)
out = self.model1(out)
out = self.model2(out)
out = self.model3(out)
out = self.model4(out)
return out
model1 = Generator(3, 1, 3)
model1.load_state_dict(torch.load('model.pth', map_location=torch.device('cpu')))
model1.eval()
model2 = Generator(3, 1, 3)
model2.load_state_dict(torch.load('model2.pth', map_location=torch.device('cpu')))
model2.eval()
def predict(input_img, ver):
input_img = Image.open(input_img)
transform = transforms.Compose([transforms.Resize(256, Image.BICUBIC), transforms.ToTensor()])
input_img = transform(input_img)
input_img = torch.unsqueeze(input_img, 0)
drawing = 0
with torch.no_grad():
if ver == 'Simple Lines':
drawing = model2(input_img)[0].detach()
else:
drawing = model1(input_img)[0].detach()
drawing = transforms.ToPILImage()(drawing)
return drawing
title="Art Style Line Drawings - Foods and Nutrition"
description="Art Style Line Drawings 🥐🥑🥒🥓🥔🥕🥖🥗🥘🥙🥚🥛🥜🥝🥞🥟🥠🥡🥢🥣🥤🥥🥦🥧🥨🥩🥪🥫🥬🥭🥮🥯"
# article = "<p style='text-align: center'></p>"
examples=[
['Baubles-standard-scale-20_00x-gigapixel.jpeg', 'Simple Lines'],
['Mannerism Klee 3-gigapixel-art-scale-6_00x.png', 'Simple Lines'],
['Cubism Klimt-gigapixel-art-scale-6_00x.png', 'Complex Lines'],
['Magic Realism Dore rainbow 5-gigapixel-art-scale-6_00x.png', 'Simple Lines'],
['Brain Coral-gigapixel-art-scale-6_00x.jpg', 'Complex Lines'],
['Dark Ritual Wisp Loop-art-scale-6_00x-gigapixel.png', 'Simple Lines'],
['Dungeons and Dragons Cartoon-art-scale-6_00x-gigapixel.png', 'Complex Lines'],
['Fantasy Art 2-art-scale-6_00x-gigapixel.png', 'Simple Lines']
]
iface = gr.Interface(predict, [gr.inputs.Image(type='filepath'),
gr.inputs.Radio(['Complex Lines','Simple Lines'], type="value", default='Simple Lines', label='version')],
gr.outputs.Image(type="pil"), title=title,description=description,examples=examples)
#iface.launch()
iface.launch()