Spaces:
Sleeping
Sleeping
import gradio as gr | |
from fastapi import FastAPI, HTTPException | |
from fastapi.middleware.cors import CORSMiddleware | |
from transformers import AutoTokenizer, AutoModel | |
import torch | |
import numpy as np | |
from pydantic import BaseModel | |
from typing import List, Dict, Any | |
import time | |
# 创建 FastAPI 应用 | |
app = FastAPI() | |
# 配置 CORS | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# 加载模型和分词器 | |
model_name = "BAAI/bge-m3" | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
model = AutoModel.from_pretrained(model_name) | |
model.eval() | |
# OpenAI 兼容的请求模型 | |
class EmbeddingRequest(BaseModel): | |
input: List[str] | str | |
model: str | None = model_name | |
encoding_format: str | None = "float" | |
user: str | None = None | |
# OpenAI 兼容的响应模型 | |
class EmbeddingResponse(BaseModel): | |
object: str = "list" | |
data: List[Dict[str, Any]] | |
model: str | |
usage: Dict[str, int] | |
def get_embedding(text: str) -> List[float]: | |
inputs = tokenizer( | |
text, | |
padding=True, | |
truncation=True, | |
max_length=512, | |
return_tensors="pt" | |
) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
embeddings = outputs.last_hidden_state[:, 0, :].numpy() | |
return embeddings[0].tolist() | |
# OpenAI 兼容的 embeddings endpoint | |
async def create_embeddings(request: EmbeddingRequest): | |
start_time = time.time() | |
# 处理输入 | |
if isinstance(request.input, str): | |
input_texts = [request.input] | |
else: | |
input_texts = request.input | |
# 获取嵌入向量 | |
embeddings = [] | |
total_tokens = 0 | |
for text in input_texts: | |
# 计算 token 数量 | |
tokens = tokenizer.encode(text) | |
total_tokens += len(tokens) | |
# 获取嵌入向量 | |
embedding = get_embedding(text) | |
embeddings.append({ | |
"object": "embedding", | |
"embedding": embedding, | |
"index": len(embeddings) | |
}) | |
response = EmbeddingResponse( | |
data=embeddings, | |
model=request.model or model_name, | |
usage={ | |
"prompt_tokens": total_tokens, | |
"total_tokens": total_tokens | |
} | |
) | |
return response | |
# Gradio 界面 | |
def gradio_embedding(text: str) -> Dict: | |
# 创建与 OpenAI 兼容的请求 | |
request = EmbeddingRequest(input=text) | |
# 调用 API 处理函数 | |
response = create_embeddings(request) | |
return response.dict() | |
# 创建 Gradio 界面 | |
demo = gr.Interface( | |
fn=gradio_embedding, | |
inputs=gr.Textbox(lines=3, placeholder="输入要进行编码的文本..."), | |
outputs=gr.Json(), | |
title="BGE-M3 Embeddings (OpenAI 兼容格式)", | |
description="输入文本,获取其对应的嵌入向量,返回格式与 OpenAI API 兼容。", | |
examples=[ | |
["这是一个示例文本。"], | |
["人工智能正在改变世界。"] | |
] | |
) | |
# 启动服务 | |
if __name__ == "__main__": | |
import uvicorn | |
# 首先启动 Gradio | |
demo.queue() | |
# 然后启动 FastAPI | |
config = uvicorn.Config( | |
app=app, | |
host="0.0.0.0", | |
port=7860, | |
log_level="info" | |
) | |
server = uvicorn.Server(config) | |
server.run() |