Spaces:
Running
Running
sanbo
commited on
Commit
·
124ac36
1
Parent(s):
5331238
update sth. at 2025-01-16 22:25:39
Browse files
app.py
CHANGED
@@ -8,7 +8,6 @@ from pydantic import BaseModel
|
|
8 |
from typing import List, Dict
|
9 |
from functools import lru_cache
|
10 |
import uvicorn
|
11 |
-
import psutil
|
12 |
import numpy as np
|
13 |
|
14 |
class EmbeddingRequest(BaseModel):
|
@@ -23,15 +22,12 @@ class EmbeddingService:
|
|
23 |
def __init__(self):
|
24 |
self.model_name = "jinaai/jina-embeddings-v3"
|
25 |
self.max_length = 512
|
26 |
-
self.batch_size = 8
|
27 |
self.device = torch.device("cpu")
|
28 |
-
self.num_threads = min(psutil.cpu_count(), 4) # 限制CPU线程数
|
29 |
self.model = None
|
30 |
self.tokenizer = None
|
31 |
self.setup_logging()
|
32 |
-
|
33 |
-
|
34 |
-
torch.set_num_threads(self.num_threads)
|
35 |
|
36 |
def setup_logging(self):
|
37 |
logging.basicConfig(
|
@@ -49,46 +45,49 @@ class EmbeddingService:
|
|
49 |
)
|
50 |
self.model = AutoModel.from_pretrained(
|
51 |
self.model_name,
|
52 |
-
trust_remote_code=True
|
53 |
-
torch_dtype=torch.float32 # CPU使用float32
|
54 |
).to(self.device)
|
55 |
-
|
56 |
self.model.eval()
|
57 |
torch.set_grad_enabled(False)
|
58 |
-
self.logger.info(f"
|
59 |
except Exception as e:
|
60 |
self.logger.error(f"模型初始化失败: {str(e)}")
|
61 |
raise
|
62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
@lru_cache(maxsize=1000)
|
64 |
-
|
|
|
|
|
65 |
try:
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
truncation=True,
|
70 |
-
max_length=self.max_length,
|
71 |
-
padding=True
|
72 |
-
)
|
73 |
|
74 |
-
|
75 |
-
|
76 |
-
return outputs.numpy().tolist()[0]
|
77 |
-
except Exception as e:
|
78 |
-
self.logger.error(f"生成嵌入向量失败: {str(e)}")
|
79 |
-
raise
|
80 |
-
|
81 |
-
# FastAPI应用初始化
|
82 |
app = FastAPI(
|
83 |
title="Jina Embeddings API",
|
84 |
description="Text embedding generation service using jina-embeddings-v3",
|
85 |
version="1.0.0"
|
86 |
)
|
87 |
|
88 |
-
# 初始化服务
|
89 |
-
embedding_service = EmbeddingService()
|
90 |
-
|
91 |
-
# CORS配置
|
92 |
app.add_middleware(
|
93 |
CORSMiddleware,
|
94 |
allow_origins=["*"],
|
@@ -97,7 +96,6 @@ app.add_middleware(
|
|
97 |
allow_headers=["*"],
|
98 |
)
|
99 |
|
100 |
-
# API端点
|
101 |
@app.post("/generate_embeddings", response_model=EmbeddingResponse)
|
102 |
@app.post("/api/v1/embeddings", response_model=EmbeddingResponse)
|
103 |
@app.post("/hf/v1/embeddings", response_model=EmbeddingResponse)
|
@@ -105,11 +103,13 @@ app.add_middleware(
|
|
105 |
@app.post("/hf/v1/chat/completions", response_model=EmbeddingResponse)
|
106 |
async def generate_embeddings(request: EmbeddingRequest):
|
107 |
try:
|
108 |
-
embedding =
|
109 |
return EmbeddingResponse(
|
110 |
status="success",
|
111 |
embeddings=[embedding]
|
112 |
)
|
|
|
|
|
113 |
except Exception as e:
|
114 |
raise HTTPException(status_code=500, detail=str(e))
|
115 |
|
@@ -118,14 +118,13 @@ async def root():
|
|
118 |
return {
|
119 |
"status": "active",
|
120 |
"model": embedding_service.model_name,
|
121 |
-
"device": str(embedding_service.device)
|
122 |
-
"cpu_threads": embedding_service.num_threads
|
123 |
}
|
124 |
|
125 |
# Gradio界面
|
126 |
def gradio_interface(text: str) -> Dict:
|
127 |
try:
|
128 |
-
embedding =
|
129 |
return {
|
130 |
"status": "success",
|
131 |
"embeddings": [embedding]
|
@@ -150,17 +149,11 @@ async def startup_event():
|
|
150 |
await embedding_service.initialize()
|
151 |
|
152 |
if __name__ == "__main__":
|
153 |
-
# 初始化服务
|
154 |
asyncio.run(embedding_service.initialize())
|
155 |
-
|
156 |
-
# 挂载Gradio应用
|
157 |
gr.mount_gradio_app(app, iface, path="/ui")
|
158 |
-
|
159 |
-
# 启动服务
|
160 |
uvicorn.run(
|
161 |
app,
|
162 |
host="0.0.0.0",
|
163 |
port=7860,
|
164 |
-
workers=1
|
165 |
-
loop="asyncio"
|
166 |
)
|
|
|
8 |
from typing import List, Dict
|
9 |
from functools import lru_cache
|
10 |
import uvicorn
|
|
|
11 |
import numpy as np
|
12 |
|
13 |
class EmbeddingRequest(BaseModel):
|
|
|
22 |
def __init__(self):
|
23 |
self.model_name = "jinaai/jina-embeddings-v3"
|
24 |
self.max_length = 512
|
|
|
25 |
self.device = torch.device("cpu")
|
|
|
26 |
self.model = None
|
27 |
self.tokenizer = None
|
28 |
self.setup_logging()
|
29 |
+
# CPU优化
|
30 |
+
torch.set_num_threads(4)
|
|
|
31 |
|
32 |
def setup_logging(self):
|
33 |
logging.basicConfig(
|
|
|
45 |
)
|
46 |
self.model = AutoModel.from_pretrained(
|
47 |
self.model_name,
|
48 |
+
trust_remote_code=True
|
|
|
49 |
).to(self.device)
|
|
|
50 |
self.model.eval()
|
51 |
torch.set_grad_enabled(False)
|
52 |
+
self.logger.info(f"模型加载成功,使用设备: {self.device}")
|
53 |
except Exception as e:
|
54 |
self.logger.error(f"模型初始化失败: {str(e)}")
|
55 |
raise
|
56 |
|
57 |
+
async def _generate_embedding_internal(self, text: str) -> List[float]:
|
58 |
+
"""内部嵌入生成函数"""
|
59 |
+
if not text.strip():
|
60 |
+
raise ValueError("输入文本不能为空")
|
61 |
+
|
62 |
+
inputs = self.tokenizer(
|
63 |
+
text,
|
64 |
+
return_tensors="pt",
|
65 |
+
truncation=True,
|
66 |
+
max_length=self.max_length,
|
67 |
+
padding=True
|
68 |
+
)
|
69 |
+
|
70 |
+
with torch.no_grad():
|
71 |
+
outputs = self.model(**inputs).last_hidden_state.mean(dim=1)
|
72 |
+
return outputs.numpy().tolist()[0]
|
73 |
+
|
74 |
@lru_cache(maxsize=1000)
|
75 |
+
def get_cached_embedding(self, text: str) -> List[float]:
|
76 |
+
"""缓存包装函数"""
|
77 |
+
loop = asyncio.new_event_loop()
|
78 |
try:
|
79 |
+
return loop.run_until_complete(self._generate_embedding_internal(text))
|
80 |
+
finally:
|
81 |
+
loop.close()
|
|
|
|
|
|
|
|
|
82 |
|
83 |
+
# 初始化服务
|
84 |
+
embedding_service = EmbeddingService()
|
|
|
|
|
|
|
|
|
|
|
|
|
85 |
app = FastAPI(
|
86 |
title="Jina Embeddings API",
|
87 |
description="Text embedding generation service using jina-embeddings-v3",
|
88 |
version="1.0.0"
|
89 |
)
|
90 |
|
|
|
|
|
|
|
|
|
91 |
app.add_middleware(
|
92 |
CORSMiddleware,
|
93 |
allow_origins=["*"],
|
|
|
96 |
allow_headers=["*"],
|
97 |
)
|
98 |
|
|
|
99 |
@app.post("/generate_embeddings", response_model=EmbeddingResponse)
|
100 |
@app.post("/api/v1/embeddings", response_model=EmbeddingResponse)
|
101 |
@app.post("/hf/v1/embeddings", response_model=EmbeddingResponse)
|
|
|
103 |
@app.post("/hf/v1/chat/completions", response_model=EmbeddingResponse)
|
104 |
async def generate_embeddings(request: EmbeddingRequest):
|
105 |
try:
|
106 |
+
embedding = embedding_service.get_cached_embedding(request.input)
|
107 |
return EmbeddingResponse(
|
108 |
status="success",
|
109 |
embeddings=[embedding]
|
110 |
)
|
111 |
+
except ValueError as e:
|
112 |
+
raise HTTPException(status_code=400, detail=str(e))
|
113 |
except Exception as e:
|
114 |
raise HTTPException(status_code=500, detail=str(e))
|
115 |
|
|
|
118 |
return {
|
119 |
"status": "active",
|
120 |
"model": embedding_service.model_name,
|
121 |
+
"device": str(embedding_service.device)
|
|
|
122 |
}
|
123 |
|
124 |
# Gradio界面
|
125 |
def gradio_interface(text: str) -> Dict:
|
126 |
try:
|
127 |
+
embedding = embedding_service.get_cached_embedding(text)
|
128 |
return {
|
129 |
"status": "success",
|
130 |
"embeddings": [embedding]
|
|
|
149 |
await embedding_service.initialize()
|
150 |
|
151 |
if __name__ == "__main__":
|
|
|
152 |
asyncio.run(embedding_service.initialize())
|
|
|
|
|
153 |
gr.mount_gradio_app(app, iface, path="/ui")
|
|
|
|
|
154 |
uvicorn.run(
|
155 |
app,
|
156 |
host="0.0.0.0",
|
157 |
port=7860,
|
158 |
+
workers=1
|
|
|
159 |
)
|