Spaces:
Running
Running
!pip install gradio | |
import gradio as gr | |
import torch | |
from transformers import pipeline | |
from huggingface_hub import InferenceClient | |
from PIL import Image | |
import numpy as np | |
import cv2 | |
# Инициализация моделей | |
# segmentation = pipeline("image-segmentation", model="nvidia/segformer-b5-finetuned-ade-640-640") | |
classification = pipeline("image-classification", model="google/vit-base-patch16-224") | |
upscaling_client = InferenceClient(model="stabilityai/stable-diffusion-x4-upscaler") | |
inpainting_client = InferenceClient(model="stabilityai/stable-diffusion-inpainting") | |
trellis_client = InferenceClient(model="microsoft/TRELLIS") | |
from transformers import OneFormerProcessor, OneFormerForUniversalSegmentation | |
import torch | |
from PIL import Image | |
import numpy as np | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
processor = OneFormerProcessor.from_pretrained("shi-labs/oneformer_ade20k_swin_tiny") | |
model = OneFormerForUniversalSegmentation.from_pretrained("shi-labs/oneformer_ade20k_swin_tiny").to(device) | |
# Функции для обработки изображений | |
from PIL import Image, ImageDraw | |
import numpy as np | |
def segment_image(image): | |
image = Image.fromarray(image) | |
# Изменяем task_input на "panoptic" | |
inputs = processor(image, task_inputs=["panoptic"], return_tensors="pt") | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
# post-process the raw predictions | |
predicted_panoptic_map = processor.post_process_panoptic_segmentation(outputs, target_sizes=[image.size[::-1]])[0] | |
# Extract segment ids and masks | |
segmentation_map = predicted_panoptic_map["segmentation"].cpu().numpy() | |
segments_info = predicted_panoptic_map["segments_info"] | |
# Create cropped masks | |
cropped_masks_with_labels = [] | |
label_counts = {} | |
for segment in segments_info: | |
mask = (segmentation_map == segment["id"]).astype(np.uint8) * 255 | |
cropped_image = cv2.bitwise_and(np.array(image), np.array(image), mask=mask) | |
label = model.config.id2label[segment["label_id"]] | |
# Check if label already exists | |
if label in label_counts: | |
label_counts[label] += 1 | |
else: | |
label_counts[label] = 1 | |
label = f"{label}_{label_counts[label] - 1}" # Append _0, _1, etc. | |
cropped_masks_with_labels.append((cropped_image, label)) | |
return cropped_masks_with_labels | |
def merge_segments_by_labels(gallery_images, labels_input): | |
""" | |
Объединяет сегменты из галереи изображений в одно изображение, | |
основываясь на введенных пользователем метках. | |
Args: | |
gallery_images: Список изображений сегментов (кортежи (изображение, метка)). | |
labels_input: Строка с метками, разделенными точкой с запятой. | |
Returns: | |
Список изображений, где выбранные сегменты объединены в одно. | |
""" | |
# 1. Разделяем входную строку с метками на список | |
labels_to_merge = [label.strip() for label in labels_input.split(";")] | |
# 2. Создаем пустое изображение для объединения | |
merged_image = None | |
# 3. Инициализируем список для хранения индексов объединенных сегментов | |
merged_indices = [] | |
# 4. Проходим по всем изображениям в галерее | |
for i, (image_path, label) in enumerate(gallery_images): | |
# 5. Если метка сегмента в списке меток для объединения | |
if label in labels_to_merge: | |
image = cv2.imread(image_path) | |
# 6. Если это первый сегмент для объединения | |
if merged_image is None: | |
# 7. Создаем копию изображения как основу для объединения | |
merged_image = image.copy() | |
else: | |
# 8. Объединяем текущее изображение с merged_image | |
# Используем cv2.add для наложения изображений, | |
# предполагая, что сегменты не перекрываются | |
merged_image = cv2.add(merged_image, image) | |
# 9. Добавляем индекс объединенного сегмента в список | |
merged_indices.append(i) | |
# 10. Если сегменты были объединены | |
if merged_image is not None: | |
# 11. Создаем новый список изображений, удаляя объединенные сегменты | |
# и добавляя объединенное изображение с новой меткой | |
new_gallery_images = [ | |
item for i, item in enumerate(gallery_images) if i not in merged_indices | |
] | |
new_name = labels_to_merge[0] | |
new_gallery_images.append((merged_image, new_name)) | |
return new_gallery_images | |
else: | |
# 12. Если не было меток для объединения, возвращаем исходный список | |
return gallery_images | |
def set_client_for_session(request: gr.Request): | |
x_ip_token = request.headers['x-ip-token'] | |
# The "JeffreyXiang/TRELLIS" space is a ZeroGPU space | |
return Client("JeffreyXiang/TRELLIS", headers={"X-IP-Token": x_ip_token}) | |
def generate_3d_model(client, segment_output, segment_name): | |
for i, (image_path, label) in enumerate(segment_output): | |
if label == segment_name: | |
result = client.predict( | |
image=handle_file(image_path), | |
multiimages=[], | |
seed=0, | |
ss_guidance_strength=7.5, | |
ss_sampling_steps=12, | |
slat_guidance_strength=3, | |
slat_sampling_steps=12, | |
multiimage_algo="stochastic", | |
api_name="/image_to_3d" | |
) | |
break | |
print(result) | |
return result["video"] | |
def classify_segments(segments): | |
# Предполагается, что segments - список изображений сегментов | |
results = [] | |
for segment in segments: | |
results.append(classification(segment)) | |
return results # Вернем список классификаций | |
def upscale_segment(segment): | |
upscaled = upscaling_client.image_to_image(segment) | |
return upscaled | |
def inpaint_image(image, mask, prompt): | |
inpainted = inpainting_client.text_to_image(prompt, image=image, mask=mask) | |
return inpainted | |
with gr.Blocks() as demo: | |
client = gr.State() | |
gr.Markdown("# Анализ и редактирование помещений") | |
with gr.Tab("Сканирование"): | |
with gr.Row(): | |
with gr.Column(scale=5): | |
image_input = gr.Image() | |
segment_button = gr.Button("Сегментировать") | |
with gr.Column(scale=5): | |
segment_output = gr.Gallery() | |
merge_segments_input = gr.Textbox(label="Сегменты для объединения (через точку с запятой, например: \"wall_0; tv_0\")") | |
merge_segments_button = gr.Button("Соединить сегменты") | |
merge_segments_button.click(merge_segments_by_labels, inputs=[segment_output, merge_segments_input], outputs=segment_output) | |
with gr.Row(): | |
with gr.Column(scale=5): | |
trellis_input = gr.Textbox(label="Имя сегмента для 3D") | |
trellis_button = gr.Button("3D Trellis") | |
with gr.Column(scale=5): | |
trellis_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300) | |
trellis_button.click(generate_3d_model, inputs=[client, segment_output, trellis_input], outputs=trellis_output) | |
segment_button.click(segment_image, inputs=image_input, outputs=segment_output) | |
# segment_button.click(segment_full_image, inputs=image_input, outputs=segment_output) | |
with gr.Tab("Редактирование"): | |
segment_input = gr.Image() | |
upscale_output = gr.Image() | |
upscale_button = gr.Button("Upscale") | |
upscale_button.click(upscale_segment, inputs=segment_input, outputs=upscale_output) | |
mask_input = gr.Image() | |
prompt_input = gr.Textbox() | |
inpaint_output = gr.Image() | |
inpaint_button = gr.Button("Inpaint") | |
inpaint_button.click(inpaint_image, inputs=[segment_input, mask_input, prompt_input], outputs=inpaint_output) | |
with gr.Tab("Создание 3D моделей"): | |
segment_input_3d = gr.Image() | |
model_output = gr.File() | |
model_button = gr.Button("Создать 3D модель") | |
model_button.click(generate_3d_model, inputs=segment_input_3d, outputs=model_output) | |
demo.load(set_client_for_session, None, client) | |
demo.launch(debug=True, show_error=True) |