jonathanjordan21's picture
Update app.py
197e027 verified
raw
history blame
1.28 kB
from fastapi import FastAPI
import numpy as np
from sentence_transformers import CrossEncoder
from typing import List
from pydantic import BaseModel
app = FastAPI()
class InputListModel(BaseModel):
keywords: List[str]
contents: List[str]
class InputModel(BaseModel):
keyword: str
content: str
model = CrossEncoder(
# "jinaai/jina-reranker-v2-base-multilingual",
"Alibaba-NLP/gte-multilingual-reranker-base",
trust_remote_code=True,
)
@app.get("/")
def greet_json():
return {"Hello": "World!"}
@app.post("/predict_list")
async def predict_list(inp : InputListModel):
sentence_pairs = [[query, doc] for query,doc in zip(inp.keywords, inp.contents)]
scores = model.predict(sentence_pairs, convert_to_tensor=False)#.tolist()
# (-scores).argsort().tolist()
return {"results":scores.tolist()}
@app.post("/predict")
async def predict(inp : InputModel):
sentence_pairs = [[inp.keyword, inp.content]]
scores = model.predict(sentence_pairs, convert_to_tensor=False)#.tolist()
# (-scores).argsort().tolist()
return {"results":scores.tolist()[0]}
# keywords = model.encode(inp.keywords)
# contents = model.encode(inp.contents)
# return {"results":np.linalg.norm(contents-keywords).tolist()}