Abhilasha commited on
Commit
782218d
·
1 Parent(s): 4b28da5
Files changed (1) hide show
  1. app.py +66 -61
app.py CHANGED
@@ -5,75 +5,62 @@ st.set_page_config(layout="wide")
5
  from streamlit_drawable_canvas import st_canvas
6
  from PIL import Image
7
  from typing import Union
8
- import random
9
  import numpy as np
10
- import os
11
- import time
12
 
13
  from models import make_image_controlnet, make_inpainting
14
  from preprocessing import preprocess_seg_mask, get_image, get_mask
15
  from explanation import make_inpainting_explanation, make_regeneration_explanation, make_segmentation_explanation
16
 
17
 
18
-
19
  def on_upload() -> None:
20
  """Upload image to the canvas."""
21
  if 'input_image' in st.session_state and st.session_state['input_image'] is not None:
22
  image = Image.open(st.session_state['input_image']).convert('RGB')
23
  st.session_state['initial_image'] = image
24
- if 'seg' in st.session_state:
25
- del st.session_state['seg']
26
- if 'unique_colors' in st.session_state:
27
- del st.session_state['unique_colors']
28
- if 'output_image' in st.session_state:
29
- del st.session_state['output_image']
30
-
31
- def make_image_row(image_0, image_1):
32
- col_0, col_1 = st.columns(2)
33
- with col_0:
34
- st.image(image_0, use_column_width=True)
35
- with col_1:
36
- st.image(image_1, use_column_width=True)
37
 
38
  def check_reset_state() -> bool:
39
  """Check whether the UI elements need to be reset"""
40
- if ('reset_canvas' in st.session_state and st.session_state['reset_canvas']):
41
  st.session_state['reset_canvas'] = False
42
  return True
43
- st.session_state['reset_canvas'] = False
44
  return False
45
 
46
 
47
- def move_image(source: Union[str, Image.Image],
48
- dest: str,
49
- rerun: bool = True,
50
- remove_state: bool = True) -> None:
51
  """Move image from source to destination."""
52
- source_image = source if isinstance(source, Image.Image) else st.session_state[source]
53
 
54
  if remove_state:
55
  st.session_state['reset_canvas'] = True
56
- if 'seg' in st.session_state:
57
- del st.session_state['seg']
58
- if 'unique_colors' in st.session_state:
59
- del st.session_state['unique_colors']
60
-
61
- st.session_state[dest] = source_image
62
- if rerun:
63
- st.experimental_rerun()
64
 
65
-
66
- def on_change_radio() -> None:
67
- """Reset the UI elements when the radio button is changed."""
68
- st.session_state['reset_canvas'] = True
69
 
70
 
71
  def make_canvas_dict(canvas_color, brush, paint_mode, _reset_state):
 
 
 
 
 
 
 
 
 
 
72
  canvas_dict = {
73
  'fill_color': canvas_color,
74
  'stroke_color': canvas_color,
75
  'background_color': "#FFFFFF",
76
- 'background_image': st.session_state['initial_image'] if 'initial_image' in st.session_state else None,
77
  'stroke_width': brush,
78
  'initial_drawing': {'version': '4.4.0', 'objects': []} if _reset_state else None,
79
  'update_streamlit': True,
@@ -86,22 +73,25 @@ def make_canvas_dict(canvas_color, brush, paint_mode, _reset_state):
86
 
87
 
88
  def make_prompt_row():
 
89
  col_0_0, col_0_1 = st.columns(2)
90
  with col_0_0:
91
  st.text_input(label="Positive prompt",
92
- value="a realistic photograph of a room, high resolution",
93
- key='positive_prompt')
94
  with col_0_1:
95
  st.text_input(label="Negative prompt",
96
- value="watermark, banner, logo, watermark, contactinfo, text, deformed, blurry, blur, lowres",
97
- key='negative_prompt')
 
98
 
99
  def make_sidebar():
 
100
  with st.sidebar:
101
  input_image = st.file_uploader("Upload Image",
102
- type=["png", "jpg"],
103
- key='input_image',
104
- on_change=on_upload)
105
  generation_mode = 'Inpainting'
106
  paint_mode = 'freedraw'
107
  brush = 30 if paint_mode == "freedraw" else 5
@@ -109,23 +99,25 @@ def make_sidebar():
109
  return input_image, generation_mode, brush, color_chooser, paint_mode
110
 
111
 
112
-
113
  def make_output_image():
114
- if 'output_image' in st.session_state:
115
- output_image = st.session_state['output_image']
116
- if isinstance(output_image, np.ndarray):
117
- output_image = Image.fromarray(output_image)
118
- if isinstance(output_image, Image.Image):
119
- output_image = output_image.resize((512, 512))
120
- else:
121
  output_image = Image.new('RGB', (512, 512), (255, 255, 255))
122
 
123
  st.write("#### Output image")
124
  st.image(output_image, width=512)
 
125
  if st.button("Move to input image"):
126
  move_image('output_image', 'initial_image', remove_state=True, rerun=True)
127
 
 
128
  def make_editing_canvas(canvas_color, brush, _reset_state, generation_mode, paint_mode):
 
129
  st.write("#### Input image")
130
  canvas_dict = make_canvas_dict(
131
  canvas_color=canvas_color,
@@ -135,35 +127,49 @@ def make_editing_canvas(canvas_color, brush, _reset_state, generation_mode, pain
135
  )
136
 
137
  if generation_mode == "Segmentation":
138
- canvas = st_canvas(**canvas_dict)
 
139
  elif generation_mode == "Inpainting":
140
- image = get_image() # Assuming this function exists in your preprocessing module
 
 
 
 
141
  canvas = st_canvas(**canvas_dict)
142
 
143
  if st.button("Generate images", key='generate_button'):
144
- canvas_mask = canvas.image_data
 
 
 
 
 
145
  if not isinstance(canvas_mask, np.ndarray):
146
  canvas_mask = np.array(canvas_mask)
147
- mask = get_mask(canvas_mask) # Assuming this function exists in your preprocessing module
148
 
149
- with st.spinner(text="Generating new images"):
150
- print("Making image")
 
151
  result_image = make_inpainting(
152
  positive_prompt=st.session_state['positive_prompt'],
153
  image=Image.fromarray(image),
154
  mask_image=mask,
155
  negative_prompt=st.session_state['negative_prompt'],
156
  )
 
157
  if isinstance(result_image, np.ndarray):
158
  result_image = Image.fromarray(result_image)
 
159
  st.session_state['output_image'] = result_image
160
 
 
161
  def main():
 
162
  st.write("## Virtual Staging")
163
 
164
  input_image, generation_mode, brush, color_chooser, paint_mode = make_sidebar()
165
 
166
- if not ('initial_image' in st.session_state and st.session_state['initial_image'] is not None):
167
  st.write("Please upload an image to begin")
168
  else:
169
  make_prompt_row()
@@ -182,7 +188,6 @@ def main():
182
  with col2:
183
  make_output_image()
184
 
 
185
  if __name__ == "__main__":
186
  main()
187
-
188
-
 
5
  from streamlit_drawable_canvas import st_canvas
6
  from PIL import Image
7
  from typing import Union
 
8
  import numpy as np
 
 
9
 
10
  from models import make_image_controlnet, make_inpainting
11
  from preprocessing import preprocess_seg_mask, get_image, get_mask
12
  from explanation import make_inpainting_explanation, make_regeneration_explanation, make_segmentation_explanation
13
 
14
 
 
15
  def on_upload() -> None:
16
  """Upload image to the canvas."""
17
  if 'input_image' in st.session_state and st.session_state['input_image'] is not None:
18
  image = Image.open(st.session_state['input_image']).convert('RGB')
19
  st.session_state['initial_image'] = image
20
+ st.session_state.pop('seg', None)
21
+ st.session_state.pop('unique_colors', None)
22
+ st.session_state.pop('output_image', None)
23
+
 
 
 
 
 
 
 
 
 
24
 
25
  def check_reset_state() -> bool:
26
  """Check whether the UI elements need to be reset"""
27
+ if st.session_state.get('reset_canvas', False):
28
  st.session_state['reset_canvas'] = False
29
  return True
 
30
  return False
31
 
32
 
33
+ def move_image(source: Union[str, Image.Image], dest: str, rerun: bool = True, remove_state: bool = True) -> None:
 
 
 
34
  """Move image from source to destination."""
35
+ source_image = source if isinstance(source, Image.Image) else st.session_state.get(source)
36
 
37
  if remove_state:
38
  st.session_state['reset_canvas'] = True
39
+ st.session_state.pop('seg', None)
40
+ st.session_state.pop('unique_colors', None)
 
 
 
 
 
 
41
 
42
+ if source_image:
43
+ st.session_state[dest] = source_image
44
+ if rerun:
45
+ st.experimental_rerun()
46
 
47
 
48
  def make_canvas_dict(canvas_color, brush, paint_mode, _reset_state):
49
+ """Create a dictionary for the canvas settings."""
50
+ background_image = st.session_state.get('initial_image', None)
51
+
52
+ if isinstance(background_image, str): # Convert if it's a file path
53
+ try:
54
+ background_image = Image.open(background_image)
55
+ except Exception as e:
56
+ st.error(f"Failed to load background image: {e}")
57
+ background_image = None
58
+
59
  canvas_dict = {
60
  'fill_color': canvas_color,
61
  'stroke_color': canvas_color,
62
  'background_color': "#FFFFFF",
63
+ 'background_image': background_image if isinstance(background_image, Image.Image) else None,
64
  'stroke_width': brush,
65
  'initial_drawing': {'version': '4.4.0', 'objects': []} if _reset_state else None,
66
  'update_streamlit': True,
 
73
 
74
 
75
  def make_prompt_row():
76
+ """Create input fields for positive and negative prompts."""
77
  col_0_0, col_0_1 = st.columns(2)
78
  with col_0_0:
79
  st.text_input(label="Positive prompt",
80
+ value="a realistic photograph of a room, high resolution",
81
+ key='positive_prompt')
82
  with col_0_1:
83
  st.text_input(label="Negative prompt",
84
+ value="watermark, banner, logo, contact info, text, deformed, blurry, lowres",
85
+ key='negative_prompt')
86
+
87
 
88
  def make_sidebar():
89
+ """Create the sidebar with upload options and settings."""
90
  with st.sidebar:
91
  input_image = st.file_uploader("Upload Image",
92
+ type=["png", "jpg"],
93
+ key='input_image',
94
+ on_change=on_upload)
95
  generation_mode = 'Inpainting'
96
  paint_mode = 'freedraw'
97
  brush = 30 if paint_mode == "freedraw" else 5
 
99
  return input_image, generation_mode, brush, color_chooser, paint_mode
100
 
101
 
 
102
  def make_output_image():
103
+ """Display the output image."""
104
+ output_image = st.session_state.get('output_image', None)
105
+
106
+ if isinstance(output_image, np.ndarray):
107
+ output_image = Image.fromarray(output_image)
108
+
109
+ if output_image is None:
110
  output_image = Image.new('RGB', (512, 512), (255, 255, 255))
111
 
112
  st.write("#### Output image")
113
  st.image(output_image, width=512)
114
+
115
  if st.button("Move to input image"):
116
  move_image('output_image', 'initial_image', remove_state=True, rerun=True)
117
 
118
+
119
  def make_editing_canvas(canvas_color, brush, _reset_state, generation_mode, paint_mode):
120
+ """Create an editable drawing canvas."""
121
  st.write("#### Input image")
122
  canvas_dict = make_canvas_dict(
123
  canvas_color=canvas_color,
 
127
  )
128
 
129
  if generation_mode == "Segmentation":
130
+ st_canvas(**canvas_dict)
131
+
132
  elif generation_mode == "Inpainting":
133
+ image = get_image()
134
+ if image is None:
135
+ st.error("Error: Could not load image for inpainting.")
136
+ return
137
+
138
  canvas = st_canvas(**canvas_dict)
139
 
140
  if st.button("Generate images", key='generate_button'):
141
+ canvas_mask = getattr(canvas, 'image_data', None)
142
+
143
+ if canvas_mask is None:
144
+ st.error("Error: No mask data found on canvas.")
145
+ return
146
+
147
  if not isinstance(canvas_mask, np.ndarray):
148
  canvas_mask = np.array(canvas_mask)
 
149
 
150
+ mask = get_mask(canvas_mask)
151
+
152
+ with st.spinner(text="Generating new images..."):
153
  result_image = make_inpainting(
154
  positive_prompt=st.session_state['positive_prompt'],
155
  image=Image.fromarray(image),
156
  mask_image=mask,
157
  negative_prompt=st.session_state['negative_prompt'],
158
  )
159
+
160
  if isinstance(result_image, np.ndarray):
161
  result_image = Image.fromarray(result_image)
162
+
163
  st.session_state['output_image'] = result_image
164
 
165
+
166
  def main():
167
+ """Main Streamlit app function."""
168
  st.write("## Virtual Staging")
169
 
170
  input_image, generation_mode, brush, color_chooser, paint_mode = make_sidebar()
171
 
172
+ if 'initial_image' not in st.session_state or st.session_state['initial_image'] is None:
173
  st.write("Please upload an image to begin")
174
  else:
175
  make_prompt_row()
 
188
  with col2:
189
  make_output_image()
190
 
191
+
192
  if __name__ == "__main__":
193
  main()