test / app.py
SharafeevRavil's picture
Update app.py
7ec19ea verified
raw
history blame
10.6 kB
import gradio as gr
import torch
from transformers import pipeline
from huggingface_hub import InferenceClient
from PIL import Image, ImageDraw
from gradio_client import Client, handle_file
import numpy as np
import cv2
import os
# Инициализация моделей
from transformers import OneFormerProcessor, OneFormerForUniversalSegmentation
device = "cuda" if torch.cuda.is_available() else "cpu"
# oneFormer segmentation
oneFormer_processor = OneFormerProcessor.from_pretrained("shi-labs/oneformer_ade20k_swin_tiny")
oneFormer_model = OneFormerForUniversalSegmentation.from_pretrained("shi-labs/oneformer_ade20k_swin_tiny").to(device)
# classification = pipeline("image-classification", model="google/vit-base-patch16-224")
# upscaling_client = InferenceClient(model="stabilityai/stable-diffusion-x4-upscaler")
# inpainting_client = InferenceClient(model="stabilityai/stable-diffusion-inpainting")
# Функции для обработки изображений
def segment_image(image):
image = Image.fromarray(image)
inputs = oneFormer_processor(image, task_inputs=["panoptic"], return_tensors="pt")
with torch.no_grad():
outputs = oneFormer_model(**inputs)
# post-process the raw predictions
predicted_panoptic_map = oneFormer_processor.post_process_panoptic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
# Extract segment ids and masks
segmentation_map = predicted_panoptic_map["segmentation"].cpu().numpy()
segments_info = predicted_panoptic_map["segments_info"]
# Create cropped masks
cropped_masks_with_labels = []
label_counts = {}
for segment in segments_info:
mask = (segmentation_map == segment["id"]).astype(np.uint8) * 255
# cropped_image = cv2.bitwise_and(np.array(image), np.array(image), mask=mask)
cropped_image = np.zeros((image.height, image.width, 4), dtype=np.uint8)
cropped_image[mask != 0, :3] = np.array(image)[mask != 0]
cropped_image[mask != 0, 3] = 255
label = oneFormer_model.config.id2label[segment["label_id"]]
# Check if label already exists
if label in label_counts:
label_counts[label] += 1
else:
label_counts[label] = 1
label = f"{label}_{label_counts[label] - 1}" # Append _0, _1, etc.
cropped_masks_with_labels.append((cropped_image, label))
return cropped_masks_with_labels
def merge_segments_by_labels(gallery_images, labels_input):
labels_to_merge = [label.strip() for label in labels_input.split(";")]
merged_image = None
merged_indices = []
for i, (image_path, label) in enumerate(gallery_images): # Исправлено: image_path
if label in labels_to_merge:
# Загружаем изображение с помощью PIL, сохраняя альфа-канал
image = Image.open(image_path).convert("RGBA")
if merged_image is None:
merged_image = image.copy()
else:
# Объединяем изображения с учетом альфа-канала
merged_image = Image.alpha_composite(merged_image, image)
merged_indices.append(i)
if merged_image is not None:
# Преобразуем объединенное изображение в numpy array
merged_image_np = np.array(merged_image)
new_gallery_images = [
item for i, item in enumerate(gallery_images) if i not in merged_indices
]
new_name = labels_to_merge[0]
new_gallery_images.append((merged_image_np, new_name))
return new_gallery_images
else:
return gallery_images
def set_hunyuan_client(request: gr.Request):
try:
x_ip_token = request.headers['x-ip-token']
client = Client("tencent/Hunyuan3D-2", headers={"X-IP-Token": x_ip_token})
print(x_ip_token, "tencent/Hunyuan3D-2 Ip token")
return client
except:
print("tencent/Hunyuan3D-2 no token")
return Client("tencent/Hunyuan3D-2")
def set_vFusion_client(request: gr.Request):
try:
x_ip_token = request.headers['x-ip-token']
client = Client("facebook/VFusion3D", headers={"X-IP-Token": x_ip_token})
print(x_ip_token, "facebook/VFusion3D Ip token")
return client
except:
print("facebook/VFusion3D no token")
return Client("facebook/VFusion3D")
def generate_3d_model(client, segment_output, segment_name):
for i, (image_path, label) in enumerate(segment_output):
if label == segment_name:
result = client.predict(
caption="",
image=handle_file(image_path),
steps=50,
guidance_scale=5.5,
seed=1234,
octree_resolution="256",
check_box_rembg=True,
api_name="/shape_generation"
)
print(result)
return result[0]
def generate_3d_model_texture(client, segment_output, segment_name):
for i, (image_path, label) in enumerate(segment_output):
if label == segment_name:
result = client.predict(
caption="",
image=handle_file(image_path),
steps=50,
guidance_scale=5.5,
seed=1234,
octree_resolution="256",
check_box_rembg=True,
api_name="/generation_all"
)
print(result)
return result[1]
def generate_3d_model2(client, segment_output, segment_name):
for i, (image_path, label) in enumerate(segment_output):
if label == segment_name:
result = client.predict(
image=handle_file(image_path),
api_name="/step_1_generate_obj"
)
print(result)
return result[0]
# def classify_segments(segments):
# # Предполагается, что segments - список изображений сегментов
# results = []
# for segment in segments:
# results.append(classification(segment))
# return results # Вернем список классификаций
# def upscale_segment(segment):
# upscaled = upscaling_client.image_to_image(segment)
# return upscaled
# def inpaint_image(image, mask, prompt):
# inpainted = inpainting_client.text_to_image(prompt, image=image, mask=mask)
# return inpainted
with gr.Blocks() as demo:
hunyuan_client = gr.State()
vFusion_client = gr.State()
gr.Markdown("# Анализ и редактирование помещений")
with gr.Tab("Сканирование"):
with gr.Row():
with gr.Column(scale=5):
image_input = gr.Image()
segment_button = gr.Button("Сегментировать")
with gr.Column(scale=5):
segment_output = gr.Gallery()
merge_segments_input = gr.Textbox(label="Сегменты для объединения (через точку с запятой, например: \"wall_0; tv_0\")")
merge_segments_button = gr.Button("Соединить сегменты")
merge_segments_button.click(merge_segments_by_labels, inputs=[segment_output, merge_segments_input], outputs=segment_output)
with gr.Row():
with gr.Column(scale=5):
trellis_input = gr.Textbox(label="Имя сегмента для 3D")
hunyuan_button = gr.Button("Hunyuan3D-2")
hunyuan_button_texture = gr.Button("Hunyuan3D-2 (with texture)")
vFusion_button = gr.Button("VFusion3D")
with gr.Column(scale=5):
# trellis_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
# trellis_output2 = LitModel3D(
# clear_color=[0.1, 0.1, 0.1, 0], # can adjust background color for better contrast
# label="3D Model Visualization",
# scale=1.0,
# tonemapping="aces", # can use aces tonemapping for more realistic lighting
# exposure=1.0, # can adjust exposure to control brightness
# contrast=1.1, # can slightly increase contrast for better depth
# camera_position=(0, 0, 2), # will set initial camera position to center the model
# zoom_speed=0.5, # will adjust zoom speed for better control
# pan_speed=0.5, # will adjust pan speed for better control
# interactive=True # this allow users to interact with the model
# )
trellis_output = gr.Model3D(label="3D Model")
# trellis_button.click(generate_3d_model, inputs=[client, segment_output, trellis_input], outputs=trellis_output)
hunyuan_button.click(generate_3d_model, inputs=[hunyuan_client, segment_output, trellis_input], outputs=trellis_output)
hunyuan_button_texture.click(generate_3d_model_texture, inputs=[hunyuan_client, segment_output, trellis_input], outputs=trellis_output)
vFusion_button.click(generate_3d_model2, inputs=[vFusion_client, segment_output, trellis_input], outputs=trellis_output)
segment_button.click(segment_image, inputs=image_input, outputs=segment_output)
# segment_button.click(segment_full_image, inputs=image_input, outputs=segment_output)
# with gr.Tab("Редактирование"):
# segment_input = gr.Image()
# upscale_output = gr.Image()
# upscale_button = gr.Button("Upscale")
# upscale_button.click(upscale_segment, inputs=segment_input, outputs=upscale_output)
# mask_input = gr.Image()
# prompt_input = gr.Textbox()
# inpaint_output = gr.Image()
# inpaint_button = gr.Button("Inpaint")
# inpaint_button.click(inpaint_image, inputs=[segment_input, mask_input, prompt_input], outputs=inpaint_output)
# with gr.Tab("Создание 3D моделей"):
# segment_input_3d = gr.Image()
# model_output = gr.File()
# model_button = gr.Button("Создать 3D модель")
# model_button.click(generate_3d_model, inputs=segment_input_3d, outputs=model_output)
demo.load(set_hunyuan_client, None, hunyuan_client)
demo.load(set_vFusion_client, None, vFusion_client)
demo.launch(debug=True, show_error=True)