sanbo commited on
Commit
5f6fa69
·
1 Parent(s): d294fc6

update sth. at 2025-03-03 19:31:17

Browse files
Files changed (1) hide show
  1. app.py +8 -7
app.py CHANGED
@@ -4,11 +4,8 @@ import torch
4
  import gradio as gr
5
  from fastapi import FastAPI, HTTPException
6
  from fastapi.middleware.cors import CORSMiddleware
7
- from pydantic import BaseModel, Field, root_validator
8
- from typing import List, Dict, Optional
9
- from functools import lru_cache
10
- from threading import Lock
11
- import uvicorn
12
 
13
  class EmbeddingRequest(BaseModel):
14
  # 强制锁定模型参数
@@ -23,7 +20,8 @@ class EmbeddingRequest(BaseModel):
23
  prompt: Optional[str] = Field(None, description="输入文本(兼容Ollama格式)")
24
 
25
  # 自动合并输入字段
26
- @root_validator(pre=True)
 
27
  def merge_input_fields(cls, values):
28
  input_fields = ["inputs", "input", "prompt"]
29
  for field in input_fields:
@@ -36,10 +34,13 @@ class EmbeddingRequest(BaseModel):
36
 
37
  class EmbeddingResponse(BaseModel):
38
  object: str = "list"
39
- data: List
40
  model: str
41
  usage: Dict[str, int]
42
 
 
 
 
43
  class EmbeddingService:
44
  def __init__(self):
45
  self._true_model_name = "jinaai/jina-embeddings-v3" # 硬编码模型名称
 
4
  import gradio as gr
5
  from fastapi import FastAPI, HTTPException
6
  from fastapi.middleware.cors import CORSMiddleware
7
+ from pydantic import BaseModel, Field, model_validator
8
+ from typing import List, Dict, Optional, Any
 
 
 
9
 
10
  class EmbeddingRequest(BaseModel):
11
  # 强制锁定模型参数
 
20
  prompt: Optional[str] = Field(None, description="输入文本(兼容Ollama格式)")
21
 
22
  # 自动合并输入字段
23
+ @model_validator(mode='before')
24
+ @classmethod
25
  def merge_input_fields(cls, values):
26
  input_fields = ["inputs", "input", "prompt"]
27
  for field in input_fields:
 
34
 
35
  class EmbeddingResponse(BaseModel):
36
  object: str = "list"
37
+ data: List[Dict[str, Any]]
38
  model: str
39
  usage: Dict[str, int]
40
 
41
+ class Config:
42
+ arbitrary_types_allowed = True
43
+
44
  class EmbeddingService:
45
  def __init__(self):
46
  self._true_model_name = "jinaai/jina-embeddings-v3" # 硬编码模型名称