File size: 6,727 Bytes
58d0adc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 |
import streamlit as st
# wide layout
st.set_page_config(layout="wide")
from streamlit_drawable_canvas import st_canvas
from PIL import Image
from typing import Union
import random
import numpy as np
import os
import time
from models import make_image_controlnet, make_inpainting
from preprocessing import preprocess_seg_mask, get_image, get_mask
from explanation import make_inpainting_explanation, make_regeneration_explanation, make_segmentation_explanation
def on_upload() -> None:
"""Upload image to the canvas."""
if 'input_image' in st.session_state and st.session_state['input_image'] is not None:
image = Image.open(st.session_state['input_image']).convert('RGB')
st.session_state['initial_image'] = image
if 'seg' in st.session_state:
del st.session_state['seg']
if 'unique_colors' in st.session_state:
del st.session_state['unique_colors']
if 'output_image' in st.session_state:
del st.session_state['output_image']
def make_image_row(image_0, image_1):
col_0, col_1 = st.columns(2)
with col_0:
st.image(image_0, use_column_width=True)
with col_1:
st.image(image_1, use_column_width=True)
def check_reset_state() -> bool:
"""Check whether the UI elements need to be reset"""
if ('reset_canvas' in st.session_state and st.session_state['reset_canvas']):
st.session_state['reset_canvas'] = False
return True
st.session_state['reset_canvas'] = False
return False
def move_image(source: Union[str, Image.Image],
dest: str,
rerun: bool = True,
remove_state: bool = True) -> None:
"""Move image from source to destination."""
source_image = source if isinstance(source, Image.Image) else st.session_state[source]
if remove_state:
st.session_state['reset_canvas'] = True
if 'seg' in st.session_state:
del st.session_state['seg']
if 'unique_colors' in st.session_state:
del st.session_state['unique_colors']
st.session_state[dest] = source_image
if rerun:
st.experimental_rerun()
def on_change_radio() -> None:
"""Reset the UI elements when the radio button is changed."""
st.session_state['reset_canvas'] = True
def make_canvas_dict(canvas_color, brush, paint_mode, _reset_state):
canvas_dict = {
'fill_color': canvas_color,
'stroke_color': canvas_color,
'background_color': "#FFFFFF",
'background_image': st.session_state['initial_image'] if 'initial_image' in st.session_state else None,
'stroke_width': brush,
'initial_drawing': {'version': '4.4.0', 'objects': []} if _reset_state else None,
'update_streamlit': True,
'height': 512,
'width': 512,
'drawing_mode': paint_mode,
'key': "canvas",
}
return canvas_dict
def make_prompt_row():
col_0_0, col_0_1 = st.columns(2)
with col_0_0:
st.text_input(label="Positive prompt",
value="a realistic photograph of a room, high resolution",
key='positive_prompt')
with col_0_1:
st.text_input(label="Negative prompt",
value="watermark, banner, logo, watermark, contactinfo, text, deformed, blurry, blur, lowres",
key='negative_prompt')
def make_sidebar():
with st.sidebar:
input_image = st.file_uploader("Upload Image",
type=["png", "jpg"],
key='input_image',
on_change=on_upload)
generation_mode = 'Inpainting'
paint_mode = 'freedraw'
brush = 30 if paint_mode == "freedraw" else 5
color_chooser = "#000000"
return input_image, generation_mode, brush, color_chooser, paint_mode
def make_output_image():
if 'output_image' in st.session_state:
output_image = st.session_state['output_image']
if isinstance(output_image, np.ndarray):
output_image = Image.fromarray(output_image)
if isinstance(output_image, Image.Image):
output_image = output_image.resize((512, 512))
else:
output_image = Image.new('RGB', (512, 512), (255, 255, 255))
st.write("#### Output image")
st.image(output_image, width=512)
if st.button("Move to input image"):
move_image('output_image', 'initial_image', remove_state=True, rerun=True)
def make_editing_canvas(canvas_color, brush, _reset_state, generation_mode, paint_mode):
st.write("#### Input image")
canvas_dict = make_canvas_dict(
canvas_color=canvas_color,
paint_mode=paint_mode,
brush=brush,
_reset_state=_reset_state
)
if generation_mode == "Segmentation":
canvas = st_canvas(**canvas_dict)
elif generation_mode == "Inpainting":
image = get_image() # Assuming this function exists in your preprocessing module
canvas = st_canvas(**canvas_dict)
if st.button("Generate images", key='generate_button'):
canvas_mask = canvas.image_data
if not isinstance(canvas_mask, np.ndarray):
canvas_mask = np.array(canvas_mask)
mask = get_mask(canvas_mask) # Assuming this function exists in your preprocessing module
with st.spinner(text="Generating new images"):
print("Making image")
result_image = make_inpainting(
positive_prompt=st.session_state['positive_prompt'],
image=Image.fromarray(image),
mask_image=mask,
negative_prompt=st.session_state['negative_prompt'],
)
if isinstance(result_image, np.ndarray):
result_image = Image.fromarray(result_image)
st.session_state['output_image'] = result_image
def main():
st.write("## Virtual Staging")
input_image, generation_mode, brush, color_chooser, paint_mode = make_sidebar()
if not ('initial_image' in st.session_state and st.session_state['initial_image'] is not None):
st.write("Please upload an image to begin")
else:
make_prompt_row()
_reset_state = check_reset_state()
col1, col2 = st.columns(2)
with col1:
make_editing_canvas(
canvas_color=color_chooser,
brush=brush,
_reset_state=_reset_state,
generation_mode=generation_mode,
paint_mode=paint_mode
)
with col2:
make_output_image()
if __name__ == "__main__":
main()
|