sanbo commited on
Commit
c064436
·
1 Parent(s): 068937b

update sth. at 2025-01-16 22:40:53

Browse files
Files changed (1) hide show
  1. app.py1 +0 -59
app.py1 DELETED
@@ -1,59 +0,0 @@
1
- from fastapi import FastAPI, HTTPException, Request
2
- from pydantic import BaseModel
3
- from transformers import AutoTokenizer, AutoModel
4
- import torch
5
- from typing import List, Dict
6
- import uvicorn
7
-
8
- # 定义请求和响应模型
9
- class EmbeddingRequest(BaseModel):
10
- input: str
11
- model: str = "jinaai/jina-embeddings-v3"
12
-
13
- class EmbeddingResponse(BaseModel):
14
- status: str
15
- embeddings: List[List[float]]
16
-
17
- # 创建FastAPI应用
18
- app = FastAPI(
19
- title="Jina Embeddings API",
20
- description="Text embedding generation service using jina-embeddings-v3",
21
- version="1.0.0"
22
- )
23
-
24
- # 加载模型和分词器
25
- model_name = "jinaai/jina-embeddings-v3"
26
- tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
27
- model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
28
-
29
- @app.post("/generate_embeddings", response_model=EmbeddingResponse)
30
- @app.post("/api/v1/embeddings", response_model=EmbeddingResponse)
31
- @app.post("/hf/v1/embeddings", response_model=EmbeddingResponse)
32
- @app.post("/api/v1/chat/completions", response_model=EmbeddingResponse)
33
- @app.post("/hf/v1/chat/completions", response_model=EmbeddingResponse)
34
- async def generate_embeddings(request: EmbeddingRequest):
35
- try:
36
- # 使用分词器处理输入文本
37
- inputs = tokenizer(request.input, return_tensors="pt", truncation=True, max_length=512)
38
-
39
- # 生成嵌入
40
- with torch.no_grad():
41
- embeddings = model(**inputs).last_hidden_state.mean(dim=1)
42
-
43
- return EmbeddingResponse(
44
- status="success",
45
- embeddings=embeddings.numpy().tolist()
46
- )
47
- except Exception as e:
48
- raise HTTPException(status_code=500, detail=str(e))
49
-
50
- @app.get("/")
51
- async def root():
52
- return {
53
- "status": "active",
54
- "model": model_name,
55
- "usage": "Send POST request to /generate_embeddings or /api/v1/embeddings or /hf/v1/embeddings"
56
- }
57
-
58
- if __name__ == "__main__":
59
- uvicorn.run(app, host="0.0.0.0", port=7860)