test / app.py
SharafeevRavil's picture
semi-final version
5f2f11a verified
raw
history blame
13.7 kB
import gradio as gr
import torch
from transformers import pipeline
from huggingface_hub import InferenceClient
from PIL import Image, ImageDraw
from gradio_client import Client, handle_file
import numpy as np
import cv2
import os
import tempfile
import io
import base64
import requests
# Инициализация моделей
from transformers import OneFormerProcessor, OneFormerForUniversalSegmentation
device = "cuda" if torch.cuda.is_available() else "cpu"
# oneFormer segmentation
oneFormer_processor = OneFormerProcessor.from_pretrained("shi-labs/oneformer_ade20k_swin_tiny")
oneFormer_model = OneFormerForUniversalSegmentation.from_pretrained("shi-labs/oneformer_ade20k_swin_tiny").to(device)
# classification = pipeline("image-classification", model="google/vit-base-patch16-224")
# upscaling_client = InferenceClient(model="stabilityai/stable-diffusion-x4-upscaler")
# inpainting_client = InferenceClient(model="stabilityai/stable-diffusion-inpainting")
# Функции для обработки изображений
def segment_image(image):
inputs = oneFormer_processor(image, task_inputs=["panoptic"], return_tensors="pt")
with torch.no_grad():
outputs = oneFormer_model(**inputs)
# post-process the raw predictions
predicted_panoptic_map = oneFormer_processor.post_process_panoptic_segmentation(outputs, target_sizes=[image.size[::-1]])[0]
# Extract segment ids and masks
segmentation_map = predicted_panoptic_map["segmentation"].cpu().numpy()
segments_info = predicted_panoptic_map["segments_info"]
# Create cropped masks
cropped_masks_with_labels = []
label_counts = {}
for segment in segments_info:
mask = (segmentation_map == segment["id"]).astype(np.uint8) * 255
# cropped_image = cv2.bitwise_and(np.array(image), np.array(image), mask=mask)
cropped_image = np.zeros((image.height, image.width, 4), dtype=np.uint8)
cropped_image[mask != 0, :3] = np.array(image)[mask != 0]
cropped_image[mask != 0, 3] = 255
label = oneFormer_model.config.id2label[segment["label_id"]]
# Check if label already exists
if label in label_counts:
label_counts[label] += 1
else:
label_counts[label] = 1
label = f"{label}_{label_counts[label] - 1}" # Append _0, _1, etc.
cropped_masks_with_labels.append((cropped_image, label))
return cropped_masks_with_labels
def merge_segments_by_labels(gallery_images, labels_input):
labels_to_merge = [label.strip() for label in labels_input.split(";")]
merged_image = None
merged_indices = []
for i, (image_path, label) in enumerate(gallery_images): # Исправлено: image_path
if label in labels_to_merge:
# Загружаем изображение с помощью PIL, сохраняя альфа-канал
image = Image.open(image_path).convert("RGBA")
if merged_image is None:
merged_image = image.copy()
else:
# Объединяем изображения с учетом альфа-канала
merged_image = Image.alpha_composite(merged_image, image)
merged_indices.append(i)
if merged_image is not None:
# Преобразуем объединенное изображение в numpy array
merged_image_np = np.array(merged_image)
new_gallery_images = [
item for i, item in enumerate(gallery_images) if i not in merged_indices
]
new_name = labels_to_merge[0]
new_gallery_images.append((merged_image_np, new_name))
return new_gallery_images
else:
return gallery_images
def select_segment(segment_output, segment_name):
for i, (image_path, label) in enumerate(segment_output):
if label == segment_name:
return image_path
#Image edit
def return_image(imageEditor):
return imageEditor['composite']
def rembg_client(request: gr.Request):
try:
client = Client("KenjieDec/RemBG", headers={"X-IP-Token": request.headers['x-ip-token']})
print("KenjieDec/RemBG Ip token")
return client
except:
print("KenjieDec/RemBG no token")
return Client("KenjieDec/RemBG")
def autocrop_image(imageEditor, border = 0):
image = imageEditor['composite']
bbox = image.getbbox()
image = image.crop(bbox)
(width, height) = image.size
width += border * 2
height += border * 2
cropped_image = Image.new("RGBA", (width, height), (0,0,0,0))
cropped_image.paste(image, (border, border))
return cropped_image
def remove_black_make_transparent(imageEditor):
image_pil = imageEditor['composite']
if image_pil.mode != "RGBA":
image_pil = image_pil.convert("RGBA")
image_np = np.array(image_pil)
black_pixels_mask = np.all(image_np[:, :, :3] == [0, 0, 0], axis=-1)
image_np[black_pixels_mask, 3] = 0
transparent_image = Image.fromarray(image_np)
return transparent_image
def rembg(imageEditor, request: gr.Request):
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
imageEditor['composite'].save(temp_file.name)
temp_file_path = temp_file.name
client = rembg_client(request)
result = client.predict(
file=handle_file(temp_file_path),
mask="Default",
model="birefnet-general-lite",
x=0,
y=0,
api_name="/inference"
)
print(result)
return result
def add_transparent_border(imageEditor, border_size=200):
image = imageEditor['composite']
width, height = image.size
new_width = width + 2 * border_size
new_height = height + 2 * border_size
new_image = Image.new("RGBA", (new_width, new_height), (0, 0, 0, 0))
new_image.paste(image, (border_size, border_size))
return new_image
def upscale(imageEditor, scale, request: gr.Request):
return upscale_image(imageEditor['composite'], version="v1.4", rescaling_factor=scale)
def upscale_image(image_pil, version="v1.4", rescaling_factor=None):
buffered = io.BytesIO()
image_pil.save(buffered, format="PNG") # Save as PNG
img_str = base64.b64encode(buffered.getvalue()).decode()
# Update the data format for PNG
data = {"data": [f"data:image/png;base64,{img_str}", version, rescaling_factor]}
# Send request to the API
response = requests.post("https://nightfury-image-face-upscale-restoration-gfpgan.hf.space/api/predict", json=data)
response.raise_for_status()
# Get the base64 data from the response
base64_data = response.json()["data"][0]
base64_data = base64_data.split(",")[1] # remove data:image/png;base64,
# Convert base64 back to PIL Image
image_bytes = base64.b64decode(base64_data)
upscaled_image = Image.open(io.BytesIO(image_bytes))
return upscaled_image
#3d models
def hunyuan_client(request: gr.Request):
try:
client = Client("tencent/Hunyuan3D-2", headers={"X-IP-Token": request.headers['x-ip-token']})
print("tencent/Hunyuan3D-2 Ip token")
return client
except:
print("tencent/Hunyuan3D-2 no token")
return Client("tencent/Hunyuan3D-2")
def vFusion_client(request: gr.Request):
try:
client = Client("facebook/VFusion3D", headers={"X-IP-Token": request.headers['x-ip-token']})
print("facebook/VFusion3D Ip token")
return client
except:
print("facebook/VFusion3D no token")
return Client("facebook/VFusion3D")
def generate_3d_model(image_pil, rembg_Hunyuan, request: gr.Request):
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
image_pil.save(temp_file.name)
temp_file_path = temp_file.name
client = hunyuan_client(request)
result = client.predict(
caption="",
image=handle_file(temp_file_path),
steps=50,
guidance_scale=5.5,
seed=1234,
octree_resolution="256",
check_box_rembg=rembg_Hunyuan,
api_name="/shape_generation"
)
print(result)
return result[0]
def generate_3d_model_texture(image_pil, rembg_Hunyuan, request: gr.Request):
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
image_pil.save(temp_file.name)
temp_file_path = temp_file.name
client = hunyuan_client(request)
result = client.predict(
caption="",
image=handle_file(temp_file_path),
steps=50,
guidance_scale=5.5,
seed=1234,
octree_resolution="256",
check_box_rembg=rembg_Hunyuan,
api_name="/generation_all"
)
print(result)
return result[1]
def generate_3d_model2(image_pil, request: gr.Request):
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file:
image_pil.save(temp_file.name)
temp_file_path = temp_file.name
client = vFusion_client(request)
result = client.predict(
image=handle_file(temp_file_path),
api_name="/step_1_generate_obj"
)
print(result)
return result[0]
########## GRADIO ##########
with gr.Blocks() as demo:
gr.Markdown("# Анализ и редактирование помещений")
with gr.Tab("Сканирование"):
with gr.Row(equal_height=True):
with gr.Column(scale=5):
image_input = gr.Image(type="pil", label="Исходное изображение", height = 400)
segment_button = gr.Button("Сегментировать")
with gr.Column(scale=5):
segments_output = gr.Gallery(label="Сегменты изображения")
merge_segments_input = gr.Textbox(label="Сегменты для объединения (через точку с запятой, например: \"wall_0; tv_0\")")
merge_segments_button = gr.Button("Соединить сегменты")
merge_segments_button.click(merge_segments_by_labels, inputs=[segments_output, merge_segments_input], outputs=segments_output)
with gr.Row(equal_height=True):
segment_text_input = gr.Textbox(label="Имя сегмента для дальнейшего редактирования")
select_segment_button = gr.Button("Использовать сегмент")
with gr.Tab("Редактирование"):
with gr.Row(equal_height=True):
with gr.Column(scale=5):
segment_input = gr.ImageEditor(type="pil", label="Сегмент для редактирования")
with gr.Column(scale=5):
crop_button = gr.Button("Обрезать сегмент")
with gr.Row(equal_height=True):
upscale_slider = gr.Slider(minimum=1, maximum=5, value=2, step=0.1, label="во сколько раз")
upscale_button = gr.Button("Upscale")
rembg_button = gr.Button("Rembg")
remove_background_button = gr.Button("Убрать черный задний фон")
with gr.Row(equal_height=True):
add_transparent_border_slider = gr.Slider(minimum=10, maximum=500, value=200, step=10, label="в пикселях")
add_transparent_border_button = gr.Button("Добавить прозрачные края")
use_button = gr.Button("Использовать сегмент для 3D")
with gr.Tab("Создание 3D"):
with gr.Row(equal_height=True):
with gr.Column(scale=5):
segment_3d_input = gr.Image(type="pil", image_mode="RGBA", label="Сегмент для 3D", height = 600)
rembg_Hunyuan = gr.Checkbox(label="Hunyuan3D-2 rembg Enabled", info="Включить rembg для Hunyuan3D-2?")
hunyuan_button = gr.Button("Hunyuan3D-2 (no texture) [ZeroGPU = 100s]")
hunyuan_button_texture = gr.Button("Hunyuan3D-2 (with texture) [ZeroGPU = 150s]")
vFusion_button = gr.Button("VFusion3D [если у вас совсем все грустно по ZeroGPU]")
with gr.Column(scale=5):
trellis_output = gr.Model3D(label="3D Model")
#tab1
segment_button.click(segment_image, inputs=image_input, outputs=segments_output)
select_segment_button.click(select_segment, inputs=[segments_output, segment_text_input], outputs=segment_input)
#tab2
crop_button.click(autocrop_image, inputs=segment_input, outputs=segment_input)
upscale_button.click(upscale, inputs=[segment_input, upscale_slider], outputs=segment_input)
rembg_button.click(rembg, inputs=segment_input, outputs=segment_input)
remove_background_button.click(remove_black_make_transparent, inputs=segment_input, outputs=segment_input)
add_transparent_border_button.click(add_transparent_border, inputs=[segment_input, add_transparent_border_slider], outputs=segment_input)
use_button.click(return_image, inputs=segment_input, outputs=segment_3d_input)
#3d buttons
hunyuan_button.click(generate_3d_model, inputs=[segment_3d_input, rembg_Hunyuan], outputs=trellis_output)
hunyuan_button_texture.click(generate_3d_model_texture, inputs=[segment_3d_input, rembg_Hunyuan], outputs=trellis_output)
vFusion_button.click(generate_3d_model2, inputs=segment_3d_input, outputs=trellis_output)
demo.launch(debug=True, show_error=True)