test / app.py
SharafeevRavil's picture
Update app.py
7b93e70 verified
raw
history blame
13 kB
import gradio as gr
import torch
from transformers import pipeline
from huggingface_hub import InferenceClient
from PIL import Image, ImageDraw
from gradio_client import Client, handle_file
import numpy as np
import cv2
# Инициализация моделей
from transformers import OneFormerProcessor, OneFormerForUniversalSegmentation
device = "cuda" if torch.cuda.is_available() else "cpu"
# oneFormer segmentation
oneFormer_processor = OneFormerProcessor.from_pretrained("shi-labs/oneformer_ade20k_swin_tiny")
oneFormer_model = OneFormerForUniversalSegmentation.from_pretrained("shi-labs/oneformer_ade20k_swin_tiny").to(device)
# classification = pipeline("image-classification", model="google/vit-base-patch16-224")
# upscaling_client = InferenceClient(model="stabilityai/stable-diffusion-x4-upscaler")
# inpainting_client = InferenceClient(model="stabilityai/stable-diffusion-inpainting")
# Функции для обработки изображений
def segment_image(image):
image = Image.fromarray(image)
inputs = oneFormer_processor(image, task_inputs=["panoptic"], return_tensors="pt")
with torch.no_grad():
outputs = oneFormer_model(**inputs)
# post-process the raw predictions
predicted_panoptic_map = oneFormer_processor.post_process_panoptic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
# Extract segment ids and masks
segmentation_map = predicted_panoptic_map["segmentation"].cpu().numpy()
segments_info = predicted_panoptic_map["segments_info"]
# Create cropped masks
cropped_masks_with_labels = []
label_counts = {}
for segment in segments_info:
mask = (segmentation_map == segment["id"]).astype(np.uint8) * 255
# cropped_image = cv2.bitwise_and(np.array(image), np.array(image), mask=mask)
cropped_image = np.zeros((image.height, image.width, 4), dtype=np.uint8)
cropped_image[mask != 0, :3] = np.array(image)[mask != 0]
cropped_image[mask != 0, 3] = 255
label = oneFormer_model.config.id2label[segment["label_id"]]
# Check if label already exists
if label in label_counts:
label_counts[label] += 1
else:
label_counts[label] = 1
label = f"{label}_{label_counts[label] - 1}" # Append _0, _1, etc.
cropped_masks_with_labels.append((cropped_image, label))
return cropped_masks_with_labels
# def merge_segments_by_labels(gallery_images, labels_input):
# """
# Объединяет сегменты из галереи изображений в одно изображение,
# основываясь на введенных пользователем метках.
# Args:
# gallery_images: Список изображений сегментов (кортежи (изображение, метка)).
# labels_input: Строка с метками, разделенными точкой с запятой.
# Returns:
# Список изображений, где выбранные сегменты объединены в одно.
# """
# labels_to_merge = [label.strip() for label in labels_input.split(";")]
# merged_image = None
# merged_indices = []
# for i, (image_path, label) in enumerate(gallery_images):
# if label in labels_to_merge:
# image = cv2.imread(image_path)
# if merged_image is None:
# merged_image = image.copy()
# else:
# merged_image = cv2.add(merged_image, image)
# merged_indices.append(i)
# if merged_image is not None:
# new_gallery_images = [
# item for i, item in enumerate(gallery_images) if i not in merged_indices
# ]
# new_name = labels_to_merge[0]
# new_gallery_images.append((merged_image, new_name))
# return new_gallery_images
# else:
# return gallery_images
def merge_segments_by_labels(gallery_images, labels_input):
labels_to_merge = [label.strip() for label in labels_input.split(";")]
merged_image = None
merged_indices = []
for i, (image_path, label) in enumerate(gallery_images): # Исправлено: image_path
if label in labels_to_merge:
# Загружаем изображение с помощью PIL, сохраняя альфа-канал
image = Image.open(image_path).convert("RGBA")
if merged_image is None:
merged_image = image.copy()
else:
# Объединяем изображения с учетом альфа-канала
merged_image = Image.alpha_composite(merged_image, image)
merged_indices.append(i)
if merged_image is not None:
# Преобразуем объединенное изображение в numpy array
merged_image_np = np.array(merged_image)
new_gallery_images = [
item for i, item in enumerate(gallery_images) if i not in merged_indices
]
new_name = labels_to_merge[0]
new_gallery_images.append((merged_image_np, new_name))
return new_gallery_images
else:
return gallery_images
# def set_client_for_session(request: gr.Request):
# x_ip_token = request.headers['x-ip-token']
# return Client("JeffreyXiang/TRELLIS", headers={"X-IP-Token": x_ip_token})
def set_hunyuan_client(request: gr.Request):
try:
x_ip_token = request.headers['x-ip-token']
return Client("tencent/Hunyuan3D-2", headers={"X-IP-Token": x_ip_token})
except:
return Client("tencent/Hunyuan3D-2")
def set_vFusion_client(request: gr.Request):
try:
x_ip_token = request.headers['x-ip-token']
return Client("facebook/VFusion3D", headers={"X-IP-Token": x_ip_token})
except:
return Client("facebook/VFusion3D")
# def generate_3d_model(client, segment_output, segment_name):
# for i, (image_path, label) in enumerate(segment_output):
# if label == segment_name:
# result = client.predict(
# image=handle_file(image_path),
# multiimages=[],
# seed=0,
# ss_guidance_strength=7.5,
# ss_sampling_steps=12,
# slat_guidance_strength=3,
# slat_sampling_steps=12,
# multiimage_algo="stochastic",
# api_name="/image_to_3d"
# )
# break
# print(result)
# return result["video"]
def generate_3d_model(client, segment_output, segment_name):
for i, (image_path, label) in enumerate(segment_output):
if label == segment_name:
result = client.predict(
caption="",
image=handle_file(image_path),
steps=50,
guidance_scale=5.5,
seed=1234,
octree_resolution="256",
check_box_rembg=True,
api_name="/shape_generation"
)
print(result)
return (result[0], result[0])
def generate_3d_model_texture(client, segment_output, segment_name):
for i, (image_path, label) in enumerate(segment_output):
if label == segment_name:
result = client.predict(
caption="",
image=handle_file(image_path),
steps=50,
guidance_scale=5.5,
seed=1234,
octree_resolution="256",
check_box_rembg=True,
api_name="/generation_all"
)
print(result)
return (result[1], result[1])
def generate_3d_model2(client, segment_output, segment_name):
for i, (image_path, label) in enumerate(segment_output):
if label == segment_name:
result = client.predict(
image=handle_file(image_path),
api_name="/step_1_generate_obj"
)
print(result)
return (result[0], result[0])
# def classify_segments(segments):
# # Предполагается, что segments - список изображений сегментов
# results = []
# for segment in segments:
# results.append(classification(segment))
# return results # Вернем список классификаций
# def upscale_segment(segment):
# upscaled = upscaling_client.image_to_image(segment)
# return upscaled
# def inpaint_image(image, mask, prompt):
# inpainted = inpainting_client.text_to_image(prompt, image=image, mask=mask)
# return inpainted
from gradio_litmodel3d import LitModel3D
with gr.Blocks() as demo:
hunyuan_client = gr.State()
vFusion_client = gr.State()
gr.Markdown("# Анализ и редактирование помещений")
with gr.Tab("Сканирование"):
with gr.Row():
with gr.Column(scale=5):
image_input = gr.Image()
segment_button = gr.Button("Сегментировать")
with gr.Column(scale=5):
segment_output = gr.Gallery()
merge_segments_input = gr.Textbox(label="Сегменты для объединения (через точку с запятой, например: \"wall_0; tv_0\")")
merge_segments_button = gr.Button("Соединить сегменты")
merge_segments_button.click(merge_segments_by_labels, inputs=[segment_output, merge_segments_input], outputs=segment_output)
with gr.Row():
with gr.Column(scale=5):
trellis_input = gr.Textbox(label="Имя сегмента для 3D")
hunyuan_button = gr.Button("Hunyuan3D-2")
hunyuan_button_texture = gr.Button("Hunyuan3D-2 (with texture)")
vFusion_button = gr.Button("VFusion3D")
with gr.Column(scale=5):
# trellis_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
# trellis_output2 = LitModel3D(
# clear_color=[0.1, 0.1, 0.1, 0], # can adjust background color for better contrast
# label="3D Model Visualization",
# scale=1.0,
# tonemapping="aces", # can use aces tonemapping for more realistic lighting
# exposure=1.0, # can adjust exposure to control brightness
# contrast=1.1, # can slightly increase contrast for better depth
# camera_position=(0, 0, 2), # will set initial camera position to center the model
# zoom_speed=0.5, # will adjust zoom speed for better control
# pan_speed=0.5, # will adjust pan speed for better control
# interactive=True # this allow users to interact with the model
# )
trellis_output = gr.Model3D(label="3D Model")
trellis_output2 = gr.Model3D(clear_color=[0.0, 0.0, 0.0, 0.0], label="3D Model Wireframe")
# trellis_button.click(generate_3d_model, inputs=[client, segment_output, trellis_input], outputs=trellis_output)
hunyuan_button.click(generate_3d_model, inputs=[hunyuan_client, segment_output, trellis_input], outputs=[trellis_output, trellis_output2])
hunyuan_button_texture.click(generate_3d_model_texture, inputs=[hunyuan_client, segment_output, trellis_input], outputs=[trellis_output, trellis_output2])
vFusion_button.click(generate_3d_model2, inputs=[vFusion_client, segment_output, trellis_input], outputs=[trellis_output, trellis_output2])
segment_button.click(segment_image, inputs=image_input, outputs=segment_output)
# segment_button.click(segment_full_image, inputs=image_input, outputs=segment_output)
# with gr.Tab("Редактирование"):
# segment_input = gr.Image()
# upscale_output = gr.Image()
# upscale_button = gr.Button("Upscale")
# upscale_button.click(upscale_segment, inputs=segment_input, outputs=upscale_output)
# mask_input = gr.Image()
# prompt_input = gr.Textbox()
# inpaint_output = gr.Image()
# inpaint_button = gr.Button("Inpaint")
# inpaint_button.click(inpaint_image, inputs=[segment_input, mask_input, prompt_input], outputs=inpaint_output)
# with gr.Tab("Создание 3D моделей"):
# segment_input_3d = gr.Image()
# model_output = gr.File()
# model_button = gr.Button("Создать 3D модель")
# model_button.click(generate_3d_model, inputs=segment_input_3d, outputs=model_output)
demo.load(set_hunyuan_client, None, hunyuan_client)
demo.load(set_vFusion_client, None, vFusion_client)
demo.launch(debug=True, show_error=True)