import faiss | |
import numpy as np | |
from fastapi import FastAPI, Query | |
from fastapi.middleware.cors import CORSMiddleware | |
from datasets import load_dataset | |
from sentence_transformers import SentenceTransformer | |
app = FastAPI() | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], # Change this to specific origins if needed for security | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
FIELDS = ( | |
"full_name", "description", "watchers_count", "forks_count", "license", | |
"default_branch", "has_build_zig", "has_build_zig_zon", "fork", | |
"open_issues", "stargazers_count", "updated_at", "created_at", | |
"size" | |
) | |
model = SentenceTransformer("all-MiniLM-L6-v2") | |
def load_dataset_with_fields(name, include_readme=False): | |
dataset = load_dataset(name)["train"] | |
repo_texts = [ | |
" ".join(str(x.get(field, "")) for field in FIELDS) + | |
(" " + x.get("readme_content", "")) * include_readme + | |
" " + " ".join(x.get("topics", [])) | |
for x in dataset | |
] | |
if not include_readme: | |
dataset = [{k: v for k, v in item.items() if k != "readme_content"} for item in dataset] | |
return dataset, repo_texts | |
datasets = { | |
"packages": load_dataset_with_fields("zigistry/packages", include_readme=True), | |
"programs": load_dataset_with_fields("zigistry/programs", include_readme=True), | |
} | |
indices = {} | |
for key, (dataset, repo_texts) in datasets.items(): | |
repo_embeddings = model.encode(repo_texts) | |
index = faiss.IndexFlatL2(repo_embeddings.shape[1]) | |
index.add(np.array(repo_embeddings)) | |
indices[key] = (index, dataset) | |
scroll_data = { | |
"infiniteScrollPackages": load_dataset_with_fields("zigistry/packages", include_readme=False)[0], | |
"infiniteScrollPrograms": load_dataset_with_fields("zigistry/programs", include_readme=False)[0], | |
} | |
def infinite_scroll_packages(q: int = Query(0, ge=0)): | |
start = q * 10 | |
return scroll_data["infiniteScrollPackages"][start : start + 10] | |
def infinite_scroll_programs(q: int = Query(0, ge=0)): | |
start = q * 10 | |
return scroll_data["infiniteScrollPrograms"][start : start + 10] | |
def search_packages(q: str): | |
key = "packages" | |
index, dataset = indices[key] | |
query_embedding = model.encode([q]) | |
distances, indices_ = index.search(np.array(query_embedding), len(dataset)) | |
min_distance = distances[0][0] | |
threshold = min_distance * 1.5 | |
results = [dataset[int(i)] for d, i in zip(distances[0], indices_[0]) if d <= threshold] | |
final_thing = results[:280] if len(results) > 280 else results | |
for i in final_thing: | |
i.pop("readme_content", None) | |
return final_thing | |
def search_programs(q: str): | |
key = "programs" | |
index, dataset = indices[key] | |
query_embedding = model.encode([q]) | |
distances, indices_ = index.search(np.array(query_embedding), len(dataset)) | |
min_distance = distances[0][0] | |
threshold = min_distance * 1.5 | |
results = [dataset[int(i)] for d, i in zip(distances[0], indices_[0]) if d <= threshold] | |
final_thing = results[:280] if len(results) > 280 else results | |
for i in final_thing: | |
i.pop("readme_content", None) | |
return final_thing |