neural-style / app.py
dogeplusplus's picture
No caching.
b260443
raw
history blame
2.53 kB
import os
import cv2
import numpy as np
import streamlit as st
import tensorflow as tf
@st.cache(suppress_st_warning=True)
def load_model():
model = tf.keras.models.load_model("style/1")
return model
def apply_style_transfer(model, content, style, resize=None):
content = np.array(content, dtype=np.float32) / 255.
style = np.array(style, dtype=np.float32) / 255.
if resize:
content = cv2.resize(content, (512, 512))
style = cv2.resize(style, (512, 512))
stylized_image = model(tf.constant(content[np.newaxis, ...]), tf.constant(style[np.newaxis, ...]))
stylized_image = stylized_image[0] * 255
stylized_image = np.array(stylized_image, dtype=np.uint8)
stylized_image = stylized_image
return stylized_image
def main():
model = load_model()
st.title("Neural Style-Transfer App")
st.write("`neural-style` is a pre-trained model from Tensorflow-Hub that allows you to apply styles to images and create pretty art. This app allows you to upload your own content or style images to create some funky effects. We provide some example styles which you can use.")
col1, col2 = st.columns(2)
content_file = st.sidebar.file_uploader('Upload Image', type=['jpg', 'jpeg', 'png'])
style_file = st.sidebar.file_uploader('Upload Style', type=['jpg', 'jpeg', 'png'])
style_options = st.sidebar.selectbox(label='Example Styles', options=os.listdir('assets/template_styles'))
col1.subheader('Content Image')
col2.subheader('Style Transfer')
st.sidebar.subheader('Style Image')
show_image = col1.empty()
show_style = col2.empty()
style = None
content = None
if content_file:
content = content_file.getvalue()
show_image.image(content, use_column_width=True)
if style_file:
style = style_file.getvalue()
st.sidebar.image(style, use_column_width=True)
elif style_options is not None:
with open(os.path.join('assets/template_styles', style_options), 'rb') as f:
style = f.read()
st.sidebar.image(style, use_column_width=True)
if content is not None and style is not None:
content_image = tf.io.decode_image(content)
style_image = tf.image.resize(tf.io.decode_image(style), (256, 256))
with st.spinner('Generating style transfer...'):
style_transfer = apply_style_transfer(model, content_image, style_image)
show_style.image(style_transfer, use_column_width=True)
if __name__ == "__main__":
main()