|
from fastapi import FastAPI, Request |
|
from requests import Response |
|
from langchain_community.embeddings.ollama import OllamaEmbeddings |
|
import numpy as np |
|
import json |
|
|
|
app = FastAPI( |
|
title="FetishTest", |
|
description="Game of matching fetish of users", |
|
) |
|
|
|
|
|
embed_model = OllamaEmbeddings(model="bge-m3") |
|
with open("standard_character.json", "r") as f: |
|
standard_character = json.load(f) |
|
with open("Q&A.json", "r") as f: |
|
answer2label = json.load(f) |
|
|
|
|
|
def cosine_similarity(vec1, vec2): |
|
vec1, vec2 = np.array(vec1), np.array(vec2) |
|
return vec1.dot(vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2)) |
|
|
|
@app.post("/fetish") |
|
async def matching(request: Request): |
|
request = await request.json() |
|
answer = request["answer"] |
|
print(f"user_input: {answer}") |
|
user_labels = [] |
|
for idx, ans in enumerate(answer): |
|
curr_label = answer2label[idx][ans] |
|
if curr_label not in user_labels: |
|
user_labels.append(answer2label[idx][ans]) |
|
user_labels = sorted(user_labels) |
|
print(f"user_labels: {user_labels}") |
|
user_embedding = embed_model.embed_query(" ".join(user_labels)) |
|
matching_dict = {} |
|
for character in standard_character: |
|
sim = cosine_similarity(user_embedding, character["embedding"]) |
|
matching_dict[character["key"]] = sim |
|
|
|
matching_tuple = sorted(matching_dict.items(), key=lambda x: x[1], reverse=True) |
|
|
|
matched = matching_tuple[0][0] |
|
sim = matching_tuple[0][1] |
|
matched_name = standard_character[matched]["name"] |
|
matched_label = standard_character[matched]["label"] |
|
result = { |
|
"result": matched |
|
} |
|
print(f"matched: {matched}") |
|
print(f"{matched_name}: {matched_label} -- score: {sim}") |
|
return result |
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
uvicorn.run( |
|
"api:app", |
|
host="0.0.0.0", |
|
port=8002, |
|
loop="asyncio", |
|
workers=8, |
|
limit_concurrency=10, |
|
timeout_keep_alive=60, |
|
access_log=True |
|
) |