SharafeevRavil commited on
Commit
a74e5af
·
verified ·
1 Parent(s): d79965e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +187 -108
app.py CHANGED
@@ -1,34 +1,24 @@
1
- # !pip install gradio
2
-
3
  import gradio as gr
4
  import torch
5
  from transformers import pipeline
6
  from huggingface_hub import InferenceClient
7
-
8
  from PIL import Image
9
  import numpy as np
10
  import cv2
11
 
12
 
13
  # Инициализация моделей
14
- # segmentation = pipeline("image-segmentation", model="nvidia/segformer-b5-finetuned-ade-640-640")
15
- classification = pipeline("image-classification", model="google/vit-base-patch16-224")
16
- upscaling_client = InferenceClient(model="stabilityai/stable-diffusion-x4-upscaler")
17
- inpainting_client = InferenceClient(model="stabilityai/stable-diffusion-inpainting")
18
- trellis_client = InferenceClient(model="microsoft/TRELLIS")
19
-
20
-
21
  from transformers import OneFormerProcessor, OneFormerForUniversalSegmentation
22
- import torch
23
- from PIL import Image
24
- import numpy as np
25
 
26
  device = "cuda" if torch.cuda.is_available() else "cpu"
27
- processor = OneFormerProcessor.from_pretrained("shi-labs/oneformer_ade20k_swin_tiny")
28
- model = OneFormerForUniversalSegmentation.from_pretrained("shi-labs/oneformer_ade20k_swin_tiny").to(device)
29
-
30
 
 
 
 
31
 
 
 
 
32
 
33
 
34
  # Функции для обработки изображений
@@ -38,15 +28,13 @@ import numpy as np
38
 
39
  def segment_image(image):
40
  image = Image.fromarray(image)
41
-
42
- # Изменяем task_input на "panoptic"
43
- inputs = processor(image, task_inputs=["panoptic"], return_tensors="pt")
44
 
45
  with torch.no_grad():
46
- outputs = model(**inputs)
47
 
48
  # post-process the raw predictions
49
- predicted_panoptic_map = processor.post_process_panoptic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
50
 
51
  # Extract segment ids and masks
52
  segmentation_map = predicted_panoptic_map["segmentation"].cpu().numpy()
@@ -58,9 +46,13 @@ def segment_image(image):
58
 
59
  for segment in segments_info:
60
  mask = (segmentation_map == segment["id"]).astype(np.uint8) * 255
61
- cropped_image = cv2.bitwise_and(np.array(image), np.array(image), mask=mask)
62
 
63
- label = model.config.id2label[segment["label_id"]]
 
 
 
 
64
 
65
  # Check if label already exists
66
  if label in label_counts:
@@ -73,105 +65,174 @@ def segment_image(image):
73
 
74
  return cropped_masks_with_labels
75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  def merge_segments_by_labels(gallery_images, labels_input):
77
- """
78
- Объединяет сегменты из галереи изображений в одно изображение,
79
- основываясь на введенных пользователем метках.
80
-
81
- Args:
82
- gallery_images: Список изображений сегментов (кортежи (изображение, метка)).
83
- labels_input: Строка с метками, разделенными точкой с запятой.
84
-
85
- Returns:
86
- Список изображений, где выбранные сегменты объединены в одно.
87
- """
88
- # 1. Разделяем входную строку с метками на список
89
  labels_to_merge = [label.strip() for label in labels_input.split(";")]
90
-
91
- # 2. Создаем пустое изображение для объединения
92
  merged_image = None
93
-
94
- # 3. Инициализируем список для хранения индексов объединенных сегментов
95
  merged_indices = []
96
 
97
- # 4. Проходим по всем изображениям в галерее
98
- for i, (image_path, label) in enumerate(gallery_images):
99
- # 5. Если метка сегмента в списке меток для объединения
100
  if label in labels_to_merge:
101
- image = cv2.imread(image_path)
102
- # 6. Если это первый сегмент для объединения
 
103
  if merged_image is None:
104
- # 7. Создаем копию изображения как основу для объединения
105
  merged_image = image.copy()
106
  else:
107
- # 8. Объединяем текущее изображение с merged_image
108
- # Используем cv2.add для наложения изображений,
109
- # предполагая, что сегменты не перекрываются
110
- merged_image = cv2.add(merged_image, image)
111
-
112
- # 9. Добавляем индекс объединенного сегмента в список
113
  merged_indices.append(i)
114
 
115
- # 10. Если сегменты были объединены
116
  if merged_image is not None:
117
- # 11. Создаем новый список изображений, удаляя объединенные сегменты
118
- # и добавляя объединенное изображение с новой меткой
 
119
  new_gallery_images = [
120
  item for i, item in enumerate(gallery_images) if i not in merged_indices
121
  ]
122
-
123
  new_name = labels_to_merge[0]
124
- new_gallery_images.append((merged_image, new_name))
125
-
126
  return new_gallery_images
127
  else:
128
- # 12. Если не было меток для объединения, возвращаем исходный список
129
  return gallery_images
130
 
131
- def set_client_for_session(request: gr.Request):
132
- x_ip_token = request.headers['x-ip-token']
133
- # The "JeffreyXiang/TRELLIS" space is a ZeroGPU space
134
- return Client("JeffreyXiang/TRELLIS", headers={"X-IP-Token": x_ip_token})
135
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
 
137
  def generate_3d_model(client, segment_output, segment_name):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  for i, (image_path, label) in enumerate(segment_output):
139
  if label == segment_name:
140
  result = client.predict(
141
  image=handle_file(image_path),
142
- multiimages=[],
143
- seed=0,
144
- ss_guidance_strength=7.5,
145
- ss_sampling_steps=12,
146
- slat_guidance_strength=3,
147
- slat_sampling_steps=12,
148
- multiimage_algo="stochastic",
149
- api_name="/image_to_3d"
150
  )
151
- break
152
- print(result)
153
- return result["video"]
154
 
155
- def classify_segments(segments):
156
- # Предполагается, что segments - список изображений сегментов
157
- results = []
158
- for segment in segments:
159
- results.append(classification(segment))
160
- return results # Вернем список классификаций
161
 
162
- def upscale_segment(segment):
163
- upscaled = upscaling_client.image_to_image(segment)
164
- return upscaled
 
 
 
165
 
166
- def inpaint_image(image, mask, prompt):
167
- inpainted = inpainting_client.text_to_image(prompt, image=image, mask=mask)
168
- return inpainted
169
 
 
 
 
170
 
171
 
 
172
 
173
  with gr.Blocks() as demo:
174
- client = gr.State()
 
175
 
176
  gr.Markdown("# Анализ и редактирование помещений")
177
 
@@ -188,32 +249,50 @@ with gr.Blocks() as demo:
188
  with gr.Row():
189
  with gr.Column(scale=5):
190
  trellis_input = gr.Textbox(label="Имя сегмента для 3D")
191
- trellis_button = gr.Button("3D Trellis")
 
 
192
  with gr.Column(scale=5):
193
- trellis_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
194
- trellis_button.click(generate_3d_model, inputs=[client, segment_output, trellis_input], outputs=trellis_output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
195
 
196
  segment_button.click(segment_image, inputs=image_input, outputs=segment_output)
197
  # segment_button.click(segment_full_image, inputs=image_input, outputs=segment_output)
198
 
199
- with gr.Tab("Редактирование"):
200
- segment_input = gr.Image()
201
- upscale_output = gr.Image()
202
- upscale_button = gr.Button("Upscale")
203
- upscale_button.click(upscale_segment, inputs=segment_input, outputs=upscale_output)
204
-
205
- mask_input = gr.Image()
206
- prompt_input = gr.Textbox()
207
- inpaint_output = gr.Image()
208
- inpaint_button = gr.Button("Inpaint")
209
- inpaint_button.click(inpaint_image, inputs=[segment_input, mask_input, prompt_input], outputs=inpaint_output)
210
-
211
- with gr.Tab("Создание 3D моделей"):
212
- segment_input_3d = gr.Image()
213
- model_output = gr.File()
214
- model_button = gr.Button("Создать 3D модель")
215
- model_button.click(generate_3d_model, inputs=segment_input_3d, outputs=model_output)
216
-
217
- demo.load(set_client_for_session, None, client)
 
218
 
219
  demo.launch(debug=True, show_error=True)
 
 
 
1
  import gradio as gr
2
  import torch
3
  from transformers import pipeline
4
  from huggingface_hub import InferenceClient
 
5
  from PIL import Image
6
  import numpy as np
7
  import cv2
8
 
9
 
10
  # Инициализация моделей
 
 
 
 
 
 
 
11
  from transformers import OneFormerProcessor, OneFormerForUniversalSegmentation
 
 
 
12
 
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
 
14
 
15
+ # oneFormer segmentation
16
+ oneFormer_processor = OneFormerProcessor.from_pretrained("shi-labs/oneformer_ade20k_swin_tiny")
17
+ oneFormer_model = OneFormerForUniversalSegmentation.from_pretrained("shi-labs/oneformer_ade20k_swin_tiny").to(device)
18
 
19
+ # classification = pipeline("image-classification", model="google/vit-base-patch16-224")
20
+ # upscaling_client = InferenceClient(model="stabilityai/stable-diffusion-x4-upscaler")
21
+ # inpainting_client = InferenceClient(model="stabilityai/stable-diffusion-inpainting")
22
 
23
 
24
  # Функции для обработки изображений
 
28
 
29
  def segment_image(image):
30
  image = Image.fromarray(image)
31
+ inputs = oneFormer_processor(image, task_inputs=["panoptic"], return_tensors="pt")
 
 
32
 
33
  with torch.no_grad():
34
+ outputs = oneFormer_model(**inputs)
35
 
36
  # post-process the raw predictions
37
+ predicted_panoptic_map = oneFormer_processor.post_process_panoptic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
38
 
39
  # Extract segment ids and masks
40
  segmentation_map = predicted_panoptic_map["segmentation"].cpu().numpy()
 
46
 
47
  for segment in segments_info:
48
  mask = (segmentation_map == segment["id"]).astype(np.uint8) * 255
49
+ # cropped_image = cv2.bitwise_and(np.array(image), np.array(image), mask=mask)
50
 
51
+ cropped_image = np.zeros((image.height, image.width, 4), dtype=np.uint8)
52
+ cropped_image[mask != 0, :3] = np.array(image)[mask != 0]
53
+ cropped_image[mask != 0, 3] = 255
54
+
55
+ label = oneFormer_model.config.id2label[segment["label_id"]]
56
 
57
  # Check if label already exists
58
  if label in label_counts:
 
65
 
66
  return cropped_masks_with_labels
67
 
68
+ # def merge_segments_by_labels(gallery_images, labels_input):
69
+ # """
70
+ # Объединяет сегменты из галереи изображений в одно изображение,
71
+ # основываясь на введенных пользователем метках.
72
+
73
+ # Args:
74
+ # gallery_images: Список изображений сегментов (кортежи (изображение, метка)).
75
+ # labels_input: Строка с метками, разделенными точкой с запятой.
76
+
77
+ # Returns:
78
+ # Список изображений, где выбранные сегменты объединены в одно.
79
+ # """
80
+ # labels_to_merge = [label.strip() for label in labels_input.split(";")]
81
+ # merged_image = None
82
+ # merged_indices = []
83
+
84
+ # for i, (image_path, label) in enumerate(gallery_images):
85
+ # if label in labels_to_merge:
86
+ # image = cv2.imread(image_path)
87
+ # if merged_image is None:
88
+ # merged_image = image.copy()
89
+ # else:
90
+ # merged_image = cv2.add(merged_image, image)
91
+ # merged_indices.append(i)
92
+ # if merged_image is not None:
93
+ # new_gallery_images = [
94
+ # item for i, item in enumerate(gallery_images) if i not in merged_indices
95
+ # ]
96
+
97
+ # new_name = labels_to_merge[0]
98
+ # new_gallery_images.append((merged_image, new_name))
99
+
100
+ # return new_gallery_images
101
+ # else:
102
+ # return gallery_images
103
+
104
+
105
  def merge_segments_by_labels(gallery_images, labels_input):
 
 
 
 
 
 
 
 
 
 
 
 
106
  labels_to_merge = [label.strip() for label in labels_input.split(";")]
 
 
107
  merged_image = None
 
 
108
  merged_indices = []
109
 
110
+ for i, (image_path, label) in enumerate(gallery_images): # Исправлено: image_path
 
 
111
  if label in labels_to_merge:
112
+ # Загружаем изображение с помощью PIL, сохраняя альфа-канал
113
+ image = Image.open(image_path).convert("RGBA")
114
+
115
  if merged_image is None:
 
116
  merged_image = image.copy()
117
  else:
118
+ # Объединяем изображения с учетом альфа-канала
119
+ merged_image = Image.alpha_composite(merged_image, image)
 
 
 
 
120
  merged_indices.append(i)
121
 
 
122
  if merged_image is not None:
123
+ # Преобразуем объединенное изображение в numpy array
124
+ merged_image_np = np.array(merged_image)
125
+
126
  new_gallery_images = [
127
  item for i, item in enumerate(gallery_images) if i not in merged_indices
128
  ]
 
129
  new_name = labels_to_merge[0]
130
+ new_gallery_images.append((merged_image_np, new_name))
 
131
  return new_gallery_images
132
  else:
 
133
  return gallery_images
134
 
 
 
 
 
135
 
136
+ # def set_client_for_session(request: gr.Request):
137
+ # x_ip_token = request.headers['x-ip-token']
138
+ # return Client("JeffreyXiang/TRELLIS", headers={"X-IP-Token": x_ip_token})
139
+
140
+ def set_hunyuan_client(request: gr.Request):
141
+ try:
142
+ x_ip_token = request.headers['x-ip-token']
143
+ return Client("tencent/Hunyuan3D-2", headers={"X-IP-Token": x_ip_token})
144
+ except:
145
+ return Client("tencent/Hunyuan3D-2")
146
+
147
+ def set_vFusion_client(request: gr.Request):
148
+ try:
149
+ x_ip_token = request.headers['x-ip-token']
150
+ return Client("facebook/VFusion3D", headers={"X-IP-Token": x_ip_token})
151
+ except:
152
+ return Client("facebook/VFusion3D")
153
+
154
+ # def generate_3d_model(client, segment_output, segment_name):
155
+ # for i, (image_path, label) in enumerate(segment_output):
156
+ # if label == segment_name:
157
+ # result = client.predict(
158
+ # image=handle_file(image_path),
159
+ # multiimages=[],
160
+ # seed=0,
161
+ # ss_guidance_strength=7.5,
162
+ # ss_sampling_steps=12,
163
+ # slat_guidance_strength=3,
164
+ # slat_sampling_steps=12,
165
+ # multiimage_algo="stochastic",
166
+ # api_name="/image_to_3d"
167
+ # )
168
+ # break
169
+ # print(result)
170
+ # return result["video"]
171
 
172
  def generate_3d_model(client, segment_output, segment_name):
173
+ for i, (image_path, label) in enumerate(segment_output):
174
+ if label == segment_name:
175
+ result = client.predict(
176
+ caption="",
177
+ image=handle_file(image_path),
178
+ steps=50,
179
+ guidance_scale=5.5,
180
+ seed=1234,
181
+ octree_resolution="256",
182
+ check_box_rembg=True,
183
+ api_name="/shape_generation"
184
+ )
185
+ print(result)
186
+ return result[0]
187
+
188
+ def generate_3d_model_texture(client, segment_output, segment_name):
189
+ for i, (image_path, label) in enumerate(segment_output):
190
+ if label == segment_name:
191
+ result = client.predict(
192
+ caption="",
193
+ image=handle_file(image_path),
194
+ steps=50,
195
+ guidance_scale=5.5,
196
+ seed=1234,
197
+ octree_resolution="256",
198
+ check_box_rembg=True,
199
+ api_name="/generation_all"
200
+ )
201
+ print(result)
202
+ return result[1]
203
+
204
+ def generate_3d_model2(client, segment_output, segment_name):
205
  for i, (image_path, label) in enumerate(segment_output):
206
  if label == segment_name:
207
  result = client.predict(
208
  image=handle_file(image_path),
209
+ api_name="/step_1_generate_obj"
 
 
 
 
 
 
 
210
  )
211
+ print(result)
212
+ return result[0]
 
213
 
 
 
 
 
 
 
214
 
215
+ # def classify_segments(segments):
216
+ # # Предполагается, что segments - список изображений сегментов
217
+ # results = []
218
+ # for segment in segments:
219
+ # results.append(classification(segment))
220
+ # return results # Вернем список классификаций
221
 
222
+ # def upscale_segment(segment):
223
+ # upscaled = upscaling_client.image_to_image(segment)
224
+ # return upscaled
225
 
226
+ # def inpaint_image(image, mask, prompt):
227
+ # inpainted = inpainting_client.text_to_image(prompt, image=image, mask=mask)
228
+ # return inpainted
229
 
230
 
231
+ from gradio_litmodel3d import LitModel3D
232
 
233
  with gr.Blocks() as demo:
234
+ hunyuan_client = gr.State()
235
+ vFusion_client = gr.State()
236
 
237
  gr.Markdown("# Анализ и редактирование помещений")
238
 
 
249
  with gr.Row():
250
  with gr.Column(scale=5):
251
  trellis_input = gr.Textbox(label="Имя сегмента для 3D")
252
+ hunyuan_button = gr.Button("Hunyuan3D-2")
253
+ hunyuan_button_texture = gr.Button("Hunyuan3D-2 (with texture)")
254
+ vFusion_button = gr.Button("VFusion3D")
255
  with gr.Column(scale=5):
256
+ # trellis_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
257
+ trellis_output2 = LitModel3D(
258
+ clear_color=[0.1, 0.1, 0.1, 0], # can adjust background color for better contrast
259
+ label="3D Model Visualization",
260
+ scale=1.0,
261
+ tonemapping="aces", # can use aces tonemapping for more realistic lighting
262
+ exposure=1.0, # can adjust exposure to control brightness
263
+ contrast=1.1, # can slightly increase contrast for better depth
264
+ camera_position=(0, 0, 2), # will set initial camera position to center the model
265
+ zoom_speed=0.5, # will adjust zoom speed for better control
266
+ pan_speed=0.5, # will adjust pan speed for better control
267
+ interactive=True # this allow users to interact with the model
268
+ )
269
+ # trellis_button.click(generate_3d_model, inputs=[client, segment_output, trellis_input], outputs=trellis_output)
270
+ hunyuan_button.click(generate_3d_model, inputs=[hunyuan_client, segment_output, trellis_input], outputs=trellis_output2)
271
+ hunyuan_button_texture.click(generate_3d_model_texture, inputs=[hunyuan_client, segment_output, trellis_input], outputs=trellis_output2)
272
+ vFusion_button.click(generate_3d_model2, inputs=[vFusion_client, segment_output, trellis_input], outputs=trellis_output2)
273
 
274
  segment_button.click(segment_image, inputs=image_input, outputs=segment_output)
275
  # segment_button.click(segment_full_image, inputs=image_input, outputs=segment_output)
276
 
277
+ # with gr.Tab("Редактирование"):
278
+ # segment_input = gr.Image()
279
+ # upscale_output = gr.Image()
280
+ # upscale_button = gr.Button("Upscale")
281
+ # upscale_button.click(upscale_segment, inputs=segment_input, outputs=upscale_output)
282
+
283
+ # mask_input = gr.Image()
284
+ # prompt_input = gr.Textbox()
285
+ # inpaint_output = gr.Image()
286
+ # inpaint_button = gr.Button("Inpaint")
287
+ # inpaint_button.click(inpaint_image, inputs=[segment_input, mask_input, prompt_input], outputs=inpaint_output)
288
+
289
+ # with gr.Tab("Создание 3D моделей"):
290
+ # segment_input_3d = gr.Image()
291
+ # model_output = gr.File()
292
+ # model_button = gr.Button("Создать 3D модель")
293
+ # model_button.click(generate_3d_model, inputs=segment_input_3d, outputs=model_output)
294
+
295
+ demo.load(set_hunyuan_client, None, hunyuan_client)
296
+ demo.load(set_vFusion_client, None, vFusion_client)
297
 
298
  demo.launch(debug=True, show_error=True)