RohanVashisht commited on
Commit
0861ec7
·
verified ·
1 Parent(s): 2a48d44

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +65 -0
app.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import faiss
2
+ import numpy as np
3
+ from fastapi import FastAPI, Query
4
+ from datasets import load_dataset
5
+ from sentence_transformers import SentenceTransformer
6
+
7
+ app = FastAPI()
8
+
9
+ FIELDS = (
10
+ "full_name", "description", "watchers_count", "forks_count", "license",
11
+ "default_branch", "has_build_zig", "has_build_zig_zon", "fork",
12
+ "open_issues", "stargazers_count", "updated_at", "created_at",
13
+ "size"
14
+ )
15
+
16
+ model = SentenceTransformer("all-MiniLM-L6-v2")
17
+
18
+ def load_dataset_with_fields(name, include_readme=False):
19
+ dataset = load_dataset(name)["train"]
20
+ repo_texts = [
21
+ " ".join(str(x.get(field, "")) for field in FIELDS) +
22
+ (" " + x.get("readme_content", "")) * include_readme +
23
+ " " + " ".join(x.get("topics", []))
24
+ for x in dataset
25
+ ]
26
+ if not include_readme:
27
+ dataset = [{k: v for k, v in item.items() if k != "readme_content"} for item in dataset]
28
+ return dataset, repo_texts
29
+
30
+ datasets = {
31
+ "packages": load_dataset_with_fields("zigistry/packages", include_readme=True),
32
+ "programs": load_dataset_with_fields("zigistry/programs", include_readme=True),
33
+ }
34
+
35
+ indices = {}
36
+ for key, (dataset, repo_texts) in datasets.items():
37
+ repo_embeddings = model.encode(repo_texts)
38
+ index = faiss.IndexFlatL2(repo_embeddings.shape[1])
39
+ index.add(np.array(repo_embeddings))
40
+ indices[key] = (index, dataset)
41
+
42
+ scroll_data = {
43
+ "infiniteScrollPackages": load_dataset_with_fields("zigistry/packages", include_readme=False)[0],
44
+ "infiniteScrollPrograms": load_dataset_with_fields("zigistry/programs", include_readme=False)[0],
45
+ }
46
+
47
+ @app.get("/fetch_data/")
48
+ def fetch_data(category: str, page_number: int = Query(0, ge=0)):
49
+ if category not in scroll_data:
50
+ return {"error": "Invalid category"}
51
+ start = page_number * 10
52
+ return scroll_data[category][start : start + 10]
53
+
54
+ @app.get("/search_repositories/")
55
+ def search_repositories(category: str, query: str):
56
+ key = "packages" if category == "SearchPackages" else "programs"
57
+ if key not in indices:
58
+ return {"error": "Invalid category"}
59
+ index, dataset = indices[key]
60
+ query_embedding = model.encode([query])
61
+ distances, indices_ = index.search(np.array(query_embedding), len(dataset))
62
+ min_distance = distances[0][0]
63
+ threshold = min_distance * 1.5
64
+ results = [dataset[int(i)] for d, i in zip(distances[0], indices_[0]) if d <= threshold]
65
+ return results[:280] if len(results) > 280 else results