test / app.py
SharafeevRavil's picture
Create app.py
23eb257 verified
raw
history blame
9.33 kB
!pip install gradio
import gradio as gr
import torch
from transformers import pipeline
from huggingface_hub import InferenceClient
from PIL import Image
import numpy as np
import cv2
# Инициализация моделей
# segmentation = pipeline("image-segmentation", model="nvidia/segformer-b5-finetuned-ade-640-640")
classification = pipeline("image-classification", model="google/vit-base-patch16-224")
upscaling_client = InferenceClient(model="stabilityai/stable-diffusion-x4-upscaler")
inpainting_client = InferenceClient(model="stabilityai/stable-diffusion-inpainting")
trellis_client = InferenceClient(model="microsoft/TRELLIS")
from transformers import OneFormerProcessor, OneFormerForUniversalSegmentation
import torch
from PIL import Image
import numpy as np
device = "cuda" if torch.cuda.is_available() else "cpu"
processor = OneFormerProcessor.from_pretrained("shi-labs/oneformer_ade20k_swin_tiny")
model = OneFormerForUniversalSegmentation.from_pretrained("shi-labs/oneformer_ade20k_swin_tiny").to(device)
# Функции для обработки изображений
from PIL import Image, ImageDraw
import numpy as np
def segment_image(image):
image = Image.fromarray(image)
# Изменяем task_input на "panoptic"
inputs = processor(image, task_inputs=["panoptic"], return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
# post-process the raw predictions
predicted_panoptic_map = processor.post_process_panoptic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
# Extract segment ids and masks
segmentation_map = predicted_panoptic_map["segmentation"].cpu().numpy()
segments_info = predicted_panoptic_map["segments_info"]
# Create cropped masks
cropped_masks_with_labels = []
label_counts = {}
for segment in segments_info:
mask = (segmentation_map == segment["id"]).astype(np.uint8) * 255
cropped_image = cv2.bitwise_and(np.array(image), np.array(image), mask=mask)
label = model.config.id2label[segment["label_id"]]
# Check if label already exists
if label in label_counts:
label_counts[label] += 1
else:
label_counts[label] = 1
label = f"{label}_{label_counts[label] - 1}" # Append _0, _1, etc.
cropped_masks_with_labels.append((cropped_image, label))
return cropped_masks_with_labels
def merge_segments_by_labels(gallery_images, labels_input):
"""
Объединяет сегменты из галереи изображений в одно изображение,
основываясь на введенных пользователем метках.
Args:
gallery_images: Список изображений сегментов (кортежи (изображение, метка)).
labels_input: Строка с метками, разделенными точкой с запятой.
Returns:
Список изображений, где выбранные сегменты объединены в одно.
"""
# 1. Разделяем входную строку с метками на список
labels_to_merge = [label.strip() for label in labels_input.split(";")]
# 2. Создаем пустое изображение для объединения
merged_image = None
# 3. Инициализируем список для хранения индексов объединенных сегментов
merged_indices = []
# 4. Проходим по всем изображениям в галерее
for i, (image_path, label) in enumerate(gallery_images):
# 5. Если метка сегмента в списке меток для объединения
if label in labels_to_merge:
image = cv2.imread(image_path)
# 6. Если это первый сегмент для объединения
if merged_image is None:
# 7. Создаем копию изображения как основу для объединения
merged_image = image.copy()
else:
# 8. Объединяем текущее изображение с merged_image
# Используем cv2.add для наложения изображений,
# предполагая, что сегменты не перекрываются
merged_image = cv2.add(merged_image, image)
# 9. Добавляем индекс объединенного сегмента в список
merged_indices.append(i)
# 10. Если сегменты были объединены
if merged_image is not None:
# 11. Создаем новый список изображений, удаляя объединенные сегменты
# и добавляя объединенное изображение с новой меткой
new_gallery_images = [
item for i, item in enumerate(gallery_images) if i not in merged_indices
]
new_name = labels_to_merge[0]
new_gallery_images.append((merged_image, new_name))
return new_gallery_images
else:
# 12. Если не было меток для объединения, возвращаем исходный список
return gallery_images
def set_client_for_session(request: gr.Request):
x_ip_token = request.headers['x-ip-token']
# The "JeffreyXiang/TRELLIS" space is a ZeroGPU space
return Client("JeffreyXiang/TRELLIS", headers={"X-IP-Token": x_ip_token})
def generate_3d_model(client, segment_output, segment_name):
for i, (image_path, label) in enumerate(segment_output):
if label == segment_name:
result = client.predict(
image=handle_file(image_path),
multiimages=[],
seed=0,
ss_guidance_strength=7.5,
ss_sampling_steps=12,
slat_guidance_strength=3,
slat_sampling_steps=12,
multiimage_algo="stochastic",
api_name="/image_to_3d"
)
break
print(result)
return result["video"]
def classify_segments(segments):
# Предполагается, что segments - список изображений сегментов
results = []
for segment in segments:
results.append(classification(segment))
return results # Вернем список классификаций
def upscale_segment(segment):
upscaled = upscaling_client.image_to_image(segment)
return upscaled
def inpaint_image(image, mask, prompt):
inpainted = inpainting_client.text_to_image(prompt, image=image, mask=mask)
return inpainted
with gr.Blocks() as demo:
client = gr.State()
gr.Markdown("# Анализ и редактирование помещений")
with gr.Tab("Сканирование"):
with gr.Row():
with gr.Column(scale=5):
image_input = gr.Image()
segment_button = gr.Button("Сегментировать")
with gr.Column(scale=5):
segment_output = gr.Gallery()
merge_segments_input = gr.Textbox(label="Сегменты для объединения (через точку с запятой, например: \"wall_0; tv_0\")")
merge_segments_button = gr.Button("Соединить сегменты")
merge_segments_button.click(merge_segments_by_labels, inputs=[segment_output, merge_segments_input], outputs=segment_output)
with gr.Row():
with gr.Column(scale=5):
trellis_input = gr.Textbox(label="Имя сегмента для 3D")
trellis_button = gr.Button("3D Trellis")
with gr.Column(scale=5):
trellis_output = gr.Video(label="Generated 3D Asset", autoplay=True, loop=True, height=300)
trellis_button.click(generate_3d_model, inputs=[client, segment_output, trellis_input], outputs=trellis_output)
segment_button.click(segment_image, inputs=image_input, outputs=segment_output)
# segment_button.click(segment_full_image, inputs=image_input, outputs=segment_output)
with gr.Tab("Редактирование"):
segment_input = gr.Image()
upscale_output = gr.Image()
upscale_button = gr.Button("Upscale")
upscale_button.click(upscale_segment, inputs=segment_input, outputs=upscale_output)
mask_input = gr.Image()
prompt_input = gr.Textbox()
inpaint_output = gr.Image()
inpaint_button = gr.Button("Inpaint")
inpaint_button.click(inpaint_image, inputs=[segment_input, mask_input, prompt_input], outputs=inpaint_output)
with gr.Tab("Создание 3D моделей"):
segment_input_3d = gr.Image()
model_output = gr.File()
model_button = gr.Button("Создать 3D модель")
model_button.click(generate_3d_model, inputs=segment_input_3d, outputs=model_output)
demo.load(set_client_for_session, None, client)
demo.launch(debug=True, show_error=True)