File size: 3,851 Bytes
eb8f9d9
daf7bc7
eb8f9d9
daf7bc7
eb8f9d9
 
daf7bc7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8f6a226
eb8f9d9
daf7bc7
 
 
 
eb8f9d9
 
daf7bc7
 
 
 
eb8f9d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
daf7bc7
 
eb8f9d9
 
 
 
 
 
 
 
 
 
 
 
 
 
f161731
eb8f9d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
daf7bc7
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
import os
import cv2
import time
import numpy as np
import streamlit as st
import tensorflow as tf
from threading import Thread


class WebcamVideoStream:
    def __init__(self, src=0):
        self.stream = cv2.VideoCapture(src)
        self.grabbed, self.frame = self.stream.read()
        self.stopped = False

    def start(self):
        Thread(target=self.update, args=()).start()
        return self

    def update(self):
        while True:
            if self.stopped:
                return

            self.grabbed, self.frame = self.stream.read()

    def read(self):
        return self.frame

    def stop(self):
        self.stopped = True


def style_transfer_direct(image, style_image, model, resize=None):
    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    if resize:
        image = cv2.resize(np.array(image, dtype=np.float32), (512, 512))
    else:
        image = np.array(image, dtype=np.float32)
    image = tf.convert_to_tensor(image[np.newaxis, ...] / 255.)
    resp = model(image, style_image)
    stylized_image = resp[0]
    return stylized_image.numpy()[0]


def get_style_dictionary():
    style_lookup = {}
    for file in os.listdir('assets/template_styles'):
        with open(os.path.join('assets/template_styles', file), 'rb') as f:
            style = f.read()
        style_image = tf.io.decode_image(style)
        style_image = np.array(style_image[np.newaxis, ...], dtype=np.float32) / 255.
        style_image = tf.image.resize(style_image, (256, 256))
        style_image = tf.convert_to_tensor(style_image)
        style_lookup[file] = style_image
    return style_lookup

@st.cache(persist=True)
def get_custom_style(image_bytes):
    style_image = tf.io.decode_image(image_bytes)
    style_image = np.array(style_image[np.newaxis, ...], dtype=np.float32) / 255.
    style_image = tf.image.resize(style_image, (256, 256))
    style_image = tf.convert_to_tensor(style_image)
    return style_image


def main():
    model = tf.saved_model.load('style/1')
    st.title("Neural Style-Transfer Webcam")
    st.subheader('Style Transfer')
    webcam_flag = st.sidebar.checkbox('Enable Webcam', value=False)
    style_flag = st.sidebar.checkbox('Enable Style Transfer', value=False)
    with open('assets/bonk.png', 'rb') as f:
        default_bytes = f.read()
    placeholder_image = st.image(default_bytes)

    style_dictionary = get_style_dictionary()
    style_options = st.sidebar.selectbox(label='Example Styles', options=list(style_dictionary.keys()))
    custom_style = st.sidebar.file_uploader('Upload Style:', type=['.jpg', '.jpeg', '.png'])
    frame_rate = st.text(f'Frames per second: 0')

    if webcam_flag:
        video_capture = WebcamVideoStream(0)
        start = time.time()
        total_frames = 0
        try:
            video_capture.start()
            if custom_style is not None:
                custom_style_bytes = custom_style.getvalue()
                style_image = get_custom_style(custom_style_bytes)
            else:
                style_image = style_dictionary[style_options]

            st.sidebar.subheader("Style Image:")
            st.sidebar.image(np.array(style_image.numpy()[0] * 255., dtype=np.uint8), use_column_width=True)

            while webcam_flag:
                content_image = video_capture.read()
                if style_flag:
                    transfer = style_transfer_direct(content_image, style_image, model)
                    placeholder_image.image(transfer)
                else:
                    placeholder_image.image(content_image, channels='BGR')
                total_frames += 1
                end = time.time()
                frame_rate.text(f'Frames per second: {total_frames / (end - start)}')
        finally:
            video_capture.stop()
            del video_capture


if __name__ == "__main__":
    main()