File size: 2,083 Bytes
318db6e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from fastapi import FastAPI, Request
from requests import Response
from langchain_community.embeddings.ollama import OllamaEmbeddings
import numpy as np
import json

app = FastAPI(
    title="FetishTest",
    description="Game of matching fetish of users",
)

# set up
embed_model = OllamaEmbeddings(model="bge-m3")
with open("standard_character.json", "r") as f:
    standard_character = json.load(f)
with open("Q&A.json", "r") as f:
    answer2label = json.load(f)

# cosine similarity
def cosine_similarity(vec1, vec2):
    vec1, vec2 = np.array(vec1), np.array(vec2)
    return vec1.dot(vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))

@app.post("/fetish")
async def matching(request: Request):
    request = await request.json()
    answer = request["answer"] # ['A', 'B' ...]
    print(f"user_input: {answer}")
    user_labels = []
    for idx, ans in enumerate(answer):
        curr_label = answer2label[idx][ans]
        if curr_label not in user_labels:
            user_labels.append(answer2label[idx][ans])
    user_labels = sorted(user_labels)
    print(f"user_labels: {user_labels}")
    user_embedding = embed_model.embed_query(" ".join(user_labels))
    matching_dict = {}
    for character in standard_character:
        sim = cosine_similarity(user_embedding, character["embedding"])
        matching_dict[character["key"]] = sim
    # sort by sim
    matching_tuple = sorted(matching_dict.items(), key=lambda x: x[1], reverse=True)
    # return the top 1
    matched = matching_tuple[0][0]
    sim = matching_tuple[0][1]
    matched_name = standard_character[matched]["name"]
    matched_label = standard_character[matched]["label"]
    result = {
        "result": matched
    }
    print(f"matched: {matched}")
    print(f"{matched_name}: {matched_label} -- score: {sim}")
    return result
    
if __name__ == "__main__":
    import uvicorn
    uvicorn.run(
        "api:app", 
        host="0.0.0.0", 
        port=8002, 
        loop="asyncio",
        workers=8,
        limit_concurrency=10,
        timeout_keep_alive=60,
        access_log=True
    )