SharafeevRavil commited on
Commit
23eb257
·
verified ·
1 Parent(s): 5468c3d

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +218 -0
app.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ !pip install gradio
2
+
3
+ import gradio as gr
4
+ import torch
5
+ from transformers import pipeline
6
+ from huggingface_hub import InferenceClient
7
+
8
+ from PIL import Image
9
+ import numpy as np
10
+ import cv2
11
+
12
+
13
+ # Инициализация моделей
14
+ # segmentation = pipeline("image-segmentation", model="nvidia/segformer-b5-finetuned-ade-640-640")
15
+ classification = pipeline("image-classification", model="google/vit-base-patch16-224")
16
+ upscaling_client = InferenceClient(model="stabilityai/stable-diffusion-x4-upscaler")
17
+ inpainting_client = InferenceClient(model="stabilityai/stable-diffusion-inpainting")
18
+ trellis_client = InferenceClient(model="microsoft/TRELLIS")
19
+
20
+
21
+ from transformers import OneFormerProcessor, OneFormerForUniversalSegmentation
22
+ import torch
23
+ from PIL import Image
24
+ import numpy as np
25
+
26
+ device = "cuda" if torch.cuda.is_available() else "cpu"
27
+ processor = OneFormerProcessor.from_pretrained("shi-labs/oneformer_ade20k_swin_tiny")
28
+ model = OneFormerForUniversalSegmentation.from_pretrained("shi-labs/oneformer_ade20k_swin_tiny").to(device)
29
+
30
+
31
+
32
+
33
+
34
+ # Функции для обработки изображений
35
+ from PIL import Image, ImageDraw
36
+ import numpy as np
37
+
38
+ def segment_image(image):
39
+ image = Image.fromarray(image)
40
+
41
+ # Изменяем task_input на "panoptic"
42
+ inputs = processor(image, task_inputs=["panoptic"], return_tensors="pt")
43
+
44
+ with torch.no_grad():
45
+ outputs = model(**inputs)
46
+
47
+ # post-process the raw predictions
48
+ predicted_panoptic_map = processor.post_process_panoptic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
49
+
50
+ # Extract segment ids and masks
51
+ segmentation_map = predicted_panoptic_map["segmentation"].cpu().numpy()
52
+ segments_info = predicted_panoptic_map["segments_info"]
53
+
54
+ # Create cropped masks
55
+ cropped_masks_with_labels = []
56
+ label_counts = {}
57
+
58
+ for segment in segments_info:
59
+ mask = (segmentation_map == segment["id"]).astype(np.uint8) * 255
60
+ cropped_image = cv2.bitwise_and(np.array(image), np.array(image), mask=mask)
61
+
62
+ label = model.config.id2label[segment["label_id"]]
63
+
64
+ # Check if label already exists
65
+ if label in label_counts:
66
+ label_counts[label] += 1
67
+ else:
68
+ label_counts[label] = 1
69
+ label = f"{label}_{label_counts[label] - 1}" # Append _0, _1, etc.
70
+
71
+ cropped_masks_with_labels.append((cropped_image, label))
72
+
73
+ return cropped_masks_with_labels
74
+
75
+ def merge_segments_by_labels(gallery_images, labels_input):
76
+ """
77
+ Объединяет сегменты из галереи изображений в одно изображение,
78
+ основываясь на введенных пользователем метках.
79
+
80
+ Args:
81
+ gallery_images: Список изображений сегментов (кортежи (изображение, метка)).
82
+ labels_input: Строка с метками, разделенными точкой с запятой.
83
+
84
+ Returns:
85
+ Список изображений, где выбранные сегменты объединены в одно.
86
+ """
87
+ # 1. Разделяем входную строку с метками на список
88
+ labels_to_merge = [label.strip() for label in labels_input.split(";")]
89
+
90
+ # 2. Создаем пустое изображение для объединения
91
+ merged_image = None
92
+
93
+ # 3. Инициализируем список для хранения индексов объединенных сегментов
94
+ merged_indices = []
95
+
96
+ # 4. Проходим по всем изображениям в галерее
97
+ for i, (image_path, label) in enumerate(gallery_images):
98
+ # 5. Если метка сегмента в списке меток для объединения
99
+ if label in labels_to_merge:
100
+ image = cv2.imread(image_path)
101
+ # 6. Если это первый сегмент для объединения
102
+ if merged_image is None:
103
+ # 7. Создаем копию изображения как основу для объединения
104
+ merged_image = image.copy()
105
+ else:
106
+ # 8. Объединяем текущее изображение с merged_image
107
+ # Используем cv2.add для наложения изображений,
108
+ # предполагая, что сегменты не перекрываются
109
+ merged_image = cv2.add(merged_image, image)
110
+
111
+ # 9. Добавляем индекс объединенного сегмента в список
112
+ merged_indices.append(i)
113
+
114
+ # 10. Если сегменты были объединены
115
+ if merged_image is not None:
116
+ # 11. Создаем новый список изображений, удаляя объединенные сегменты
117
+ # и добавляя объединенное изображение с новой меткой
118
+ new_gallery_images = [
119
+ item for i, item in enumerate(gallery_images) if i not in merged_indices
120
+ ]
121
+
122
+ new_name = labels_to_merge[0]
123
+ new_gallery_images.append((merged_image, new_name))
124
+
125
+ return new_gallery_images
126
+ else:
127
+ # 12. Если не было меток для объединения, возвращаем исходный список
128
+ return gallery_images
129
+
130
+ def set_client_for_session(request: gr.Request):
131
+ x_ip_token = request.headers['x-ip-token']
132
+ # The "JeffreyXiang/TRELLIS" space is a ZeroGPU space
133
+ return Client("JeffreyXiang/TRELLIS", headers={"X-IP-Token": x_ip_token})
134
+
135
+
136
+ def generate_3d_model(client, segment_output, segment_name):
137
+ for i, (image_path, label) in enumerate(segment_output):
138
+ if label == segment_name:
139
+ result = client.predict(
140
+ image=handle_file(image_path),
141
+ multiimages=[],
142
+ seed=0,
143
+ ss_guidance_strength=7.5,
144
+ ss_sampling_steps=12,
145
+ slat_guidance_strength=3,
146
+ slat_sampling_steps=12,
147
+ multiimage_algo="stochastic",
148
+ api_name="/image_to_3d"
149
+ )
150
+ break
151
+ print(result)
152
+ return result["video"]
153
+
154
+ def classify_segments(segments):
155
+ # Предполагается, что segments - список изображений сегментов
156
+ results = []
157
+ for segment in segments:
158
+ results.append(classification(segment))
159
+ return results # Вернем список классификаций
160
+
161
+ def upscale_segment(segment):
162
+ upscaled = upscaling_client.image_to_image(segment)
163
+ return upscaled
164
+
165
+ def inpaint_image(image, mask, prompt):
166
+ inpainted = inpainting_client.text_to_image(prompt, image=image, mask=mask)
167
+ return inpainted
168
+
169
+
170
+
171
+
172
+ with gr.Blocks() as demo:
173
+ client = gr.State()
174
+
175
+ gr.Markdown("# Анализ и редактирование помещений")
176
+
177
+ with gr.Tab("Сканирование"):
178
+ with gr.Row():
179
+ with gr.Column(scale=5):
180
+ image_input = gr.Image()
181
+ segment_button = gr.Button("Сегментировать")
182
+ with gr.Column(scale=5):
183
+ segment_output = gr.Gallery()
184
+ merge_segments_input = gr.Textbox(label="Сегменты для объединения (через точку с запятой, например: \"wall_0; tv_0\")")
185
+ merge_segments_button = gr.Button("Соединить сегменты")
186
+ merge_segments_button.click(merge_segments_by_labels, inputs=[segment_output, merge_segments_input], outputs=segment_output)
187
+ with gr.Row():
188
+ with gr.Column(scale=5):
189
+ trellis_input = gr.Textbox(label="Имя сегмента для 3D")
190
+ trellis_button = gr.Button("3D Trellis")
191
+ with gr.Column(scale=5):
192
+ trellis_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
193
+ trellis_button.click(generate_3d_model, inputs=[client, segment_output, trellis_input], outputs=trellis_output)
194
+
195
+ segment_button.click(segment_image, inputs=image_input, outputs=segment_output)
196
+ # segment_button.click(segment_full_image, inputs=image_input, outputs=segment_output)
197
+
198
+ with gr.Tab("Редактирование"):
199
+ segment_input = gr.Image()
200
+ upscale_output = gr.Image()
201
+ upscale_button = gr.Button("Upscale")
202
+ upscale_button.click(upscale_segment, inputs=segment_input, outputs=upscale_output)
203
+
204
+ mask_input = gr.Image()
205
+ prompt_input = gr.Textbox()
206
+ inpaint_output = gr.Image()
207
+ inpaint_button = gr.Button("Inpaint")
208
+ inpaint_button.click(inpaint_image, inputs=[segment_input, mask_input, prompt_input], outputs=inpaint_output)
209
+
210
+ with gr.Tab("Создание 3D моделей"):
211
+ segment_input_3d = gr.Image()
212
+ model_output = gr.File()
213
+ model_button = gr.Button("Создать 3D модель")
214
+ model_button.click(generate_3d_model, inputs=segment_input_3d, outputs=model_output)
215
+
216
+ demo.load(set_client_for_session, None, client)
217
+
218
+ demo.launch(debug=True, show_error=True)