File size: 6,637 Bytes
58d0adc 782218d 58d0adc 782218d 58d0adc 782218d 58d0adc 782218d 58d0adc 782218d 58d0adc 782218d 58d0adc 782218d 58d0adc 782218d 58d0adc 782218d 58d0adc 782218d 58d0adc 782218d 58d0adc 782218d 58d0adc 782218d 58d0adc 782218d 58d0adc 782218d 58d0adc 782218d 58d0adc 782218d 58d0adc 782218d 58d0adc 782218d 58d0adc 782218d 58d0adc 782218d 58d0adc 782218d 58d0adc 782218d 58d0adc 782218d 58d0adc 782218d 58d0adc 782218d 58d0adc 782218d 58d0adc |
|
import streamlit as st
# wide layout
st.set_page_config(layout="wide")
from streamlit_drawable_canvas import st_canvas
from PIL import Image
from typing import Union
import numpy as np
from models import make_image_controlnet, make_inpainting
from preprocessing import preprocess_seg_mask, get_image, get_mask
from explanation import make_inpainting_explanation, make_regeneration_explanation, make_segmentation_explanation
def on_upload() -> None:
"""Upload image to the canvas."""
if 'input_image' in st.session_state and st.session_state['input_image'] is not None:
image = Image.open(st.session_state['input_image']).convert('RGB')
st.session_state['initial_image'] = image
st.session_state.pop('seg', None)
st.session_state.pop('unique_colors', None)
st.session_state.pop('output_image', None)
def check_reset_state() -> bool:
"""Check whether the UI elements need to be reset"""
if st.session_state.get('reset_canvas', False):
st.session_state['reset_canvas'] = False
return True
return False
def move_image(source: Union[str, Image.Image], dest: str, rerun: bool = True, remove_state: bool = True) -> None:
"""Move image from source to destination."""
source_image = source if isinstance(source, Image.Image) else st.session_state.get(source)
if remove_state:
st.session_state['reset_canvas'] = True
st.session_state.pop('seg', None)
st.session_state.pop('unique_colors', None)
if source_image:
st.session_state[dest] = source_image
if rerun:
st.experimental_rerun()
def make_canvas_dict(canvas_color, brush, paint_mode, _reset_state):
"""Create a dictionary for the canvas settings."""
background_image = st.session_state.get('initial_image', None)
if isinstance(background_image, str): # Convert if it's a file path
try:
background_image = Image.open(background_image)
except Exception as e:
st.error(f"Failed to load background image: {e}")
background_image = None
canvas_dict = {
'fill_color': canvas_color,
'stroke_color': canvas_color,
'background_color': "#FFFFFF",
'background_image': background_image if isinstance(background_image, Image.Image) else None,
'stroke_width': brush,
'initial_drawing': {'version': '4.4.0', 'objects': []} if _reset_state else None,
'update_streamlit': True,
'height': 512,
'width': 512,
'drawing_mode': paint_mode,
'key': "canvas",
}
return canvas_dict
def make_prompt_row():
"""Create input fields for positive and negative prompts."""
col_0_0, col_0_1 = st.columns(2)
with col_0_0:
st.text_input(label="Positive prompt",
value="a realistic photograph of a room, high resolution",
key='positive_prompt')
with col_0_1:
st.text_input(label="Negative prompt",
value="watermark, banner, logo, contact info, text, deformed, blurry, lowres",
key='negative_prompt')
def make_sidebar():
"""Create the sidebar with upload options and settings."""
with st.sidebar:
input_image = st.file_uploader("Upload Image",
type=["png", "jpg"],
key='input_image',
on_change=on_upload)
generation_mode = 'Inpainting'
paint_mode = 'freedraw'
brush = 30 if paint_mode == "freedraw" else 5
color_chooser = "#000000"
return input_image, generation_mode, brush, color_chooser, paint_mode
def make_output_image():
"""Display the output image."""
output_image = st.session_state.get('output_image', None)
if isinstance(output_image, np.ndarray):
output_image = Image.fromarray(output_image)
if output_image is None:
output_image = Image.new('RGB', (512, 512), (255, 255, 255))
st.write("#### Output image")
st.image(output_image, width=512)
if st.button("Move to input image"):
move_image('output_image', 'initial_image', remove_state=True, rerun=True)
def make_editing_canvas(canvas_color, brush, _reset_state, generation_mode, paint_mode):
"""Create an editable drawing canvas."""
st.write("#### Input image")
canvas_dict = make_canvas_dict(
canvas_color=canvas_color,
paint_mode=paint_mode,
brush=brush,
_reset_state=_reset_state
)
if generation_mode == "Segmentation":
st_canvas(**canvas_dict)
elif generation_mode == "Inpainting":
image = get_image()
if image is None:
st.error("Error: Could not load image for inpainting.")
return
canvas = st_canvas(**canvas_dict)
if st.button("Generate images", key='generate_button'):
canvas_mask = getattr(canvas, 'image_data', None)
if canvas_mask is None:
st.error("Error: No mask data found on canvas.")
return
if not isinstance(canvas_mask, np.ndarray):
canvas_mask = np.array(canvas_mask)
mask = get_mask(canvas_mask)
with st.spinner(text="Generating new images..."):
result_image = make_inpainting(
positive_prompt=st.session_state['positive_prompt'],
image=Image.fromarray(image),
mask_image=mask,
negative_prompt=st.session_state['negative_prompt'],
)
if isinstance(result_image, np.ndarray):
result_image = Image.fromarray(result_image)
st.session_state['output_image'] = result_image
def main():
"""Main Streamlit app function."""
st.write("## Virtual Staging")
input_image, generation_mode, brush, color_chooser, paint_mode = make_sidebar()
if 'initial_image' not in st.session_state or st.session_state['initial_image'] is None:
st.write("Please upload an image to begin")
else:
make_prompt_row()
_reset_state = check_reset_state()
col1, col2 = st.columns(2)
with col1:
make_editing_canvas(
canvas_color=color_chooser,
brush=brush,
_reset_state=_reset_state,
generation_mode=generation_mode,
paint_mode=paint_mode
)
with col2:
make_output_image()
if __name__ == "__main__":
main()
|