Update app.py
Browse files
app.py
CHANGED
@@ -4,105 +4,78 @@ from fastapi import FastAPI, Query
|
|
4 |
from fastapi.responses import JSONResponse
|
5 |
from datasets import load_dataset
|
6 |
from sentence_transformers import SentenceTransformer
|
|
|
7 |
|
8 |
app = FastAPI()
|
9 |
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
"open_issues",
|
15 |
-
"stargazers_count",
|
16 |
-
"
|
17 |
-
"
|
18 |
-
|
19 |
-
"size",
|
20 |
-
"fork",
|
21 |
-
"updated_at",
|
22 |
-
"has_build_zig",
|
23 |
-
"has_build_zig_zon",
|
24 |
-
"created_at",
|
25 |
-
)
|
26 |
|
27 |
-
|
|
|
28 |
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
(
|
34 |
-
|
35 |
-
|
36 |
-
]
|
37 |
-
if not include_readme:
|
38 |
-
dataset = [{k: v for k, v in item.items() if k != "readme_content"} for item in dataset]
|
39 |
-
return dataset, repo_texts
|
40 |
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
|
|
45 |
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
index
|
50 |
-
index.add(np.array(repo_embeddings))
|
51 |
-
indices[key] = (index, dataset)
|
52 |
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
|
58 |
-
def
|
59 |
-
"""
|
60 |
-
Only return results that are likely relevant (distance-based filtering).
|
61 |
-
Lower distance = more similar.
|
62 |
-
Threshold is a fraction of the *minimum* distance found.
|
63 |
-
"""
|
64 |
if len(distances) == 0:
|
65 |
return []
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
dataset[int(i)]
|
70 |
-
for d, i in zip(distances, idxs)
|
71 |
-
if d <= cutoff
|
72 |
-
]
|
73 |
-
return filtered[:max_results]
|
74 |
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
return JSONResponse(content=content, headers=headers)
|
81 |
|
82 |
-
|
83 |
-
def infinite_scroll_programs(q: int = Query(0, ge=0)):
|
84 |
-
start = q * 10
|
85 |
-
content = scroll_data["infiniteScrollPrograms"][start : start + 10]
|
86 |
-
headers = {"Access-Control-Allow-Origin": "*", "Content-Type": "application/json"}
|
87 |
-
return JSONResponse(content=content, headers=headers)
|
88 |
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
|
|
|
|
|
|
|
|
99 |
|
100 |
-
|
101 |
-
|
102 |
-
key = "programs"
|
103 |
-
index, dataset = indices[key]
|
104 |
-
query_embedding = model.encode([q])
|
105 |
-
distances, idxs = index.search(np.array(query_embedding), len(dataset))
|
106 |
-
results = filter_results_by_distance(distances[0], idxs[0], dataset)
|
107 |
headers = {"Access-Control-Allow-Origin": "*", "Content-Type": "application/json"}
|
108 |
return JSONResponse(content=results, headers=headers)
|
|
|
4 |
from fastapi.responses import JSONResponse
|
5 |
from datasets import load_dataset
|
6 |
from sentence_transformers import SentenceTransformer
|
7 |
+
from typing import List, Dict
|
8 |
|
9 |
app = FastAPI()
|
10 |
|
11 |
+
# Constants
|
12 |
+
EMBEDDING_MODEL_NAME = "all-MiniLM-L6-v2"
|
13 |
+
SEARCH_THRESHOLD_RATIO = 0.3
|
14 |
+
FIELDS_TO_INCLUDE = [
|
15 |
+
"full_name", "description", "default_branch", "open_issues",
|
16 |
+
"stargazers_count", "forks_count", "watchers_count", "license",
|
17 |
+
"size", "fork", "updated_at", "has_build_zig",
|
18 |
+
"has_build_zig_zon", "created_at"
|
19 |
+
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
+
# Load embedding model
|
22 |
+
model = SentenceTransformer(EMBEDDING_MODEL_NAME)
|
23 |
|
24 |
+
# Helper functions
|
25 |
+
def prepare_text(entry: Dict, include_readme: bool = True) -> str:
|
26 |
+
parts = [str(entry.get(field, "")) for field in FIELDS_TO_INCLUDE]
|
27 |
+
if include_readme:
|
28 |
+
parts.append(entry.get("readme_content", ""))
|
29 |
+
parts.extend(entry.get("topics", []))
|
30 |
+
return " ".join(parts)
|
|
|
|
|
|
|
|
|
31 |
|
32 |
+
def load_and_encode_dataset(name: str, include_readme: bool = True):
|
33 |
+
raw_dataset = load_dataset(name)["train"]
|
34 |
+
texts = [prepare_text(item, include_readme) for item in raw_dataset]
|
35 |
+
embeddings = model.encode(texts)
|
36 |
+
return raw_dataset, np.array(embeddings)
|
37 |
|
38 |
+
def build_faiss_index(embeddings: np.ndarray) -> faiss.IndexFlatL2:
|
39 |
+
index = faiss.IndexFlatL2(embeddings.shape[1])
|
40 |
+
index.add(embeddings)
|
41 |
+
return index
|
|
|
|
|
42 |
|
43 |
+
def search_index(index: faiss.IndexFlatL2, query: str, embeddings: np.ndarray, dataset: List[Dict]) -> List[Dict]:
|
44 |
+
query_vector = model.encode([query])
|
45 |
+
distances, indices = index.search(np.array(query_vector), len(dataset))
|
46 |
+
return filter_by_distance(distances[0], indices[0], dataset)
|
47 |
|
48 |
+
def filter_by_distance(distances: np.ndarray, indices: np.ndarray, dataset: List[Dict], ratio: float = SEARCH_THRESHOLD_RATIO) -> List[Dict]:
|
|
|
|
|
|
|
|
|
|
|
49 |
if len(distances) == 0:
|
50 |
return []
|
51 |
+
min_d, max_d = np.min(distances), np.max(distances)
|
52 |
+
threshold = min_d + (max_d - min_d) * ratio
|
53 |
+
return [dataset[i] for d, i in zip(distances, indices) if d <= threshold]
|
|
|
|
|
|
|
|
|
|
|
54 |
|
55 |
+
# Load datasets and create indices
|
56 |
+
data_configs = {
|
57 |
+
"packages": "zigistry/packages",
|
58 |
+
"programs": "zigistry/programs"
|
59 |
+
}
|
|
|
60 |
|
61 |
+
data_store = {}
|
|
|
|
|
|
|
|
|
|
|
62 |
|
63 |
+
for key, dataset_name in data_configs.items():
|
64 |
+
dataset, embeddings = load_and_encode_dataset(dataset_name, include_readme=True)
|
65 |
+
index = build_faiss_index(embeddings)
|
66 |
+
data_store[key] = {
|
67 |
+
"dataset": dataset,
|
68 |
+
"index": index,
|
69 |
+
"embeddings": embeddings
|
70 |
+
}
|
71 |
+
|
72 |
+
# FastAPI endpoints
|
73 |
+
@app.get("/search/{category}/")
|
74 |
+
def search(category: str, q: str = Query(...)):
|
75 |
+
if category not in data_store:
|
76 |
+
return JSONResponse(status_code=404, content={"error": "Invalid category"})
|
77 |
|
78 |
+
store = data_store[category]
|
79 |
+
results = search_index(store["index"], q, store["embeddings"], store["dataset"])
|
|
|
|
|
|
|
|
|
|
|
80 |
headers = {"Access-Control-Allow-Origin": "*", "Content-Type": "application/json"}
|
81 |
return JSONResponse(content=results, headers=headers)
|