import faiss import numpy as np from fastapi import FastAPI, Query from fastapi.responses import JSONResponse from datasets import load_dataset from sentence_transformers import SentenceTransformer from typing import List, Dict app = FastAPI() # Constants EMBEDDING_MODEL_NAME = "all-MiniLM-L6-v2" SEARCH_THRESHOLD_RATIO = 0.3 FIELDS_TO_INCLUDE = [ "full_name", "description", "default_branch", "open_issues", "stargazers_count", "forks_count", "watchers_count", "license", "size", "fork", "updated_at", "has_build_zig", "has_build_zig_zon", "created_at" ] # Load embedding model model = SentenceTransformer(EMBEDDING_MODEL_NAME) # Helper functions def prepare_text(entry: Dict, include_readme: bool = True) -> str: parts = [str(entry.get(field, "")) for field in FIELDS_TO_INCLUDE] if include_readme: parts.append(entry.get("readme_content", "")) parts.extend(entry.get("topics", [])) return " ".join(parts) def load_and_encode_dataset(name: str, include_readme: bool = True): raw_dataset = load_dataset(name)["train"] texts = [prepare_text(item, include_readme) for item in raw_dataset] embeddings = model.encode(texts) return raw_dataset, np.array(embeddings) def build_faiss_index(embeddings: np.ndarray) -> faiss.IndexFlatL2: index = faiss.IndexFlatL2(embeddings.shape[1]) index.add(embeddings) return index def search_index(index: faiss.IndexFlatL2, query: str, embeddings: np.ndarray, dataset: List[Dict]) -> List[Dict]: query_vector = model.encode([query]) distances, indices = index.search(np.array(query_vector), len(dataset)) return filter_by_distance(distances[0], indices[0], dataset) def filter_by_distance(distances: np.ndarray, indices: np.ndarray, dataset: List[Dict], ratio: float = SEARCH_THRESHOLD_RATIO) -> List[Dict]: if len(distances) == 0: return [] min_d, max_d = np.min(distances), np.max(distances) threshold = min_d + (max_d - min_d) * ratio return [dataset[i] for d, i in zip(distances, indices) if d <= threshold] # Load datasets and create indices data_configs = { "packages": "zigistry/packages", "programs": "zigistry/programs" } data_store = {} for key, dataset_name in data_configs.items(): dataset, embeddings = load_and_encode_dataset(dataset_name, include_readme=True) index = build_faiss_index(embeddings) data_store[key] = { "dataset": dataset, "index": index, "embeddings": embeddings } # FastAPI endpoints @app.get("/search/{category}/") def search(category: str, q: str = Query(...)): if category not in data_store: return JSONResponse(status_code=404, content={"error": "Invalid category"}) store = data_store[category] results = search_index(store["index"], q, store["embeddings"], store["dataset"]) headers = {"Access-Control-Allow-Origin": "*", "Content-Type": "application/json"} return JSONResponse(content=results, headers=headers)