| from typing import Optional | |
| import numpy as np | |
| import cv2 | |
| import streamlit as st | |
| from PIL import Image | |
| from sdfile import PIPELINES, generate | |
| DEFAULT_PROMPT = "belted shirt black belted portrait-collar wrap blouse with black prints" | |
| DEAFULT_WIDTH, DEFAULT_HEIGHT = 512,512 | |
| OUTPUT_IMAGE_KEY = "output_img" | |
| LOADED_IMAGE_KEY = "loaded_img" | |
| def get_image(key: str) -> Optional[Image.Image]: | |
| if key in st.session_state: | |
| return st.session_state[key] | |
| return None | |
| def set_image(key:str, img: Image.Image): | |
| st.session_state[key] = img | |
| def prompt_and_generate_button(prefix, pipeline_name: PIPELINES, **kwargs): | |
| prompt = st.text_area( | |
| "Prompt", | |
| value = DEFAULT_PROMPT, | |
| key = f"{prefix}-prompt" | |
| ) | |
| negative_prompt = st.text_area( | |
| "Negative prompt", | |
| value = "", | |
| key =f"{prefix}-negative_prompt", | |
| ) | |
| col1,col2 =st.columns(2) | |
| with col1: | |
| steps = st.slider( | |
| "Number of inference steps", | |
| min_value=1, | |
| max_value=200, | |
| value=30, | |
| key=f"{prefix}-inference-steps", | |
| ) | |
| with col2: | |
| guidance_scale = st.slider( | |
| "Guidance scale", | |
| min_value=0.0, | |
| max_value=20.0, | |
| value= 7.5, | |
| step = 0.5, | |
| key=f"{prefix}-guidance-scale", | |
| ) | |
| enable_cpu_offload = st.checkbox( | |
| "Enable CPU offload if you run out of memory", | |
| key =f"{prefix}-cpu-offload", | |
| value= False, | |
| ) | |
| if st.button("Generate Image", key = f"{prefix}-btn"): | |
| with st.spinner("Generating image ..."): | |
| image = generate( | |
| prompt, | |
| pipeline_name, | |
| negative_prompt=negative_prompt, | |
| num_inference_steps=steps, | |
| guidance_scale=guidance_scale, | |
| enable_cpu_offload=enable_cpu_offload, | |
| **kwargs, | |
| ) | |
| set_image(OUTPUT_IMAGE_KEY,image.copy()) | |
| st.image(image) | |
| def width_and_height_sliders(prefix): | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| width = st.slider( | |
| "Width", | |
| min_value=64, | |
| max_value=1600, | |
| step=16, | |
| value=512, | |
| key=f"{prefix}-width", | |
| ) | |
| with col2: | |
| height = st.slider( | |
| "Height", | |
| min_value=64, | |
| max_value=1600, | |
| step=16, | |
| value=512, | |
| key=f"{prefix}-height", | |
| ) | |
| return width, height | |
| def image_uploader(prefix): | |
| image = st.file_uploader("Image", ["jpg", "png"], key=f"{prefix}-uploader") | |
| if image: | |
| image = Image.open(image) | |
| print(f"loaded input image of size ({image.width}, {image.height})") | |
| return image | |
| return get_image(LOADED_IMAGE_KEY) | |
| def sketching(): | |
| image = image_uploader("Controlnet") | |
| if not image: | |
| return None,None | |
| image = cv2.imread(image) | |
| image = cv2.cvtColor(image,cv2.COLOR_BGR2GRAY) | |
| image_blur = cv2.GaussianBlur(image,(5,5),0) | |
| sketch = cv2.adaptiveThreshold(image_blur, 255, cv2.ADAPTIVE_THRESH_MEAN_C,cv2.THRES_BINARY,11,2) | |
| sketch_pil = Image.fromarray(sketch) | |
| return sketch_pil | |
| def txt2img_tab(): | |
| prefix = "txt2img" | |
| width, height = width_and_height_sliders(prefix) | |
| prompt_and_generate_button(prefix,"txt2img",width=width,height=height) | |
| def sketching_tab(): | |
| prefix = "sketch2img" | |
| col1,col2 = st.columns(2) | |
| with col1: | |
| sketch_pil = sketching() | |
| with col2: | |
| if sketch_pil: | |
| controlnet_conditioning_scale = st.slider( | |
| "Strength or dependence on the input sketch", | |
| min_value=0.0, | |
| max_value= 1.0, | |
| value = 0.5, | |
| step = 0.05, | |
| key=f"{prefix}-controlnet_conditioning_scale", | |
| ) | |
| prompt_and_generate_button( | |
| prefix, | |
| "sketch2img", | |
| sketch_pil=sketch_pil, | |
| controlnet_conditioning_scale=controlnet_conditioning_scale, | |
| ) | |
| def main(): | |
| st.set_page_config(layout="wide") | |
| st.title("Fashion-SDX: Playground") | |
| tab1,tab2 = st.tabs( | |
| ["Text to image", "Sketch to image"] | |
| ) | |
| with tab1: | |
| txt2img_tab() | |
| with tab2: | |
| sketching_tab() | |
| with st.sidebar: | |
| st.header("Most Recent Output Image") | |
| output_image = get_image((OUTPUT_IMAGE_KEY)) | |
| if output_image: | |
| st.image(output_image) | |
| else: | |
| st.markdown("no output generated yet") | |
| if __name__ =="__main__": | |
| main() |