Spaces:
Running
Running
import gradio as gr | |
import torch | |
from transformers import pipeline | |
from huggingface_hub import InferenceClient | |
from PIL import Image, ImageDraw | |
from gradio_client import Client, handle_file | |
import numpy as np | |
import cv2 | |
import os | |
import tempfile | |
import io | |
import base64 | |
import requests | |
# Инициализация моделей | |
from transformers import OneFormerProcessor, OneFormerForUniversalSegmentation | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
# oneFormer segmentation | |
oneFormer_processor = OneFormerProcessor.from_pretrained("shi-labs/oneformer_ade20k_swin_tiny") | |
oneFormer_model = OneFormerForUniversalSegmentation.from_pretrained("shi-labs/oneformer_ade20k_swin_tiny").to(device) | |
# classification = pipeline("image-classification", model="google/vit-base-patch16-224") | |
# upscaling_client = InferenceClient(model="stabilityai/stable-diffusion-x4-upscaler") | |
# inpainting_client = InferenceClient(model="stabilityai/stable-diffusion-inpainting") | |
# Функции для обработки изображений | |
def segment_image(image): | |
inputs = oneFormer_processor(image, task_inputs=["panoptic"], return_tensors="pt") | |
with torch.no_grad(): | |
outputs = oneFormer_model(**inputs) | |
# post-process the raw predictions | |
predicted_panoptic_map = oneFormer_processor.post_process_panoptic_segmentation(outputs, target_sizes=[image.size[::-1]])[0] | |
# Extract segment ids and masks | |
segmentation_map = predicted_panoptic_map["segmentation"].cpu().numpy() | |
segments_info = predicted_panoptic_map["segments_info"] | |
# Create cropped masks | |
cropped_masks_with_labels = [] | |
label_counts = {} | |
for segment in segments_info: | |
mask = (segmentation_map == segment["id"]).astype(np.uint8) * 255 | |
# cropped_image = cv2.bitwise_and(np.array(image), np.array(image), mask=mask) | |
cropped_image = np.zeros((image.height, image.width, 4), dtype=np.uint8) | |
cropped_image[mask != 0, :3] = np.array(image)[mask != 0] | |
cropped_image[mask != 0, 3] = 255 | |
label = oneFormer_model.config.id2label[segment["label_id"]] | |
# Check if label already exists | |
if label in label_counts: | |
label_counts[label] += 1 | |
else: | |
label_counts[label] = 1 | |
label = f"{label}_{label_counts[label] - 1}" # Append _0, _1, etc. | |
cropped_masks_with_labels.append((cropped_image, label)) | |
return cropped_masks_with_labels | |
def merge_segments_by_labels(gallery_images, labels_input): | |
labels_to_merge = [label.strip() for label in labels_input.split(";")] | |
merged_image = None | |
merged_indices = [] | |
for i, (image_path, label) in enumerate(gallery_images): # Исправлено: image_path | |
if label in labels_to_merge: | |
# Загружаем изображение с помощью PIL, сохраняя альфа-канал | |
image = Image.open(image_path).convert("RGBA") | |
if merged_image is None: | |
merged_image = image.copy() | |
else: | |
# Объединяем изображения с учетом альфа-канала | |
merged_image = Image.alpha_composite(merged_image, image) | |
merged_indices.append(i) | |
if merged_image is not None: | |
# Преобразуем объединенное изображение в numpy array | |
merged_image_np = np.array(merged_image) | |
new_gallery_images = [ | |
item for i, item in enumerate(gallery_images) if i not in merged_indices | |
] | |
new_name = labels_to_merge[0] | |
new_gallery_images.append((merged_image_np, new_name)) | |
return new_gallery_images | |
else: | |
return gallery_images | |
def select_segment(segment_output, segment_name): | |
for i, (image_path, label) in enumerate(segment_output): | |
if label == segment_name: | |
return image_path | |
#Image edit | |
def return_image(imageEditor): | |
return imageEditor['composite'] | |
def rembg_client(request: gr.Request): | |
try: | |
client = Client("KenjieDec/RemBG", headers={"X-IP-Token": request.headers['x-ip-token']}) | |
print("KenjieDec/RemBG Ip token") | |
return client | |
except: | |
print("KenjieDec/RemBG no token") | |
return Client("KenjieDec/RemBG") | |
def autocrop_image(imageEditor, border = 0): | |
image = imageEditor['composite'] | |
bbox = image.getbbox() | |
image = image.crop(bbox) | |
(width, height) = image.size | |
width += border * 2 | |
height += border * 2 | |
cropped_image = Image.new("RGBA", (width, height), (0,0,0,0)) | |
cropped_image.paste(image, (border, border)) | |
return cropped_image | |
def remove_black_make_transparent(imageEditor): | |
image_pil = imageEditor['composite'] | |
if image_pil.mode != "RGBA": | |
image_pil = image_pil.convert("RGBA") | |
image_np = np.array(image_pil) | |
black_pixels_mask = np.all(image_np[:, :, :3] == [0, 0, 0], axis=-1) | |
image_np[black_pixels_mask, 3] = 0 | |
transparent_image = Image.fromarray(image_np) | |
return transparent_image | |
def rembg(imageEditor, request: gr.Request): | |
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file: | |
imageEditor['composite'].save(temp_file.name) | |
temp_file_path = temp_file.name | |
client = rembg_client(request) | |
result = client.predict( | |
file=handle_file(temp_file_path), | |
mask="Default", | |
model="birefnet-general-lite", | |
x=0, | |
y=0, | |
api_name="/inference" | |
) | |
print(result) | |
return result | |
def add_transparent_border(imageEditor, border_size=200): | |
image = imageEditor['composite'] | |
width, height = image.size | |
new_width = width + 2 * border_size | |
new_height = height + 2 * border_size | |
new_image = Image.new("RGBA", (new_width, new_height), (0, 0, 0, 0)) | |
new_image.paste(image, (border_size, border_size)) | |
return new_image | |
def upscale(imageEditor, scale, request: gr.Request): | |
return upscale_image(imageEditor['composite'], version="v1.4", rescaling_factor=scale) | |
def upscale_image(image_pil, version="v1.4", rescaling_factor=None): | |
buffered = io.BytesIO() | |
image_pil.save(buffered, format="PNG") # Save as PNG | |
img_str = base64.b64encode(buffered.getvalue()).decode() | |
# Update the data format for PNG | |
data = {"data": [f"data:image/png;base64,{img_str}", version, rescaling_factor]} | |
# Send request to the API | |
response = requests.post("https://nightfury-image-face-upscale-restoration-gfpgan.hf.space/api/predict", json=data) | |
response.raise_for_status() | |
# Get the base64 data from the response | |
base64_data = response.json()["data"][0] | |
base64_data = base64_data.split(",")[1] # remove data:image/png;base64, | |
# Convert base64 back to PIL Image | |
image_bytes = base64.b64decode(base64_data) | |
upscaled_image = Image.open(io.BytesIO(image_bytes)) | |
return upscaled_image | |
#3d models | |
def hunyuan_client(request: gr.Request): | |
try: | |
client = Client("tencent/Hunyuan3D-2", headers={"X-IP-Token": request.headers['x-ip-token']}) | |
print("tencent/Hunyuan3D-2 Ip token") | |
return client | |
except: | |
print("tencent/Hunyuan3D-2 no token") | |
return Client("tencent/Hunyuan3D-2") | |
def vFusion_client(request: gr.Request): | |
try: | |
client = Client("facebook/VFusion3D", headers={"X-IP-Token": request.headers['x-ip-token']}) | |
print("facebook/VFusion3D Ip token") | |
return client | |
except: | |
print("facebook/VFusion3D no token") | |
return Client("facebook/VFusion3D") | |
def generate_3d_model(image_pil, rembg_Hunyuan, request: gr.Request): | |
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file: | |
image_pil.save(temp_file.name) | |
temp_file_path = temp_file.name | |
client = hunyuan_client(request) | |
result = client.predict( | |
caption="", | |
image=handle_file(temp_file_path), | |
steps=50, | |
guidance_scale=5.5, | |
seed=1234, | |
octree_resolution="256", | |
check_box_rembg=rembg_Hunyuan, | |
api_name="/shape_generation" | |
) | |
print(result) | |
return result[0] | |
def generate_3d_model_texture(image_pil, rembg_Hunyuan, request: gr.Request): | |
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file: | |
image_pil.save(temp_file.name) | |
temp_file_path = temp_file.name | |
client = hunyuan_client(request) | |
result = client.predict( | |
caption="", | |
image=handle_file(temp_file_path), | |
steps=50, | |
guidance_scale=5.5, | |
seed=1234, | |
octree_resolution="256", | |
check_box_rembg=rembg_Hunyuan, | |
api_name="/generation_all" | |
) | |
print(result) | |
return result[1] | |
def generate_3d_model2(image_pil, request: gr.Request): | |
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp_file: | |
image_pil.save(temp_file.name) | |
temp_file_path = temp_file.name | |
client = vFusion_client(request) | |
result = client.predict( | |
image=handle_file(temp_file_path), | |
api_name="/step_1_generate_obj" | |
) | |
print(result) | |
return result[0] | |
########## GRADIO ########## | |
with gr.Blocks() as demo: | |
gr.Markdown("# Анализ и редактирование помещений") | |
with gr.Tab("Сканирование"): | |
with gr.Row(equal_height=True): | |
with gr.Column(scale=5): | |
image_input = gr.Image(type="pil", label="Исходное изображение", height = 400) | |
segment_button = gr.Button("Сегментировать") | |
with gr.Column(scale=5): | |
segments_output = gr.Gallery(label="Сегменты изображения") | |
merge_segments_input = gr.Textbox(label="Сегменты для объединения (через точку с запятой, например: \"wall_0; tv_0\")") | |
merge_segments_button = gr.Button("Соединить сегменты") | |
merge_segments_button.click(merge_segments_by_labels, inputs=[segments_output, merge_segments_input], outputs=segments_output) | |
with gr.Row(equal_height=True): | |
segment_text_input = gr.Textbox(label="Имя сегмента для дальнейшего редактирования") | |
select_segment_button = gr.Button("Использовать сегмент") | |
with gr.Tab("Редактирование"): | |
with gr.Row(equal_height=True): | |
with gr.Column(scale=5): | |
segment_input = gr.ImageEditor(type="pil", label="Сегмент для редактирования") | |
with gr.Column(scale=5): | |
crop_button = gr.Button("Обрезать сегмент") | |
with gr.Row(equal_height=True): | |
upscale_slider = gr.Slider(minimum=1, maximum=5, value=2, step=0.1, label="во сколько раз") | |
upscale_button = gr.Button("Upscale") | |
rembg_button = gr.Button("Rembg") | |
remove_background_button = gr.Button("Убрать черный задний фон") | |
with gr.Row(equal_height=True): | |
add_transparent_border_slider = gr.Slider(minimum=10, maximum=500, value=200, step=10, label="в пикселях") | |
add_transparent_border_button = gr.Button("Добавить прозрачные края") | |
use_button = gr.Button("Использовать сегмент для 3D") | |
with gr.Tab("Создание 3D"): | |
with gr.Row(equal_height=True): | |
with gr.Column(scale=5): | |
segment_3d_input = gr.Image(type="pil", image_mode="RGBA", label="Сегмент для 3D", height = 600) | |
rembg_Hunyuan = gr.Checkbox(label="Hunyuan3D-2 rembg Enabled", info="Включить rembg для Hunyuan3D-2?") | |
hunyuan_button = gr.Button("Hunyuan3D-2 (no texture) [ZeroGPU = 100s]") | |
hunyuan_button_texture = gr.Button("Hunyuan3D-2 (with texture) [ZeroGPU = 150s]") | |
vFusion_button = gr.Button("VFusion3D [если у вас совсем все грустно по ZeroGPU]") | |
with gr.Column(scale=5): | |
trellis_output = gr.Model3D(label="3D Model") | |
#tab1 | |
segment_button.click(segment_image, inputs=image_input, outputs=segments_output) | |
select_segment_button.click(select_segment, inputs=[segments_output, segment_text_input], outputs=segment_input) | |
#tab2 | |
crop_button.click(autocrop_image, inputs=segment_input, outputs=segment_input) | |
upscale_button.click(upscale, inputs=[segment_input, upscale_slider], outputs=segment_input) | |
rembg_button.click(rembg, inputs=segment_input, outputs=segment_input) | |
remove_background_button.click(remove_black_make_transparent, inputs=segment_input, outputs=segment_input) | |
add_transparent_border_button.click(add_transparent_border, inputs=[segment_input, add_transparent_border_slider], outputs=segment_input) | |
use_button.click(return_image, inputs=segment_input, outputs=segment_3d_input) | |
#3d buttons | |
hunyuan_button.click(generate_3d_model, inputs=[segment_3d_input, rembg_Hunyuan], outputs=trellis_output) | |
hunyuan_button_texture.click(generate_3d_model_texture, inputs=[segment_3d_input, rembg_Hunyuan], outputs=trellis_output) | |
vFusion_button.click(generate_3d_model2, inputs=segment_3d_input, outputs=trellis_output) | |
demo.launch(debug=True, show_error=True) |