Spaces:
Running
Running
import gradio as gr | |
import torch | |
from transformers import pipeline | |
from huggingface_hub import InferenceClient | |
from PIL import Image, ImageDraw | |
from gradio_client import Client, handle_file | |
import numpy as np | |
import cv2 | |
import os | |
# Инициализация моделей | |
from transformers import OneFormerProcessor, OneFormerForUniversalSegmentation | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# oneFormer segmentation | |
oneFormer_processor = OneFormerProcessor.from_pretrained("shi-labs/oneformer_ade20k_swin_tiny") | |
oneFormer_model = OneFormerForUniversalSegmentation.from_pretrained("shi-labs/oneformer_ade20k_swin_tiny").to(device) | |
# classification = pipeline("image-classification", model="google/vit-base-patch16-224") | |
# upscaling_client = InferenceClient(model="stabilityai/stable-diffusion-x4-upscaler") | |
# inpainting_client = InferenceClient(model="stabilityai/stable-diffusion-inpainting") | |
# Функции для обработки изображений | |
def segment_image(image): | |
image = Image.fromarray(image) | |
inputs = oneFormer_processor(image, task_inputs=["panoptic"], return_tensors="pt") | |
with torch.no_grad(): | |
outputs = oneFormer_model(**inputs) | |
# post-process the raw predictions | |
predicted_panoptic_map = oneFormer_processor.post_process_panoptic_segmentation(outputs, target_sizes=[image.size[::-1]])[0] | |
# Extract segment ids and masks | |
segmentation_map = predicted_panoptic_map["segmentation"].cpu().numpy() | |
segments_info = predicted_panoptic_map["segments_info"] | |
# Create cropped masks | |
cropped_masks_with_labels = [] | |
label_counts = {} | |
for segment in segments_info: | |
mask = (segmentation_map == segment["id"]).astype(np.uint8) * 255 | |
# cropped_image = cv2.bitwise_and(np.array(image), np.array(image), mask=mask) | |
cropped_image = np.zeros((image.height, image.width, 4), dtype=np.uint8) | |
cropped_image[mask != 0, :3] = np.array(image)[mask != 0] | |
cropped_image[mask != 0, 3] = 255 | |
label = oneFormer_model.config.id2label[segment["label_id"]] | |
# Check if label already exists | |
if label in label_counts: | |
label_counts[label] += 1 | |
else: | |
label_counts[label] = 1 | |
label = f"{label}_{label_counts[label] - 1}" # Append _0, _1, etc. | |
cropped_masks_with_labels.append((cropped_image, label)) | |
return cropped_masks_with_labels | |
def merge_segments_by_labels(gallery_images, labels_input): | |
labels_to_merge = [label.strip() for label in labels_input.split(";")] | |
merged_image = None | |
merged_indices = [] | |
for i, (image_path, label) in enumerate(gallery_images): # Исправлено: image_path | |
if label in labels_to_merge: | |
# Загружаем изображение с помощью PIL, сохраняя альфа-канал | |
image = Image.open(image_path).convert("RGBA") | |
if merged_image is None: | |
merged_image = image.copy() | |
else: | |
# Объединяем изображения с учетом альфа-канала | |
merged_image = Image.alpha_composite(merged_image, image) | |
merged_indices.append(i) | |
if merged_image is not None: | |
# Преобразуем объединенное изображение в numpy array | |
merged_image_np = np.array(merged_image) | |
new_gallery_images = [ | |
item for i, item in enumerate(gallery_images) if i not in merged_indices | |
] | |
new_name = labels_to_merge[0] | |
new_gallery_images.append((merged_image_np, new_name)) | |
return new_gallery_images | |
else: | |
return gallery_images | |
def set_hunyuan_client(request: gr.Request): | |
try: | |
x_ip_token = request.headers['x-ip-token'] | |
client = Client("tencent/Hunyuan3D-2", headers={"X-IP-Token": x_ip_token}) | |
print(x_ip_token, "tencent/Hunyuan3D-2 Ip token") | |
return client | |
except: | |
print("tencent/Hunyuan3D-2 no token") | |
return Client("tencent/Hunyuan3D-2") | |
def set_vFusion_client(request: gr.Request): | |
try: | |
x_ip_token = request.headers['x-ip-token'] | |
client = Client("facebook/VFusion3D", headers={"X-IP-Token": x_ip_token}) | |
print(x_ip_token, "facebook/VFusion3D Ip token") | |
return client | |
except: | |
print("facebook/VFusion3D no token") | |
return Client("facebook/VFusion3D") | |
def generate_3d_model(client, segment_output, segment_name): | |
for i, (image_path, label) in enumerate(segment_output): | |
if label == segment_name: | |
result = client.predict( | |
caption="", | |
image=handle_file(image_path), | |
steps=50, | |
guidance_scale=5.5, | |
seed=1234, | |
octree_resolution="256", | |
check_box_rembg=True, | |
api_name="/shape_generation" | |
) | |
print(result) | |
return result[0] | |
def generate_3d_model_texture(client, segment_output, segment_name): | |
for i, (image_path, label) in enumerate(segment_output): | |
if label == segment_name: | |
result = client.predict( | |
caption="", | |
image=handle_file(image_path), | |
steps=50, | |
guidance_scale=5.5, | |
seed=1234, | |
octree_resolution="256", | |
check_box_rembg=True, | |
api_name="/generation_all" | |
) | |
print(result) | |
return result[1] | |
def generate_3d_model2(client, segment_output, segment_name): | |
for i, (image_path, label) in enumerate(segment_output): | |
if label == segment_name: | |
result = client.predict( | |
image=handle_file(image_path), | |
api_name="/step_1_generate_obj" | |
) | |
print(result) | |
return result[0] | |
# def classify_segments(segments): | |
# # Предполагается, что segments - список изображений сегментов | |
# results = [] | |
# for segment in segments: | |
# results.append(classification(segment)) | |
# return results # Вернем список классификаций | |
# def upscale_segment(segment): | |
# upscaled = upscaling_client.image_to_image(segment) | |
# return upscaled | |
# def inpaint_image(image, mask, prompt): | |
# inpainted = inpainting_client.text_to_image(prompt, image=image, mask=mask) | |
# return inpainted | |
with gr.Blocks() as demo: | |
hunyuan_client = gr.State() | |
vFusion_client = gr.State() | |
gr.Markdown("# Анализ и редактирование помещений") | |
with gr.Tab("Сканирование"): | |
with gr.Row(): | |
with gr.Column(scale=5): | |
image_input = gr.Image() | |
segment_button = gr.Button("Сегментировать") | |
with gr.Column(scale=5): | |
segment_output = gr.Gallery() | |
merge_segments_input = gr.Textbox(label="Сегменты для объединения (через точку с запятой, например: \"wall_0; tv_0\")") | |
merge_segments_button = gr.Button("Соединить сегменты") | |
merge_segments_button.click(merge_segments_by_labels, inputs=[segment_output, merge_segments_input], outputs=segment_output) | |
with gr.Row(): | |
with gr.Column(scale=5): | |
trellis_input = gr.Textbox(label="Имя сегмента для 3D") | |
hunyuan_button = gr.Button("Hunyuan3D-2") | |
hunyuan_button_texture = gr.Button("Hunyuan3D-2 (with texture)") | |
vFusion_button = gr.Button("VFusion3D") | |
with gr.Column(scale=5): | |
# trellis_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300) | |
# trellis_output2 = LitModel3D( | |
# clear_color=[0.1, 0.1, 0.1, 0], # can adjust background color for better contrast | |
# label="3D Model Visualization", | |
# scale=1.0, | |
# tonemapping="aces", # can use aces tonemapping for more realistic lighting | |
# exposure=1.0, # can adjust exposure to control brightness | |
# contrast=1.1, # can slightly increase contrast for better depth | |
# camera_position=(0, 0, 2), # will set initial camera position to center the model | |
# zoom_speed=0.5, # will adjust zoom speed for better control | |
# pan_speed=0.5, # will adjust pan speed for better control | |
# interactive=True # this allow users to interact with the model | |
# ) | |
trellis_output = gr.Model3D(label="3D Model") | |
# trellis_button.click(generate_3d_model, inputs=[client, segment_output, trellis_input], outputs=trellis_output) | |
hunyuan_button.click(generate_3d_model, inputs=[hunyuan_client, segment_output, trellis_input], outputs=trellis_output) | |
hunyuan_button_texture.click(generate_3d_model_texture, inputs=[hunyuan_client, segment_output, trellis_input], outputs=trellis_output) | |
vFusion_button.click(generate_3d_model2, inputs=[vFusion_client, segment_output, trellis_input], outputs=trellis_output) | |
segment_button.click(segment_image, inputs=image_input, outputs=segment_output) | |
# segment_button.click(segment_full_image, inputs=image_input, outputs=segment_output) | |
# with gr.Tab("Редактирование"): | |
# segment_input = gr.Image() | |
# upscale_output = gr.Image() | |
# upscale_button = gr.Button("Upscale") | |
# upscale_button.click(upscale_segment, inputs=segment_input, outputs=upscale_output) | |
# mask_input = gr.Image() | |
# prompt_input = gr.Textbox() | |
# inpaint_output = gr.Image() | |
# inpaint_button = gr.Button("Inpaint") | |
# inpaint_button.click(inpaint_image, inputs=[segment_input, mask_input, prompt_input], outputs=inpaint_output) | |
# with gr.Tab("Создание 3D моделей"): | |
# segment_input_3d = gr.Image() | |
# model_output = gr.File() | |
# model_button = gr.Button("Создать 3D модель") | |
# model_button.click(generate_3d_model, inputs=segment_input_3d, outputs=model_output) | |
demo.load(set_hunyuan_client, None, hunyuan_client) | |
demo.load(set_vFusion_client, None, vFusion_client) | |
demo.launch(debug=True, show_error=True) |