File size: 6,637 Bytes
58d0adc 782218d 58d0adc 782218d 58d0adc 782218d 58d0adc 782218d 58d0adc 782218d 58d0adc 782218d 58d0adc 782218d 58d0adc 782218d 58d0adc 782218d 58d0adc 782218d 58d0adc 782218d 58d0adc 782218d 58d0adc 782218d 58d0adc 782218d 58d0adc 782218d 58d0adc 782218d 58d0adc 782218d 58d0adc 782218d 58d0adc 782218d 58d0adc 782218d 58d0adc 782218d 58d0adc 782218d 58d0adc 782218d 58d0adc 782218d 58d0adc 782218d 58d0adc 782218d 58d0adc 782218d 58d0adc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 |
import streamlit as st
# wide layout
st.set_page_config(layout="wide")
from streamlit_drawable_canvas import st_canvas
from PIL import Image
from typing import Union
import numpy as np
from models import make_image_controlnet, make_inpainting
from preprocessing import preprocess_seg_mask, get_image, get_mask
from explanation import make_inpainting_explanation, make_regeneration_explanation, make_segmentation_explanation
def on_upload() -> None:
"""Upload image to the canvas."""
if 'input_image' in st.session_state and st.session_state['input_image'] is not None:
image = Image.open(st.session_state['input_image']).convert('RGB')
st.session_state['initial_image'] = image
st.session_state.pop('seg', None)
st.session_state.pop('unique_colors', None)
st.session_state.pop('output_image', None)
def check_reset_state() -> bool:
"""Check whether the UI elements need to be reset"""
if st.session_state.get('reset_canvas', False):
st.session_state['reset_canvas'] = False
return True
return False
def move_image(source: Union[str, Image.Image], dest: str, rerun: bool = True, remove_state: bool = True) -> None:
"""Move image from source to destination."""
source_image = source if isinstance(source, Image.Image) else st.session_state.get(source)
if remove_state:
st.session_state['reset_canvas'] = True
st.session_state.pop('seg', None)
st.session_state.pop('unique_colors', None)
if source_image:
st.session_state[dest] = source_image
if rerun:
st.experimental_rerun()
def make_canvas_dict(canvas_color, brush, paint_mode, _reset_state):
"""Create a dictionary for the canvas settings."""
background_image = st.session_state.get('initial_image', None)
if isinstance(background_image, str): # Convert if it's a file path
try:
background_image = Image.open(background_image)
except Exception as e:
st.error(f"Failed to load background image: {e}")
background_image = None
canvas_dict = {
'fill_color': canvas_color,
'stroke_color': canvas_color,
'background_color': "#FFFFFF",
'background_image': background_image if isinstance(background_image, Image.Image) else None,
'stroke_width': brush,
'initial_drawing': {'version': '4.4.0', 'objects': []} if _reset_state else None,
'update_streamlit': True,
'height': 512,
'width': 512,
'drawing_mode': paint_mode,
'key': "canvas",
}
return canvas_dict
def make_prompt_row():
"""Create input fields for positive and negative prompts."""
col_0_0, col_0_1 = st.columns(2)
with col_0_0:
st.text_input(label="Positive prompt",
value="a realistic photograph of a room, high resolution",
key='positive_prompt')
with col_0_1:
st.text_input(label="Negative prompt",
value="watermark, banner, logo, contact info, text, deformed, blurry, lowres",
key='negative_prompt')
def make_sidebar():
"""Create the sidebar with upload options and settings."""
with st.sidebar:
input_image = st.file_uploader("Upload Image",
type=["png", "jpg"],
key='input_image',
on_change=on_upload)
generation_mode = 'Inpainting'
paint_mode = 'freedraw'
brush = 30 if paint_mode == "freedraw" else 5
color_chooser = "#000000"
return input_image, generation_mode, brush, color_chooser, paint_mode
def make_output_image():
"""Display the output image."""
output_image = st.session_state.get('output_image', None)
if isinstance(output_image, np.ndarray):
output_image = Image.fromarray(output_image)
if output_image is None:
output_image = Image.new('RGB', (512, 512), (255, 255, 255))
st.write("#### Output image")
st.image(output_image, width=512)
if st.button("Move to input image"):
move_image('output_image', 'initial_image', remove_state=True, rerun=True)
def make_editing_canvas(canvas_color, brush, _reset_state, generation_mode, paint_mode):
"""Create an editable drawing canvas."""
st.write("#### Input image")
canvas_dict = make_canvas_dict(
canvas_color=canvas_color,
paint_mode=paint_mode,
brush=brush,
_reset_state=_reset_state
)
if generation_mode == "Segmentation":
st_canvas(**canvas_dict)
elif generation_mode == "Inpainting":
image = get_image()
if image is None:
st.error("Error: Could not load image for inpainting.")
return
canvas = st_canvas(**canvas_dict)
if st.button("Generate images", key='generate_button'):
canvas_mask = getattr(canvas, 'image_data', None)
if canvas_mask is None:
st.error("Error: No mask data found on canvas.")
return
if not isinstance(canvas_mask, np.ndarray):
canvas_mask = np.array(canvas_mask)
mask = get_mask(canvas_mask)
with st.spinner(text="Generating new images..."):
result_image = make_inpainting(
positive_prompt=st.session_state['positive_prompt'],
image=Image.fromarray(image),
mask_image=mask,
negative_prompt=st.session_state['negative_prompt'],
)
if isinstance(result_image, np.ndarray):
result_image = Image.fromarray(result_image)
st.session_state['output_image'] = result_image
def main():
"""Main Streamlit app function."""
st.write("## Virtual Staging")
input_image, generation_mode, brush, color_chooser, paint_mode = make_sidebar()
if 'initial_image' not in st.session_state or st.session_state['initial_image'] is None:
st.write("Please upload an image to begin")
else:
make_prompt_row()
_reset_state = check_reset_state()
col1, col2 = st.columns(2)
with col1:
make_editing_canvas(
canvas_color=color_chooser,
brush=brush,
_reset_state=_reset_state,
generation_mode=generation_mode,
paint_mode=paint_mode
)
with col2:
make_output_image()
if __name__ == "__main__":
main()
|