QijiYuntai2.0 / app.py
Reduxxxx's picture
Update app.py
9180c95 verified
raw
history blame
1.69 kB
# app.py
import gradio as gr
from transformers import AutoModel
import torch
import numpy as np
from PIL import Image
def load_model():
# 加载模型
model = AutoModel.from_pretrained("jadechoghari/vfusion3d", trust_remote_code=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
return model
def process_image(input_image):
try:
# 确保输入图像是PIL Image格式
if not isinstance(input_image, Image.Image):
input_image = Image.fromarray(input_image)
# 加载模型
model = load_model()
# 图像预处理
input_image = input_image.resize((256, 256))
# 转换为tensor
image_tensor = torch.from_numpy(np.array(input_image)).float()
image_tensor = image_tensor.permute(2, 0, 1).unsqueeze(0)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
image_tensor = image_tensor.to(device)
# 模型推理
with torch.no_grad():
output = model(image_tensor)
return output
except Exception as e:
return f"错误: {str(e)}"
# 创建Gradio界面
demo = gr.Interface(
fn=process_image,
inputs=[
gr.Image(type="pil", label="上传图片")
],
outputs=[
gr.Model3D(label="生成的3D模型"),
gr.Text(label="处理状态")
],
title="麒迹云台 - 2D转3D模型生成器",
description="上传一张图片,AI将自动生成对应的3D模型。支持格式:jpg, png, jpeg",
theme=gr.themes.Soft()
)
if __name__ == "__main__":
demo.launch()