Last commit not found
import streamlit as st | |
# wide layout | |
st.set_page_config(layout="wide") | |
from streamlit_drawable_canvas import st_canvas | |
from PIL import Image | |
from typing import Union | |
import random | |
import numpy as np | |
import os | |
import time | |
from models import make_image_controlnet, make_inpainting | |
from preprocessing import preprocess_seg_mask, get_image, get_mask | |
from explanation import make_inpainting_explanation, make_regeneration_explanation, make_segmentation_explanation | |
def on_upload() -> None: | |
"""Upload image to the canvas.""" | |
if 'input_image' in st.session_state and st.session_state['input_image'] is not None: | |
image = Image.open(st.session_state['input_image']).convert('RGB') | |
st.session_state['initial_image'] = image | |
if 'seg' in st.session_state: | |
del st.session_state['seg'] | |
if 'unique_colors' in st.session_state: | |
del st.session_state['unique_colors'] | |
if 'output_image' in st.session_state: | |
del st.session_state['output_image'] | |
def make_image_row(image_0, image_1): | |
col_0, col_1 = st.columns(2) | |
with col_0: | |
st.image(image_0, use_column_width=True) | |
with col_1: | |
st.image(image_1, use_column_width=True) | |
def check_reset_state() -> bool: | |
"""Check whether the UI elements need to be reset""" | |
if ('reset_canvas' in st.session_state and st.session_state['reset_canvas']): | |
st.session_state['reset_canvas'] = False | |
return True | |
st.session_state['reset_canvas'] = False | |
return False | |
def move_image(source: Union[str, Image.Image], | |
dest: str, | |
rerun: bool = True, | |
remove_state: bool = True) -> None: | |
"""Move image from source to destination.""" | |
source_image = source if isinstance(source, Image.Image) else st.session_state[source] | |
if remove_state: | |
st.session_state['reset_canvas'] = True | |
if 'seg' in st.session_state: | |
del st.session_state['seg'] | |
if 'unique_colors' in st.session_state: | |
del st.session_state['unique_colors'] | |
st.session_state[dest] = source_image | |
if rerun: | |
st.experimental_rerun() | |
def on_change_radio() -> None: | |
"""Reset the UI elements when the radio button is changed.""" | |
st.session_state['reset_canvas'] = True | |
def make_canvas_dict(canvas_color, brush, paint_mode, _reset_state): | |
canvas_dict = { | |
'fill_color': canvas_color, | |
'stroke_color': canvas_color, | |
'background_color': "#FFFFFF", | |
'background_image': st.session_state['initial_image'] if 'initial_image' in st.session_state else None, | |
'stroke_width': brush, | |
'initial_drawing': {'version': '4.4.0', 'objects': []} if _reset_state else None, | |
'update_streamlit': True, | |
'height': 512, | |
'width': 512, | |
'drawing_mode': paint_mode, | |
'key': "canvas", | |
} | |
return canvas_dict | |
def make_prompt_row(): | |
col_0_0, col_0_1 = st.columns(2) | |
with col_0_0: | |
st.text_input(label="Positive prompt", | |
value="a realistic photograph of a room, high resolution", | |
key='positive_prompt') | |
with col_0_1: | |
st.text_input(label="Negative prompt", | |
value="watermark, banner, logo, watermark, contactinfo, text, deformed, blurry, blur, lowres", | |
key='negative_prompt') | |
def make_sidebar(): | |
with st.sidebar: | |
input_image = st.file_uploader("Upload Image", | |
type=["png", "jpg"], | |
key='input_image', | |
on_change=on_upload) | |
generation_mode = 'Inpainting' | |
paint_mode = 'freedraw' | |
brush = 30 if paint_mode == "freedraw" else 5 | |
color_chooser = "#000000" | |
return input_image, generation_mode, brush, color_chooser, paint_mode | |
def make_output_image(): | |
if 'output_image' in st.session_state: | |
output_image = st.session_state['output_image'] | |
if isinstance(output_image, np.ndarray): | |
output_image = Image.fromarray(output_image) | |
if isinstance(output_image, Image.Image): | |
output_image = output_image.resize((512, 512)) | |
else: | |
output_image = Image.new('RGB', (512, 512), (255, 255, 255)) | |
st.write("#### Output image") | |
st.image(output_image, width=512) | |
if st.button("Move to input image"): | |
move_image('output_image', 'initial_image', remove_state=True, rerun=True) | |
def make_editing_canvas(canvas_color, brush, _reset_state, generation_mode, paint_mode): | |
st.write("#### Input image") | |
canvas_dict = make_canvas_dict( | |
canvas_color=canvas_color, | |
paint_mode=paint_mode, | |
brush=brush, | |
_reset_state=_reset_state | |
) | |
if generation_mode == "Segmentation": | |
canvas = st_canvas(**canvas_dict) | |
elif generation_mode == "Inpainting": | |
image = get_image() # Assuming this function exists in your preprocessing module | |
canvas = st_canvas(**canvas_dict) | |
if st.button("Generate images", key='generate_button'): | |
canvas_mask = canvas.image_data | |
if not isinstance(canvas_mask, np.ndarray): | |
canvas_mask = np.array(canvas_mask) | |
mask = get_mask(canvas_mask) # Assuming this function exists in your preprocessing module | |
with st.spinner(text="Generating new images"): | |
print("Making image") | |
result_image = make_inpainting( | |
positive_prompt=st.session_state['positive_prompt'], | |
image=Image.fromarray(image), | |
mask_image=mask, | |
negative_prompt=st.session_state['negative_prompt'], | |
) | |
if isinstance(result_image, np.ndarray): | |
result_image = Image.fromarray(result_image) | |
st.session_state['output_image'] = result_image | |
def main(): | |
st.write("## Virtual Staging") | |
input_image, generation_mode, brush, color_chooser, paint_mode = make_sidebar() | |
if not ('initial_image' in st.session_state and st.session_state['initial_image'] is not None): | |
st.write("Please upload an image to begin") | |
else: | |
make_prompt_row() | |
_reset_state = check_reset_state() | |
col1, col2 = st.columns(2) | |
with col1: | |
make_editing_canvas( | |
canvas_color=color_chooser, | |
brush=brush, | |
_reset_state=_reset_state, | |
generation_mode=generation_mode, | |
paint_mode=paint_mode | |
) | |
with col2: | |
make_output_image() | |
if __name__ == "__main__": | |
main() | |