File size: 5,528 Bytes
81d8245
 
8b5fe62
 
 
81d8245
 
2e40cec
8b5fe62
 
 
 
 
 
 
 
 
 
 
 
9d52a0c
8b5fe62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0282d2d
 
 
8b5fe62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49d0ebc
 
176a9a1
 
 
 
 
 
 
 
 
 
 
 
49d0ebc
 
 
 
 
 
8b5fe62
 
 
81d8245
8b5fe62
 
 
 
 
 
 
 
 
 
 
 
 
 
c0b70e5
c2d4647
9866f53
5d30bbb
8b5fe62
176a9a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5fb2463
176a9a1
b6ea260
176a9a1
 
8b5fe62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49d0ebc
8b5fe62
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
from utils_html import HTML_TEMPLATE
from io import BytesIO
import gradio as gr 
import numpy as np 
import requests
import modal
import PIL

f_gc = modal.Cls.lookup("casa-interior-gc-v2", "GetProduct")

def casa_ai_run_tab1(image=None, text=None): 
    
    if image is None: 
        print('Please provide image of empty room to design')
        return None

    if text is None: 
        print('Please provide a text prompt')
        return None

    f = modal.Cls.lookup("casa-interior-hf-v4", "DesignModel")
    result_image = f.inference.remote("tab1", image, text)
    return result_image

def casa_ai_run_tab2(dict=None, text=None):
    
    image = dict["background"].convert("RGB")
    mask = dict["layers"][0].convert('L')

    if np.sum(np.array(mask)) == 0: 
        mask = None 
        
    if mask is None: 
        print('Please provide a mask over the object you want to generate again.')
        
    if image is None and text is None: 
        print('Please provide context in form of image, text')
        return None
        
    f_tab2 = modal.Cls.lookup("casa-interior-hf-v3", "DesignModel")
    result_image = f_tab2.inference.remote("tab2", image, text, mask)
    return result_image

def casa_ai_run_tab3(dict=None):

    selected_crop = dict["composite"]
    
    if selected_crop is None: 
        print('Please provide cropped object')
        return None

    selected_crop = PIL.Image.fromarray(selected_crop).convert('RGB')
    results = f_gc.inference.remote(selected_crop)

    return results

def casa_ai_run_tab_sketch(image=None, room_type=None, room_style=None):

    if image is None: 
        print('Please provide a sketch or ketchup image')
        return None

    if room_type is None: 
        print('Please select a room type')
        return None

    if room_style is None: 
        print('Please select a room style')
        return None

    text = f"{room_type}, {room_style}"
    f = modal.Cls.lookup("casa-interior-hf-v6-sketch", "DesignModel")
    result_image = f.inference.remote(image, text)
    return result_image


with gr.Blocks() as casa:
    title = "Casa-AI Demo"
    description = "A Gradio interface to use CasaAI for virtual staging"
    gr.HTML(value=HTML_TEMPLATE, show_label=False)

    with gr.Tab("Reimagine"):
        with gr.Row():
            with gr.Column():
                inputs = [
                            gr.Image(sources='upload', type="pil", label="Upload"), 
                            gr.Textbox(label="Room description.")
                        ]
            with gr.Column():
                outputs = [gr.Image(label="Generated room image")]

        
        submit_btn = gr.Button("Generate!")
        submit_btn.click(casa_ai_run_tab1, inputs=inputs, outputs=outputs)
        gr.Examples(examples=[['example_images/image_0.jpg', 'Living room in bohemian style'], 
                              ['example_images/image_1.jpg', 'A minimalist and scandinavian style living room with black leather sofa'], 
                              ['example_images/image_2.jpg', 'Modern bedroom art deco style']], 
                    inputs=inputs, outputs=outputs, fn=casa_ai_run_tab1, cache_examples=True)

    
    with gr.Tab("Sketch Transform"):
        with gr.Row():
            with gr.Column():
                inputs = [
                            gr.Image(sources='upload', type="numpy", label="Upload"),
                            gr.Dropdown(["Living Room", "Bedroom", "Kitchen"], label="Room Type", info="Select Room Type"),
                            gr.Dropdown(["Modern", "Minimalist", "Scandinavian"], label="Style", info="Interior Style!"),
                ]
                    
            with gr.Column():
                outputs = [gr.Image(label="Image of sketch transformed")]
                
        submit_btn = gr.Button("Transform!")
        submit_btn.click(casa_ai_run_tab_sketch, inputs=inputs, outputs=outputs)
        gr.Examples(examples=[['example_images/sketch01.jpeg', 'Living Room',  'Modern'], 
                              ['example_images/sketch02.jpeg', 'Bedroom', 'Minimalist'], 
                              ['example_images/sketchup.jpg', 'Kitchen', 'Modern']], 
                    inputs=inputs, outputs=outputs, fn=casa_ai_run_tab_sketch, cache_examples=True)

        
    with gr.Tab("Redesign"):
        with gr.Row():
            with gr.Column():
                inputs = [
                            gr.ImageEditor(sources='upload', brush=gr.Brush(colors=["#FFFFFF"]), elem_id="image_upload", type="pil", label="Upload", layers=False, eraser=True, transforms=[]),
                            gr.Textbox(label="Description for redesigning masked object")]
            with gr.Column():
                outputs = [gr.Image(label="Image with new designed object")]
                
        submit_btn = gr.Button("Redesign!")
        submit_btn.click(casa_ai_run_tab2, inputs=inputs, outputs=outputs)

    with gr.Tab("Recommendation"):
        with gr.Row():
            with gr.Column():
                inputs = [
                            gr.ImageEditor(sources='upload', elem_id="image_upload", type="numpy", label="Upload", layers=False, eraser=False, brush=False, transforms=['crop']),
                            ]
            with gr.Column():
                outputs = [gr.Gallery(label="Similar products")]
                
        submit_btn = gr.Button("Find similar products!")
        submit_btn.click(casa_ai_run_tab3, inputs=inputs, outputs=outputs)


casa.launch()