Abhilasha
fix
782218d
import streamlit as st
# wide layout
st.set_page_config(layout="wide")
from streamlit_drawable_canvas import st_canvas
from PIL import Image
from typing import Union
import numpy as np
from models import make_image_controlnet, make_inpainting
from preprocessing import preprocess_seg_mask, get_image, get_mask
from explanation import make_inpainting_explanation, make_regeneration_explanation, make_segmentation_explanation
def on_upload() -> None:
"""Upload image to the canvas."""
if 'input_image' in st.session_state and st.session_state['input_image'] is not None:
image = Image.open(st.session_state['input_image']).convert('RGB')
st.session_state['initial_image'] = image
st.session_state.pop('seg', None)
st.session_state.pop('unique_colors', None)
st.session_state.pop('output_image', None)
def check_reset_state() -> bool:
"""Check whether the UI elements need to be reset"""
if st.session_state.get('reset_canvas', False):
st.session_state['reset_canvas'] = False
return True
return False
def move_image(source: Union[str, Image.Image], dest: str, rerun: bool = True, remove_state: bool = True) -> None:
"""Move image from source to destination."""
source_image = source if isinstance(source, Image.Image) else st.session_state.get(source)
if remove_state:
st.session_state['reset_canvas'] = True
st.session_state.pop('seg', None)
st.session_state.pop('unique_colors', None)
if source_image:
st.session_state[dest] = source_image
if rerun:
st.experimental_rerun()
def make_canvas_dict(canvas_color, brush, paint_mode, _reset_state):
"""Create a dictionary for the canvas settings."""
background_image = st.session_state.get('initial_image', None)
if isinstance(background_image, str): # Convert if it's a file path
try:
background_image = Image.open(background_image)
except Exception as e:
st.error(f"Failed to load background image: {e}")
background_image = None
canvas_dict = {
'fill_color': canvas_color,
'stroke_color': canvas_color,
'background_color': "#FFFFFF",
'background_image': background_image if isinstance(background_image, Image.Image) else None,
'stroke_width': brush,
'initial_drawing': {'version': '4.4.0', 'objects': []} if _reset_state else None,
'update_streamlit': True,
'height': 512,
'width': 512,
'drawing_mode': paint_mode,
'key': "canvas",
}
return canvas_dict
def make_prompt_row():
"""Create input fields for positive and negative prompts."""
col_0_0, col_0_1 = st.columns(2)
with col_0_0:
st.text_input(label="Positive prompt",
value="a realistic photograph of a room, high resolution",
key='positive_prompt')
with col_0_1:
st.text_input(label="Negative prompt",
value="watermark, banner, logo, contact info, text, deformed, blurry, lowres",
key='negative_prompt')
def make_sidebar():
"""Create the sidebar with upload options and settings."""
with st.sidebar:
input_image = st.file_uploader("Upload Image",
type=["png", "jpg"],
key='input_image',
on_change=on_upload)
generation_mode = 'Inpainting'
paint_mode = 'freedraw'
brush = 30 if paint_mode == "freedraw" else 5
color_chooser = "#000000"
return input_image, generation_mode, brush, color_chooser, paint_mode
def make_output_image():
"""Display the output image."""
output_image = st.session_state.get('output_image', None)
if isinstance(output_image, np.ndarray):
output_image = Image.fromarray(output_image)
if output_image is None:
output_image = Image.new('RGB', (512, 512), (255, 255, 255))
st.write("#### Output image")
st.image(output_image, width=512)
if st.button("Move to input image"):
move_image('output_image', 'initial_image', remove_state=True, rerun=True)
def make_editing_canvas(canvas_color, brush, _reset_state, generation_mode, paint_mode):
"""Create an editable drawing canvas."""
st.write("#### Input image")
canvas_dict = make_canvas_dict(
canvas_color=canvas_color,
paint_mode=paint_mode,
brush=brush,
_reset_state=_reset_state
)
if generation_mode == "Segmentation":
st_canvas(**canvas_dict)
elif generation_mode == "Inpainting":
image = get_image()
if image is None:
st.error("Error: Could not load image for inpainting.")
return
canvas = st_canvas(**canvas_dict)
if st.button("Generate images", key='generate_button'):
canvas_mask = getattr(canvas, 'image_data', None)
if canvas_mask is None:
st.error("Error: No mask data found on canvas.")
return
if not isinstance(canvas_mask, np.ndarray):
canvas_mask = np.array(canvas_mask)
mask = get_mask(canvas_mask)
with st.spinner(text="Generating new images..."):
result_image = make_inpainting(
positive_prompt=st.session_state['positive_prompt'],
image=Image.fromarray(image),
mask_image=mask,
negative_prompt=st.session_state['negative_prompt'],
)
if isinstance(result_image, np.ndarray):
result_image = Image.fromarray(result_image)
st.session_state['output_image'] = result_image
def main():
"""Main Streamlit app function."""
st.write("## Virtual Staging")
input_image, generation_mode, brush, color_chooser, paint_mode = make_sidebar()
if 'initial_image' not in st.session_state or st.session_state['initial_image'] is None:
st.write("Please upload an image to begin")
else:
make_prompt_row()
_reset_state = check_reset_state()
col1, col2 = st.columns(2)
with col1:
make_editing_canvas(
canvas_color=color_chooser,
brush=brush,
_reset_state=_reset_state,
generation_mode=generation_mode,
paint_mode=paint_mode
)
with col2:
make_output_image()
if __name__ == "__main__":
main()