RohanVashisht commited on
Commit
6153fbc
·
verified ·
1 Parent(s): 2f9d235

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -85
app.py CHANGED
@@ -4,105 +4,78 @@ from fastapi import FastAPI, Query
4
  from fastapi.responses import JSONResponse
5
  from datasets import load_dataset
6
  from sentence_transformers import SentenceTransformer
 
7
 
8
  app = FastAPI()
9
 
10
- FIELDS = (
11
- "full_name",
12
- "description",
13
- "default_branch",
14
- "open_issues",
15
- "stargazers_count",
16
- "forks_count",
17
- "watchers_count",
18
- "license",
19
- "size",
20
- "fork",
21
- "updated_at",
22
- "has_build_zig",
23
- "has_build_zig_zon",
24
- "created_at",
25
- )
26
 
27
- model = SentenceTransformer("all-MiniLM-L6-v2")
 
28
 
29
- def load_dataset_with_fields(name, include_readme=False):
30
- dataset = load_dataset(name)["train"]
31
- repo_texts = [
32
- " ".join(str(x.get(field, "")) for field in FIELDS) +
33
- (" " + x.get("readme_content", "")) * include_readme +
34
- " " + " ".join(x.get("topics", []))
35
- for x in dataset
36
- ]
37
- if not include_readme:
38
- dataset = [{k: v for k, v in item.items() if k != "readme_content"} for item in dataset]
39
- return dataset, repo_texts
40
 
41
- datasets = {
42
- "packages": load_dataset_with_fields("zigistry/packages", include_readme=True),
43
- "programs": load_dataset_with_fields("zigistry/programs", include_readme=True),
44
- }
 
45
 
46
- indices = {}
47
- for key, (dataset, repo_texts) in datasets.items():
48
- repo_embeddings = model.encode(repo_texts)
49
- index = faiss.IndexFlatL2(repo_embeddings.shape[1])
50
- index.add(np.array(repo_embeddings))
51
- indices[key] = (index, dataset)
52
 
53
- scroll_data = {
54
- "infiniteScrollPackages": load_dataset_with_fields("zigistry/packages", include_readme=False)[0],
55
- "infiniteScrollPrograms": load_dataset_with_fields("zigistry/programs", include_readme=False)[0],
56
- }
57
 
58
- def filter_results_by_distance(distances, idxs, dataset, max_results=50, threshold=0.6):
59
- """
60
- Only return results that are likely relevant (distance-based filtering).
61
- Lower distance = more similar.
62
- Threshold is a fraction of the *minimum* distance found.
63
- """
64
  if len(distances) == 0:
65
  return []
66
- min_dist = np.min(distances)
67
- cutoff = min_dist + ((max(distances) - min_dist) * threshold)
68
- filtered = [
69
- dataset[int(i)]
70
- for d, i in zip(distances, idxs)
71
- if d <= cutoff
72
- ]
73
- return filtered[:max_results]
74
 
75
- @app.get("/infiniteScrollPackages/")
76
- def infinite_scroll_packages(q: int = Query(0, ge=0)):
77
- start = q * 10
78
- content = scroll_data["infiniteScrollPackages"][start : start + 10]
79
- headers = {"Access-Control-Allow-Origin": "*", "Content-Type": "application/json"}
80
- return JSONResponse(content=content, headers=headers)
81
 
82
- @app.get("/infiniteScrollPrograms/")
83
- def infinite_scroll_programs(q: int = Query(0, ge=0)):
84
- start = q * 10
85
- content = scroll_data["infiniteScrollPrograms"][start : start + 10]
86
- headers = {"Access-Control-Allow-Origin": "*", "Content-Type": "application/json"}
87
- return JSONResponse(content=content, headers=headers)
88
 
89
- @app.get("/searchPackages/")
90
- def search_packages(q: str):
91
- key = "packages"
92
- index, dataset = indices[key]
93
- query_embedding = model.encode([q])
94
- distances, idxs = index.search(np.array(query_embedding), len(dataset))
95
- # Only keep results that are likely relevant
96
- results = filter_results_by_distance(distances[0], idxs[0], dataset)
97
- headers = {"Access-Control-Allow-Origin": "*", "Content-Type": "application/json"}
98
- return JSONResponse(content=results, headers=headers)
 
 
 
 
99
 
100
- @app.get("/searchPrograms/")
101
- def search_programs(q: str):
102
- key = "programs"
103
- index, dataset = indices[key]
104
- query_embedding = model.encode([q])
105
- distances, idxs = index.search(np.array(query_embedding), len(dataset))
106
- results = filter_results_by_distance(distances[0], idxs[0], dataset)
107
  headers = {"Access-Control-Allow-Origin": "*", "Content-Type": "application/json"}
108
  return JSONResponse(content=results, headers=headers)
 
4
  from fastapi.responses import JSONResponse
5
  from datasets import load_dataset
6
  from sentence_transformers import SentenceTransformer
7
+ from typing import List, Dict
8
 
9
  app = FastAPI()
10
 
11
+ # Constants
12
+ EMBEDDING_MODEL_NAME = "all-MiniLM-L6-v2"
13
+ SEARCH_THRESHOLD_RATIO = 0.3
14
+ FIELDS_TO_INCLUDE = [
15
+ "full_name", "description", "default_branch", "open_issues",
16
+ "stargazers_count", "forks_count", "watchers_count", "license",
17
+ "size", "fork", "updated_at", "has_build_zig",
18
+ "has_build_zig_zon", "created_at"
19
+ ]
 
 
 
 
 
 
 
20
 
21
+ # Load embedding model
22
+ model = SentenceTransformer(EMBEDDING_MODEL_NAME)
23
 
24
+ # Helper functions
25
+ def prepare_text(entry: Dict, include_readme: bool = True) -> str:
26
+ parts = [str(entry.get(field, "")) for field in FIELDS_TO_INCLUDE]
27
+ if include_readme:
28
+ parts.append(entry.get("readme_content", ""))
29
+ parts.extend(entry.get("topics", []))
30
+ return " ".join(parts)
 
 
 
 
31
 
32
+ def load_and_encode_dataset(name: str, include_readme: bool = True):
33
+ raw_dataset = load_dataset(name)["train"]
34
+ texts = [prepare_text(item, include_readme) for item in raw_dataset]
35
+ embeddings = model.encode(texts)
36
+ return raw_dataset, np.array(embeddings)
37
 
38
+ def build_faiss_index(embeddings: np.ndarray) -> faiss.IndexFlatL2:
39
+ index = faiss.IndexFlatL2(embeddings.shape[1])
40
+ index.add(embeddings)
41
+ return index
 
 
42
 
43
+ def search_index(index: faiss.IndexFlatL2, query: str, embeddings: np.ndarray, dataset: List[Dict]) -> List[Dict]:
44
+ query_vector = model.encode([query])
45
+ distances, indices = index.search(np.array(query_vector), len(dataset))
46
+ return filter_by_distance(distances[0], indices[0], dataset)
47
 
48
+ def filter_by_distance(distances: np.ndarray, indices: np.ndarray, dataset: List[Dict], ratio: float = SEARCH_THRESHOLD_RATIO) -> List[Dict]:
 
 
 
 
 
49
  if len(distances) == 0:
50
  return []
51
+ min_d, max_d = np.min(distances), np.max(distances)
52
+ threshold = min_d + (max_d - min_d) * ratio
53
+ return [dataset[i] for d, i in zip(distances, indices) if d <= threshold]
 
 
 
 
 
54
 
55
+ # Load datasets and create indices
56
+ data_configs = {
57
+ "packages": "zigistry/packages",
58
+ "programs": "zigistry/programs"
59
+ }
 
60
 
61
+ data_store = {}
 
 
 
 
 
62
 
63
+ for key, dataset_name in data_configs.items():
64
+ dataset, embeddings = load_and_encode_dataset(dataset_name, include_readme=True)
65
+ index = build_faiss_index(embeddings)
66
+ data_store[key] = {
67
+ "dataset": dataset,
68
+ "index": index,
69
+ "embeddings": embeddings
70
+ }
71
+
72
+ # FastAPI endpoints
73
+ @app.get("/search/{category}/")
74
+ def search(category: str, q: str = Query(...)):
75
+ if category not in data_store:
76
+ return JSONResponse(status_code=404, content={"error": "Invalid category"})
77
 
78
+ store = data_store[category]
79
+ results = search_index(store["index"], q, store["embeddings"], store["dataset"])
 
 
 
 
 
80
  headers = {"Access-Control-Allow-Origin": "*", "Content-Type": "application/json"}
81
  return JSONResponse(content=results, headers=headers)