RohanVashisht commited on
Commit
cf9a0c1
·
verified ·
1 Parent(s): 0eeaf9e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -58
app.py CHANGED
@@ -1,10 +1,10 @@
1
  import faiss
2
  import numpy as np
3
- from fastapi import FastAPI
4
  from fastapi.responses import JSONResponse
5
  from datasets import load_dataset
6
  from sentence_transformers import SentenceTransformer
7
- from typing import List
8
 
9
  app = FastAPI()
10
 
@@ -29,65 +29,76 @@ print("Loading sentence transformer model (all-MiniLM-L6-v2)...")
29
  model = SentenceTransformer("all-MiniLM-L6-v2")
30
  print("Model loaded successfully.")
31
 
32
- def load_and_index_dataset(name: str, include_readme: bool = False):
33
- print(f"Loading dataset '{name}'...")
34
- dataset = load_dataset(name)["train"]
35
-
36
- repo_texts = [
37
- " ".join(str(x.get(field, "")) for field in FIELDS) +
38
- (" " + x.get("readme_content", "") if include_readme else "") +
39
- " " + " ".join(x.get("topics", []))
40
- for x in dataset
41
- ]
42
-
43
- if not include_readme:
44
- dataset = [{k: v for k, v in item.items() if k != "readme_content"} for item in dataset]
45
-
46
- print(f"Creating embeddings for {len(repo_texts)} documents in '{name}'...")
47
- repo_embeddings = model.encode(repo_texts, show_progress_bar=True)
48
-
49
- print(f"Building FAISS index for '{name}'...")
50
- embedding_dim = repo_embeddings.shape[1]
51
-
52
- index = faiss.IndexFlatL2(embedding_dim)
53
- index.add(np.array(repo_embeddings, dtype=np.float32))
54
-
55
- print(f"'{name}' dataset indexed with {index.ntotal} vectors.")
56
- return index, list(dataset)
57
-
58
- indices = {}
 
 
59
 
60
  for key, readme_flag in {"packages": True, "programs": True}.items():
61
- index, data = load_and_index_dataset(f"zigistry/{key}", include_readme=readme_flag)
62
- indices[key] = (index, data)
63
-
64
- def perform_search(query: str, dataset_key: str, k: int):
65
- index, dataset = indices[dataset_key]
66
-
67
- query_embedding = model.encode([query])
68
- query_embedding = np.array(query_embedding, dtype=np.float32)
69
-
70
- distances, idxs = index.search(query_embedding, k)
71
-
72
- results = []
73
- for dist, idx in zip(distances[0], idxs[0]):
74
- if idx == -1:
75
- continue
76
-
77
- item = dataset[int(idx)].copy()
78
- item['relevance_score'] = 1.0 - (dist / 2.0)
79
- results.append(item)
80
-
81
- return results
 
 
 
 
 
 
 
82
 
83
  @app.get("/searchPackages/")
84
- def search_packages(q: str, k: int = 10):
85
- results = perform_search(query=q, dataset_key="packages", k=k)
86
- headers = {"Access-Control-Allow-Origin": "*", "Content-Type": "application/json"}
87
- return JSONResponse(content=results, headers=headers)
 
88
 
89
  @app.get("/searchPrograms/")
90
- def search_programs(q: str, k: int = 10):
91
- results = perform_search(query=q, dataset_key="programs", k=k)
92
- headers = {"Access-Control-Allow-Origin": "*", "Content-Type": "application/json"}
93
- return JSONResponse(content=results, headers=headers)
 
 
1
  import faiss
2
  import numpy as np
3
+ from fastapi import FastAPI, Query, HTTPException
4
  from fastapi.responses import JSONResponse
5
  from datasets import load_dataset
6
  from sentence_transformers import SentenceTransformer
7
+ from typing import List, Dict, Tuple
8
 
9
  app = FastAPI()
10
 
 
29
  model = SentenceTransformer("all-MiniLM-L6-v2")
30
  print("Model loaded successfully.")
31
 
32
+ def load_and_index_dataset(name: str, include_readme: bool = False) -> Tuple[faiss.IndexFlatL2, List[Dict]]:
33
+ try:
34
+ print(f"Loading dataset '{name}'...")
35
+ dataset = load_dataset(name)["train"]
36
+
37
+ repo_texts = [
38
+ " ".join(str(x.get(field, "")) for field in FIELDS) +
39
+ (" " + x.get("readme_content", "") if include_readme else "") +
40
+ " " + " ".join(x.get("topics", []))
41
+ for x in dataset
42
+ ]
43
+
44
+ if not include_readme:
45
+ dataset = [{k: v for k, v in item.items() if k != "readme_content"} for item in dataset]
46
+
47
+ print(f"Creating embeddings for {len(repo_texts)} documents in '{name}'...")
48
+ repo_embeddings = model.encode(repo_texts, show_progress_bar=True)
49
+
50
+ embedding_dim = repo_embeddings.shape[1]
51
+ index = faiss.IndexFlatL2(embedding_dim)
52
+ index.add(np.array(repo_embeddings, dtype=np.float32))
53
+
54
+ print(f"'{name}' dataset indexed with {index.ntotal} vectors.")
55
+ return index, list(dataset)
56
+ except Exception as e:
57
+ print(f"Error loading dataset '{name}': {e}")
58
+ raise RuntimeError(f"Dataset loading/indexing failed: {name}")
59
+
60
+ indices: Dict[str, Tuple[faiss.IndexFlatL2, List[Dict]]] = {}
61
 
62
  for key, readme_flag in {"packages": True, "programs": True}.items():
63
+ try:
64
+ index, data = load_and_index_dataset(f"zigistry/{key}", include_readme=readme_flag)
65
+ indices[key] = (index, data)
66
+ except Exception as e:
67
+ print(f"Failed to prepare index for {key}: {e}")
68
+ indices[key] = (None, [])
69
+
70
+ def perform_search(query: str, dataset_key: str, k: int) -> List[Dict]:
71
+ index, dataset = indices.get(dataset_key, (None, []))
72
+ if not index:
73
+ raise HTTPException(status_code=500, detail=f"Index not available for {dataset_key}")
74
+
75
+ try:
76
+ query_embedding = model.encode([query])
77
+ distances, idxs = index.search(np.array(query_embedding, dtype=np.float32), k)
78
+
79
+ results = []
80
+ for dist, idx in zip(distances[0], idxs[0]):
81
+ if idx == -1:
82
+ continue
83
+ item = dataset[int(idx)].copy()
84
+ item["relevance_score"] = float(1.0 - dist / 2.0)
85
+ results.append(item)
86
+
87
+ return results
88
+ except Exception as e:
89
+ print(f"Error during search: {e}")
90
+ raise HTTPException(status_code=500, detail="Search failed")
91
 
92
  @app.get("/searchPackages/")
93
+ def search_packages(q: str = Query(...), k: int = Query(10)) -> JSONResponse:
94
+ if not q:
95
+ raise HTTPException(status_code=400, detail="Query parameter 'q' is required.")
96
+ results = perform_search(q, "packages", k)
97
+ return JSONResponse(content=results, headers={"Access-Control-Allow-Origin": "*"})
98
 
99
  @app.get("/searchPrograms/")
100
+ def search_programs(q: str = Query(...), k: int = Query(10)) -> JSONResponse:
101
+ if not q:
102
+ raise HTTPException(status_code=400, detail="Query parameter 'q' is required.")
103
+ results = perform_search(q, "programs", k)
104
+ return JSONResponse(content=results, headers={"Access-Control-Allow-Origin": "*"})