RohanVashisht commited on
Commit
0eeaf9e
·
verified ·
1 Parent(s): 820aa6d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +45 -38
app.py CHANGED
@@ -1,10 +1,10 @@
1
  import faiss
2
  import numpy as np
3
- from fastapi import FastAPI, Query
4
  from fastapi.responses import JSONResponse
5
  from datasets import load_dataset
6
  from sentence_transformers import SentenceTransformer
7
- from typing import List, Dict
8
 
9
  app = FastAPI()
10
 
@@ -25,62 +25,69 @@ FIELDS = (
25
  "created_at",
26
  )
27
 
 
28
  model = SentenceTransformer("all-MiniLM-L6-v2")
 
29
 
30
- def load_dataset_with_fields(name, include_readme=False):
 
31
  dataset = load_dataset(name)["train"]
 
32
  repo_texts = [
33
  " ".join(str(x.get(field, "")) for field in FIELDS) +
34
  (" " + x.get("readme_content", "") if include_readme else "") +
35
  " " + " ".join(x.get("topics", []))
36
  for x in dataset
37
  ]
 
38
  if not include_readme:
39
  dataset = [{k: v for k, v in item.items() if k != "readme_content"} for item in dataset]
40
- return dataset, repo_texts
41
-
42
- datasets = {
43
- "packages": load_dataset_with_fields("zigistry/packages", include_readme=True),
44
- "programs": load_dataset_with_fields("zigistry/programs", include_readme=True),
45
- }
 
 
 
 
 
 
46
 
47
  indices = {}
48
- for key, (dataset, repo_texts) in datasets.items():
49
- repo_embeddings = model.encode(repo_texts)
50
- index = faiss.IndexFlatL2(repo_embeddings.shape[1])
51
- index.add(np.array(repo_embeddings))
52
- indices[key] = (index, dataset)
53
 
54
- def filter_results_by_distance(distances, idxs, dataset, threshold_ratio=0.3):
55
- if len(distances) == 0:
56
- return []
57
- min_distance = np.min(distances)
58
- max_distance = np.max(distances)
59
- threshold = min_distance + ((max_distance - min_distance) * threshold_ratio)
60
 
61
- results = [
62
- dataset[int(i)]
63
- for d, i in zip(distances, idxs)
64
- if d <= threshold
65
- ]
 
 
 
 
 
 
 
 
 
 
 
 
66
  return results
67
 
68
  @app.get("/searchPackages/")
69
- def search_packages(q: str):
70
- key = "packages"
71
- index, dataset = indices[key]
72
- query_embedding = model.encode([q])
73
- distances, idxs = index.search(np.array(query_embedding), len(dataset))
74
- results = filter_results_by_distance(distances[0], idxs[0], dataset)
75
  headers = {"Access-Control-Allow-Origin": "*", "Content-Type": "application/json"}
76
  return JSONResponse(content=results, headers=headers)
77
 
78
  @app.get("/searchPrograms/")
79
- def search_programs(q: str):
80
- key = "programs"
81
- index, dataset = indices[key]
82
- query_embedding = model.encode([q])
83
- distances, idxs = index.search(np.array(query_embedding), len(dataset))
84
- results = filter_results_by_distance(distances[0], idxs[0], dataset)
85
  headers = {"Access-Control-Allow-Origin": "*", "Content-Type": "application/json"}
86
- return JSONResponse(content=results, headers=headers)
 
1
  import faiss
2
  import numpy as np
3
+ from fastapi import FastAPI
4
  from fastapi.responses import JSONResponse
5
  from datasets import load_dataset
6
  from sentence_transformers import SentenceTransformer
7
+ from typing import List
8
 
9
  app = FastAPI()
10
 
 
25
  "created_at",
26
  )
27
 
28
+ print("Loading sentence transformer model (all-MiniLM-L6-v2)...")
29
  model = SentenceTransformer("all-MiniLM-L6-v2")
30
+ print("Model loaded successfully.")
31
 
32
+ def load_and_index_dataset(name: str, include_readme: bool = False):
33
+ print(f"Loading dataset '{name}'...")
34
  dataset = load_dataset(name)["train"]
35
+
36
  repo_texts = [
37
  " ".join(str(x.get(field, "")) for field in FIELDS) +
38
  (" " + x.get("readme_content", "") if include_readme else "") +
39
  " " + " ".join(x.get("topics", []))
40
  for x in dataset
41
  ]
42
+
43
  if not include_readme:
44
  dataset = [{k: v for k, v in item.items() if k != "readme_content"} for item in dataset]
45
+
46
+ print(f"Creating embeddings for {len(repo_texts)} documents in '{name}'...")
47
+ repo_embeddings = model.encode(repo_texts, show_progress_bar=True)
48
+
49
+ print(f"Building FAISS index for '{name}'...")
50
+ embedding_dim = repo_embeddings.shape[1]
51
+
52
+ index = faiss.IndexFlatL2(embedding_dim)
53
+ index.add(np.array(repo_embeddings, dtype=np.float32))
54
+
55
+ print(f"'{name}' dataset indexed with {index.ntotal} vectors.")
56
+ return index, list(dataset)
57
 
58
  indices = {}
 
 
 
 
 
59
 
60
+ for key, readme_flag in {"packages": True, "programs": True}.items():
61
+ index, data = load_and_index_dataset(f"zigistry/{key}", include_readme=readme_flag)
62
+ indices[key] = (index, data)
 
 
 
63
 
64
+ def perform_search(query: str, dataset_key: str, k: int):
65
+ index, dataset = indices[dataset_key]
66
+
67
+ query_embedding = model.encode([query])
68
+ query_embedding = np.array(query_embedding, dtype=np.float32)
69
+
70
+ distances, idxs = index.search(query_embedding, k)
71
+
72
+ results = []
73
+ for dist, idx in zip(distances[0], idxs[0]):
74
+ if idx == -1:
75
+ continue
76
+
77
+ item = dataset[int(idx)].copy()
78
+ item['relevance_score'] = 1.0 - (dist / 2.0)
79
+ results.append(item)
80
+
81
  return results
82
 
83
  @app.get("/searchPackages/")
84
+ def search_packages(q: str, k: int = 10):
85
+ results = perform_search(query=q, dataset_key="packages", k=k)
 
 
 
 
86
  headers = {"Access-Control-Allow-Origin": "*", "Content-Type": "application/json"}
87
  return JSONResponse(content=results, headers=headers)
88
 
89
  @app.get("/searchPrograms/")
90
+ def search_programs(q: str, k: int = 10):
91
+ results = perform_search(query=q, dataset_key="programs", k=k)
 
 
 
 
92
  headers = {"Access-Control-Allow-Origin": "*", "Content-Type": "application/json"}
93
+ return JSONResponse(content=results, headers=headers)