RohanVashisht commited on
Commit
820aa6d
·
verified ·
1 Parent(s): 6153fbc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -60
app.py CHANGED
@@ -8,74 +8,79 @@ from typing import List, Dict
8
 
9
  app = FastAPI()
10
 
11
- # Constants
12
- EMBEDDING_MODEL_NAME = "all-MiniLM-L6-v2"
13
- SEARCH_THRESHOLD_RATIO = 0.3
14
- FIELDS_TO_INCLUDE = [
15
- "full_name", "description", "default_branch", "open_issues",
16
- "stargazers_count", "forks_count", "watchers_count", "license",
17
- "size", "fork", "updated_at", "has_build_zig",
18
- "has_build_zig_zon", "created_at"
19
- ]
 
 
 
 
 
 
 
20
 
21
- # Load embedding model
22
- model = SentenceTransformer(EMBEDDING_MODEL_NAME)
23
 
24
- # Helper functions
25
- def prepare_text(entry: Dict, include_readme: bool = True) -> str:
26
- parts = [str(entry.get(field, "")) for field in FIELDS_TO_INCLUDE]
27
- if include_readme:
28
- parts.append(entry.get("readme_content", ""))
29
- parts.extend(entry.get("topics", []))
30
- return " ".join(parts)
 
 
 
 
31
 
32
- def load_and_encode_dataset(name: str, include_readme: bool = True):
33
- raw_dataset = load_dataset(name)["train"]
34
- texts = [prepare_text(item, include_readme) for item in raw_dataset]
35
- embeddings = model.encode(texts)
36
- return raw_dataset, np.array(embeddings)
37
-
38
- def build_faiss_index(embeddings: np.ndarray) -> faiss.IndexFlatL2:
39
- index = faiss.IndexFlatL2(embeddings.shape[1])
40
- index.add(embeddings)
41
- return index
42
 
43
- def search_index(index: faiss.IndexFlatL2, query: str, embeddings: np.ndarray, dataset: List[Dict]) -> List[Dict]:
44
- query_vector = model.encode([query])
45
- distances, indices = index.search(np.array(query_vector), len(dataset))
46
- return filter_by_distance(distances[0], indices[0], dataset)
 
 
47
 
48
- def filter_by_distance(distances: np.ndarray, indices: np.ndarray, dataset: List[Dict], ratio: float = SEARCH_THRESHOLD_RATIO) -> List[Dict]:
49
  if len(distances) == 0:
50
  return []
51
- min_d, max_d = np.min(distances), np.max(distances)
52
- threshold = min_d + (max_d - min_d) * ratio
53
- return [dataset[i] for d, i in zip(distances, indices) if d <= threshold]
54
-
55
- # Load datasets and create indices
56
- data_configs = {
57
- "packages": "zigistry/packages",
58
- "programs": "zigistry/programs"
59
- }
60
-
61
- data_store = {}
62
 
63
- for key, dataset_name in data_configs.items():
64
- dataset, embeddings = load_and_encode_dataset(dataset_name, include_readme=True)
65
- index = build_faiss_index(embeddings)
66
- data_store[key] = {
67
- "dataset": dataset,
68
- "index": index,
69
- "embeddings": embeddings
70
- }
71
 
72
- # FastAPI endpoints
73
- @app.get("/search/{category}/")
74
- def search(category: str, q: str = Query(...)):
75
- if category not in data_store:
76
- return JSONResponse(status_code=404, content={"error": "Invalid category"})
 
 
 
 
77
 
78
- store = data_store[category]
79
- results = search_index(store["index"], q, store["embeddings"], store["dataset"])
 
 
 
 
 
80
  headers = {"Access-Control-Allow-Origin": "*", "Content-Type": "application/json"}
81
- return JSONResponse(content=results, headers=headers)
 
8
 
9
  app = FastAPI()
10
 
11
+ FIELDS = (
12
+ "full_name",
13
+ "description",
14
+ "default_branch",
15
+ "open_issues",
16
+ "stargazers_count",
17
+ "forks_count",
18
+ "watchers_count",
19
+ "license",
20
+ "size",
21
+ "fork",
22
+ "updated_at",
23
+ "has_build_zig",
24
+ "has_build_zig_zon",
25
+ "created_at",
26
+ )
27
 
28
+ model = SentenceTransformer("all-MiniLM-L6-v2")
 
29
 
30
+ def load_dataset_with_fields(name, include_readme=False):
31
+ dataset = load_dataset(name)["train"]
32
+ repo_texts = [
33
+ " ".join(str(x.get(field, "")) for field in FIELDS) +
34
+ (" " + x.get("readme_content", "") if include_readme else "") +
35
+ " " + " ".join(x.get("topics", []))
36
+ for x in dataset
37
+ ]
38
+ if not include_readme:
39
+ dataset = [{k: v for k, v in item.items() if k != "readme_content"} for item in dataset]
40
+ return dataset, repo_texts
41
 
42
+ datasets = {
43
+ "packages": load_dataset_with_fields("zigistry/packages", include_readme=True),
44
+ "programs": load_dataset_with_fields("zigistry/programs", include_readme=True),
45
+ }
 
 
 
 
 
 
46
 
47
+ indices = {}
48
+ for key, (dataset, repo_texts) in datasets.items():
49
+ repo_embeddings = model.encode(repo_texts)
50
+ index = faiss.IndexFlatL2(repo_embeddings.shape[1])
51
+ index.add(np.array(repo_embeddings))
52
+ indices[key] = (index, dataset)
53
 
54
+ def filter_results_by_distance(distances, idxs, dataset, threshold_ratio=0.3):
55
  if len(distances) == 0:
56
  return []
57
+ min_distance = np.min(distances)
58
+ max_distance = np.max(distances)
59
+ threshold = min_distance + ((max_distance - min_distance) * threshold_ratio)
 
 
 
 
 
 
 
 
60
 
61
+ results = [
62
+ dataset[int(i)]
63
+ for d, i in zip(distances, idxs)
64
+ if d <= threshold
65
+ ]
66
+ return results
 
 
67
 
68
+ @app.get("/searchPackages/")
69
+ def search_packages(q: str):
70
+ key = "packages"
71
+ index, dataset = indices[key]
72
+ query_embedding = model.encode([q])
73
+ distances, idxs = index.search(np.array(query_embedding), len(dataset))
74
+ results = filter_results_by_distance(distances[0], idxs[0], dataset)
75
+ headers = {"Access-Control-Allow-Origin": "*", "Content-Type": "application/json"}
76
+ return JSONResponse(content=results, headers=headers)
77
 
78
+ @app.get("/searchPrograms/")
79
+ def search_programs(q: str):
80
+ key = "programs"
81
+ index, dataset = indices[key]
82
+ query_embedding = model.encode([q])
83
+ distances, idxs = index.search(np.array(query_embedding), len(dataset))
84
+ results = filter_results_by_distance(distances[0], idxs[0], dataset)
85
  headers = {"Access-Control-Allow-Origin": "*", "Content-Type": "application/json"}
86
+ return JSONResponse(content=results, headers=headers)