import asyncio import logging import torch import gradio as gr from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field, root_validator from typing import List, Dict, Optional from functools import lru_cache from threading import Lock import uvicorn class EmbeddingRequest(BaseModel): # 强制锁定模型参数 model: str = Field( default="jinaai/jina-embeddings-v3", description="此参数仅用于API兼容,实际模型固定为jinaai/jina-embeddings-v3", frozen=True # 禁止修改 ) # 支持三种输入字段 inputs: Optional[str] = Field(None, description="输入文本(兼容HuggingFace格式)") input: Optional[str] = Field(None, description="输入文本(兼容OpenAI格式)") prompt: Optional[str] = Field(None, description="输入文本(兼容Ollama格式)") # 自动合并输入字段 @root_validator(pre=True) def merge_input_fields(cls, values): input_fields = ["inputs", "input", "prompt"] for field in input_fields: if values.get(field): values["inputs"] = values[field] break else: raise ValueError("必须提供 inputs/input/prompt 任一字段") return values class EmbeddingResponse(BaseModel): status: str embeddings: List[List[float]] class EmbeddingService: def __init__(self): self._true_model_name = "jinaai/jina-embeddings-v3" # 硬编码模型名称 self.max_length = 512 self.device = torch.device("cpu") self.model = None self.tokenizer = None self.lock = Lock() self.setup_logging() torch.set_num_threads(4) # CPU优化 def setup_logging(self): logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s' ) self.logger = logging.getLogger(__name__) async def initialize(self): try: from transformers import AutoTokenizer, AutoModel self.tokenizer = AutoTokenizer.from_pretrained( self._true_model_name, trust_remote_code=True ) self.model = AutoModel.from_pretrained( self._true_model_name, trust_remote_code=True ).to(self.device) self.model.eval() torch.set_grad_enabled(False) self.logger.info(f"强制加载模型: {self._true_model_name}") except Exception as e: self.logger.error(f"模型初始化失败: {str(e)}") raise @lru_cache(maxsize=1000) def get_embedding(self, text: str) -> List[float]: with self.lock: try: inputs = self.tokenizer( text, return_tensors="pt", truncation=True, max_length=self.max_length, padding=True ) with torch.no_grad(): outputs = self.model(**inputs).last_hidden_state.mean(dim=1) return outputs.numpy().tolist()[0] except Exception as e: self.logger.error(f"生成嵌入向量失败: {str(e)}") raise embedding_service = EmbeddingService() app = FastAPI() app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @app.post("/embed", response_model=EmbeddingResponse) @app.post("/api/embeddings", response_model=EmbeddingResponse) @app.post("/api/embed", response_model=EmbeddingResponse) @app.post("/v1/embeddings", response_model=EmbeddingResponse) @app.post("/generate_embeddings", response_model=EmbeddingResponse) @app.post("/api/v1/embeddings", response_model=EmbeddingResponse) @app.post("/hf/v1/embeddings", response_model=EmbeddingResponse) @app.post("/api/v1/chat/completions", response_model=EmbeddingResponse) @app.post("/hf/v1/chat/completions", response_model=EmbeddingResponse) async def generate_embeddings(request: EmbeddingRequest): try: embedding = await asyncio.get_running_loop().run_in_executor( None, embedding_service.get_embedding, request.inputs # 使用合并后的输入字段 ) return EmbeddingResponse( status="success", embeddings=[embedding] ) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.get("/") async def root(): return { "status": "active", "true_model": embedding_service._true_model_name, "device": str(embedding_service.device) } def gradio_interface(text: str) -> Dict: try: embedding = embedding_service.get_embedding(text) return { "status": "success", "embeddings": [embedding] } except Exception as e: return { "status": "error", "message": str(e) } iface = gr.Interface( fn=gradio_interface, inputs=gr.Textbox(lines=3, label="输入文本"), outputs=gr.JSON(label="嵌入向量结果"), title="Jina Embeddings V3", description="强制使用jinaai/jina-embeddings-v3模型(无视请求中的model参数)", examples=[[ "Represent this sentence for searching relevant passages: " "The sky is blue because of Rayleigh scattering" ]] ) @app.on_event("startup") async def startup_event(): await embedding_service.initialize() if __name__ == "__main__": asyncio.run(embedding_service.initialize()) gr.mount_gradio_app(app, iface, path="/ui") uvicorn.run(app, host="0.0.0.0", port=7860, workers=1)