File size: 2,970 Bytes
0861ec7 0eeaf9e fad1ac4 0861ec7 0eeaf9e 0861ec7 820aa6d 0861ec7 0eeaf9e 820aa6d 0eeaf9e 0861ec7 0eeaf9e 820aa6d 0eeaf9e 820aa6d 0eeaf9e 820aa6d 0eeaf9e 0861ec7 820aa6d 0861ec7 0eeaf9e 0861ec7 0eeaf9e 820aa6d 6153fbc 820aa6d 0eeaf9e 820aa6d 4121aea 820aa6d 0eeaf9e fad1ac4 0eeaf9e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 |
import faiss
import numpy as np
from fastapi import FastAPI
from fastapi.responses import JSONResponse
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from typing import List
app = FastAPI()
FIELDS = (
"full_name",
"description",
"default_branch",
"open_issues",
"stargazers_count",
"forks_count",
"watchers_count",
"license",
"size",
"fork",
"updated_at",
"has_build_zig",
"has_build_zig_zon",
"created_at",
)
print("Loading sentence transformer model (all-MiniLM-L6-v2)...")
model = SentenceTransformer("all-MiniLM-L6-v2")
print("Model loaded successfully.")
def load_and_index_dataset(name: str, include_readme: bool = False):
print(f"Loading dataset '{name}'...")
dataset = load_dataset(name)["train"]
repo_texts = [
" ".join(str(x.get(field, "")) for field in FIELDS) +
(" " + x.get("readme_content", "") if include_readme else "") +
" " + " ".join(x.get("topics", []))
for x in dataset
]
if not include_readme:
dataset = [{k: v for k, v in item.items() if k != "readme_content"} for item in dataset]
print(f"Creating embeddings for {len(repo_texts)} documents in '{name}'...")
repo_embeddings = model.encode(repo_texts, show_progress_bar=True)
print(f"Building FAISS index for '{name}'...")
embedding_dim = repo_embeddings.shape[1]
index = faiss.IndexFlatL2(embedding_dim)
index.add(np.array(repo_embeddings, dtype=np.float32))
print(f"'{name}' dataset indexed with {index.ntotal} vectors.")
return index, list(dataset)
indices = {}
for key, readme_flag in {"packages": True, "programs": True}.items():
index, data = load_and_index_dataset(f"zigistry/{key}", include_readme=readme_flag)
indices[key] = (index, data)
def perform_search(query: str, dataset_key: str, k: int):
index, dataset = indices[dataset_key]
query_embedding = model.encode([query])
query_embedding = np.array(query_embedding, dtype=np.float32)
distances, idxs = index.search(query_embedding, k)
results = []
for dist, idx in zip(distances[0], idxs[0]):
if idx == -1:
continue
item = dataset[int(idx)].copy()
item['relevance_score'] = 1.0 - (dist / 2.0)
results.append(item)
return results
@app.get("/searchPackages/")
def search_packages(q: str, k: int = 10):
results = perform_search(query=q, dataset_key="packages", k=k)
headers = {"Access-Control-Allow-Origin": "*", "Content-Type": "application/json"}
return JSONResponse(content=results, headers=headers)
@app.get("/searchPrograms/")
def search_programs(q: str, k: int = 10):
results = perform_search(query=q, dataset_key="programs", k=k)
headers = {"Access-Control-Allow-Origin": "*", "Content-Type": "application/json"}
return JSONResponse(content=results, headers=headers) |