sanbo commited on
Commit
124ac36
·
1 Parent(s): 5331238

update sth. at 2025-01-16 22:25:39

Browse files
Files changed (1) hide show
  1. app.py +35 -42
app.py CHANGED
@@ -8,7 +8,6 @@ from pydantic import BaseModel
8
  from typing import List, Dict
9
  from functools import lru_cache
10
  import uvicorn
11
- import psutil
12
  import numpy as np
13
 
14
  class EmbeddingRequest(BaseModel):
@@ -23,15 +22,12 @@ class EmbeddingService:
23
  def __init__(self):
24
  self.model_name = "jinaai/jina-embeddings-v3"
25
  self.max_length = 512
26
- self.batch_size = 8
27
  self.device = torch.device("cpu")
28
- self.num_threads = min(psutil.cpu_count(), 4) # 限制CPU线程数
29
  self.model = None
30
  self.tokenizer = None
31
  self.setup_logging()
32
-
33
- # CPU优化配置
34
- torch.set_num_threads(self.num_threads)
35
 
36
  def setup_logging(self):
37
  logging.basicConfig(
@@ -49,46 +45,49 @@ class EmbeddingService:
49
  )
50
  self.model = AutoModel.from_pretrained(
51
  self.model_name,
52
- trust_remote_code=True,
53
- torch_dtype=torch.float32 # CPU使用float32
54
  ).to(self.device)
55
-
56
  self.model.eval()
57
  torch.set_grad_enabled(False)
58
- self.logger.info(f"模型加载成功,CPU线程数: {self.num_threads}")
59
  except Exception as e:
60
  self.logger.error(f"模型初始化失败: {str(e)}")
61
  raise
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  @lru_cache(maxsize=1000)
64
- async def generate_embedding(self, text: str) -> List[float]:
 
 
65
  try:
66
- inputs = self.tokenizer(
67
- text,
68
- return_tensors="pt",
69
- truncation=True,
70
- max_length=self.max_length,
71
- padding=True
72
- )
73
 
74
- with torch.no_grad():
75
- outputs = self.model(**inputs).last_hidden_state.mean(dim=1)
76
- return outputs.numpy().tolist()[0]
77
- except Exception as e:
78
- self.logger.error(f"生成嵌入向量失败: {str(e)}")
79
- raise
80
-
81
- # FastAPI应用初始化
82
  app = FastAPI(
83
  title="Jina Embeddings API",
84
  description="Text embedding generation service using jina-embeddings-v3",
85
  version="1.0.0"
86
  )
87
 
88
- # 初始化服务
89
- embedding_service = EmbeddingService()
90
-
91
- # CORS配置
92
  app.add_middleware(
93
  CORSMiddleware,
94
  allow_origins=["*"],
@@ -97,7 +96,6 @@ app.add_middleware(
97
  allow_headers=["*"],
98
  )
99
 
100
- # API端点
101
  @app.post("/generate_embeddings", response_model=EmbeddingResponse)
102
  @app.post("/api/v1/embeddings", response_model=EmbeddingResponse)
103
  @app.post("/hf/v1/embeddings", response_model=EmbeddingResponse)
@@ -105,11 +103,13 @@ app.add_middleware(
105
  @app.post("/hf/v1/chat/completions", response_model=EmbeddingResponse)
106
  async def generate_embeddings(request: EmbeddingRequest):
107
  try:
108
- embedding = await embedding_service.generate_embedding(request.input)
109
  return EmbeddingResponse(
110
  status="success",
111
  embeddings=[embedding]
112
  )
 
 
113
  except Exception as e:
114
  raise HTTPException(status_code=500, detail=str(e))
115
 
@@ -118,14 +118,13 @@ async def root():
118
  return {
119
  "status": "active",
120
  "model": embedding_service.model_name,
121
- "device": str(embedding_service.device),
122
- "cpu_threads": embedding_service.num_threads
123
  }
124
 
125
  # Gradio界面
126
  def gradio_interface(text: str) -> Dict:
127
  try:
128
- embedding = asyncio.run(embedding_service.generate_embedding(text))
129
  return {
130
  "status": "success",
131
  "embeddings": [embedding]
@@ -150,17 +149,11 @@ async def startup_event():
150
  await embedding_service.initialize()
151
 
152
  if __name__ == "__main__":
153
- # 初始化服务
154
  asyncio.run(embedding_service.initialize())
155
-
156
- # 挂载Gradio应用
157
  gr.mount_gradio_app(app, iface, path="/ui")
158
-
159
- # 启动服务
160
  uvicorn.run(
161
  app,
162
  host="0.0.0.0",
163
  port=7860,
164
- workers=1,
165
- loop="asyncio"
166
  )
 
8
  from typing import List, Dict
9
  from functools import lru_cache
10
  import uvicorn
 
11
  import numpy as np
12
 
13
  class EmbeddingRequest(BaseModel):
 
22
  def __init__(self):
23
  self.model_name = "jinaai/jina-embeddings-v3"
24
  self.max_length = 512
 
25
  self.device = torch.device("cpu")
 
26
  self.model = None
27
  self.tokenizer = None
28
  self.setup_logging()
29
+ # CPU优化
30
+ torch.set_num_threads(4)
 
31
 
32
  def setup_logging(self):
33
  logging.basicConfig(
 
45
  )
46
  self.model = AutoModel.from_pretrained(
47
  self.model_name,
48
+ trust_remote_code=True
 
49
  ).to(self.device)
 
50
  self.model.eval()
51
  torch.set_grad_enabled(False)
52
+ self.logger.info(f"模型加载成功,使用设备: {self.device}")
53
  except Exception as e:
54
  self.logger.error(f"模型初始化失败: {str(e)}")
55
  raise
56
 
57
+ async def _generate_embedding_internal(self, text: str) -> List[float]:
58
+ """内部嵌入生成函数"""
59
+ if not text.strip():
60
+ raise ValueError("输入文本不能为空")
61
+
62
+ inputs = self.tokenizer(
63
+ text,
64
+ return_tensors="pt",
65
+ truncation=True,
66
+ max_length=self.max_length,
67
+ padding=True
68
+ )
69
+
70
+ with torch.no_grad():
71
+ outputs = self.model(**inputs).last_hidden_state.mean(dim=1)
72
+ return outputs.numpy().tolist()[0]
73
+
74
  @lru_cache(maxsize=1000)
75
+ def get_cached_embedding(self, text: str) -> List[float]:
76
+ """缓存包装函数"""
77
+ loop = asyncio.new_event_loop()
78
  try:
79
+ return loop.run_until_complete(self._generate_embedding_internal(text))
80
+ finally:
81
+ loop.close()
 
 
 
 
82
 
83
+ # 初始化服务
84
+ embedding_service = EmbeddingService()
 
 
 
 
 
 
85
  app = FastAPI(
86
  title="Jina Embeddings API",
87
  description="Text embedding generation service using jina-embeddings-v3",
88
  version="1.0.0"
89
  )
90
 
 
 
 
 
91
  app.add_middleware(
92
  CORSMiddleware,
93
  allow_origins=["*"],
 
96
  allow_headers=["*"],
97
  )
98
 
 
99
  @app.post("/generate_embeddings", response_model=EmbeddingResponse)
100
  @app.post("/api/v1/embeddings", response_model=EmbeddingResponse)
101
  @app.post("/hf/v1/embeddings", response_model=EmbeddingResponse)
 
103
  @app.post("/hf/v1/chat/completions", response_model=EmbeddingResponse)
104
  async def generate_embeddings(request: EmbeddingRequest):
105
  try:
106
+ embedding = embedding_service.get_cached_embedding(request.input)
107
  return EmbeddingResponse(
108
  status="success",
109
  embeddings=[embedding]
110
  )
111
+ except ValueError as e:
112
+ raise HTTPException(status_code=400, detail=str(e))
113
  except Exception as e:
114
  raise HTTPException(status_code=500, detail=str(e))
115
 
 
118
  return {
119
  "status": "active",
120
  "model": embedding_service.model_name,
121
+ "device": str(embedding_service.device)
 
122
  }
123
 
124
  # Gradio界面
125
  def gradio_interface(text: str) -> Dict:
126
  try:
127
+ embedding = embedding_service.get_cached_embedding(text)
128
  return {
129
  "status": "success",
130
  "embeddings": [embedding]
 
149
  await embedding_service.initialize()
150
 
151
  if __name__ == "__main__":
 
152
  asyncio.run(embedding_service.initialize())
 
 
153
  gr.mount_gradio_app(app, iface, path="/ui")
 
 
154
  uvicorn.run(
155
  app,
156
  host="0.0.0.0",
157
  port=7860,
158
+ workers=1
 
159
  )