|
import streamlit as st |
|
|
|
st.set_page_config(layout="wide") |
|
|
|
from streamlit_drawable_canvas import st_canvas |
|
from PIL import Image |
|
from typing import Union |
|
import random |
|
import numpy as np |
|
import os |
|
import time |
|
|
|
from models import make_image_controlnet, make_inpainting |
|
from preprocessing import preprocess_seg_mask, get_image, get_mask |
|
from explanation import make_inpainting_explanation, make_regeneration_explanation, make_segmentation_explanation |
|
|
|
|
|
|
|
def on_upload() -> None: |
|
"""Upload image to the canvas.""" |
|
if 'input_image' in st.session_state and st.session_state['input_image'] is not None: |
|
image = Image.open(st.session_state['input_image']).convert('RGB') |
|
st.session_state['initial_image'] = image |
|
if 'seg' in st.session_state: |
|
del st.session_state['seg'] |
|
if 'unique_colors' in st.session_state: |
|
del st.session_state['unique_colors'] |
|
if 'output_image' in st.session_state: |
|
del st.session_state['output_image'] |
|
|
|
def make_image_row(image_0, image_1): |
|
col_0, col_1 = st.columns(2) |
|
with col_0: |
|
st.image(image_0, use_column_width=True) |
|
with col_1: |
|
st.image(image_1, use_column_width=True) |
|
|
|
def check_reset_state() -> bool: |
|
"""Check whether the UI elements need to be reset""" |
|
if ('reset_canvas' in st.session_state and st.session_state['reset_canvas']): |
|
st.session_state['reset_canvas'] = False |
|
return True |
|
st.session_state['reset_canvas'] = False |
|
return False |
|
|
|
|
|
def move_image(source: Union[str, Image.Image], |
|
dest: str, |
|
rerun: bool = True, |
|
remove_state: bool = True) -> None: |
|
"""Move image from source to destination.""" |
|
source_image = source if isinstance(source, Image.Image) else st.session_state[source] |
|
|
|
if remove_state: |
|
st.session_state['reset_canvas'] = True |
|
if 'seg' in st.session_state: |
|
del st.session_state['seg'] |
|
if 'unique_colors' in st.session_state: |
|
del st.session_state['unique_colors'] |
|
|
|
st.session_state[dest] = source_image |
|
if rerun: |
|
st.experimental_rerun() |
|
|
|
|
|
def on_change_radio() -> None: |
|
"""Reset the UI elements when the radio button is changed.""" |
|
st.session_state['reset_canvas'] = True |
|
|
|
|
|
def make_canvas_dict(canvas_color, brush, paint_mode, _reset_state): |
|
canvas_dict = { |
|
'fill_color': canvas_color, |
|
'stroke_color': canvas_color, |
|
'background_color': "#FFFFFF", |
|
'background_image': st.session_state['initial_image'] if 'initial_image' in st.session_state else None, |
|
'stroke_width': brush, |
|
'initial_drawing': {'version': '4.4.0', 'objects': []} if _reset_state else None, |
|
'update_streamlit': True, |
|
'height': 512, |
|
'width': 512, |
|
'drawing_mode': paint_mode, |
|
'key': "canvas", |
|
} |
|
return canvas_dict |
|
|
|
|
|
def make_prompt_row(): |
|
col_0_0, col_0_1 = st.columns(2) |
|
with col_0_0: |
|
st.text_input(label="Positive prompt", |
|
value="a realistic photograph of a room, high resolution", |
|
key='positive_prompt') |
|
with col_0_1: |
|
st.text_input(label="Negative prompt", |
|
value="watermark, banner, logo, watermark, contactinfo, text, deformed, blurry, blur, lowres", |
|
key='negative_prompt') |
|
|
|
def make_sidebar(): |
|
with st.sidebar: |
|
input_image = st.file_uploader("Upload Image", |
|
type=["png", "jpg"], |
|
key='input_image', |
|
on_change=on_upload) |
|
generation_mode = 'Inpainting' |
|
paint_mode = 'freedraw' |
|
brush = 30 if paint_mode == "freedraw" else 5 |
|
color_chooser = "#000000" |
|
return input_image, generation_mode, brush, color_chooser, paint_mode |
|
|
|
|
|
|
|
def make_output_image(): |
|
if 'output_image' in st.session_state: |
|
output_image = st.session_state['output_image'] |
|
if isinstance(output_image, np.ndarray): |
|
output_image = Image.fromarray(output_image) |
|
if isinstance(output_image, Image.Image): |
|
output_image = output_image.resize((512, 512)) |
|
else: |
|
output_image = Image.new('RGB', (512, 512), (255, 255, 255)) |
|
|
|
st.write("#### Output image") |
|
st.image(output_image, width=512) |
|
if st.button("Move to input image"): |
|
move_image('output_image', 'initial_image', remove_state=True, rerun=True) |
|
|
|
def make_editing_canvas(canvas_color, brush, _reset_state, generation_mode, paint_mode): |
|
st.write("#### Input image") |
|
canvas_dict = make_canvas_dict( |
|
canvas_color=canvas_color, |
|
paint_mode=paint_mode, |
|
brush=brush, |
|
_reset_state=_reset_state |
|
) |
|
|
|
if generation_mode == "Segmentation": |
|
canvas = st_canvas(**canvas_dict) |
|
elif generation_mode == "Inpainting": |
|
image = get_image() |
|
canvas = st_canvas(**canvas_dict) |
|
|
|
if st.button("Generate images", key='generate_button'): |
|
canvas_mask = canvas.image_data |
|
if not isinstance(canvas_mask, np.ndarray): |
|
canvas_mask = np.array(canvas_mask) |
|
mask = get_mask(canvas_mask) |
|
|
|
with st.spinner(text="Generating new images"): |
|
print("Making image") |
|
result_image = make_inpainting( |
|
positive_prompt=st.session_state['positive_prompt'], |
|
image=Image.fromarray(image), |
|
mask_image=mask, |
|
negative_prompt=st.session_state['negative_prompt'], |
|
) |
|
if isinstance(result_image, np.ndarray): |
|
result_image = Image.fromarray(result_image) |
|
st.session_state['output_image'] = result_image |
|
|
|
def main(): |
|
st.write("## Virtual Staging") |
|
|
|
input_image, generation_mode, brush, color_chooser, paint_mode = make_sidebar() |
|
|
|
if not ('initial_image' in st.session_state and st.session_state['initial_image'] is not None): |
|
st.write("Please upload an image to begin") |
|
else: |
|
make_prompt_row() |
|
_reset_state = check_reset_state() |
|
|
|
col1, col2 = st.columns(2) |
|
with col1: |
|
make_editing_canvas( |
|
canvas_color=color_chooser, |
|
brush=brush, |
|
_reset_state=_reset_state, |
|
generation_mode=generation_mode, |
|
paint_mode=paint_mode |
|
) |
|
|
|
with col2: |
|
make_output_image() |
|
|
|
if __name__ == "__main__": |
|
main() |
|
|
|
|