|
import faiss |
|
import numpy as np |
|
from fastapi import FastAPI, Query |
|
from fastapi.responses import JSONResponse |
|
from datasets import load_dataset |
|
from sentence_transformers import SentenceTransformer |
|
from typing import List, Dict |
|
|
|
app = FastAPI() |
|
|
|
|
|
EMBEDDING_MODEL_NAME = "all-MiniLM-L6-v2" |
|
SEARCH_THRESHOLD_RATIO = 0.3 |
|
FIELDS_TO_INCLUDE = [ |
|
"full_name", "description", "default_branch", "open_issues", |
|
"stargazers_count", "forks_count", "watchers_count", "license", |
|
"size", "fork", "updated_at", "has_build_zig", |
|
"has_build_zig_zon", "created_at" |
|
] |
|
|
|
|
|
model = SentenceTransformer(EMBEDDING_MODEL_NAME) |
|
|
|
|
|
def prepare_text(entry: Dict, include_readme: bool = True) -> str: |
|
parts = [str(entry.get(field, "")) for field in FIELDS_TO_INCLUDE] |
|
if include_readme: |
|
parts.append(entry.get("readme_content", "")) |
|
parts.extend(entry.get("topics", [])) |
|
return " ".join(parts) |
|
|
|
def load_and_encode_dataset(name: str, include_readme: bool = True): |
|
raw_dataset = load_dataset(name)["train"] |
|
texts = [prepare_text(item, include_readme) for item in raw_dataset] |
|
embeddings = model.encode(texts) |
|
return raw_dataset, np.array(embeddings) |
|
|
|
def build_faiss_index(embeddings: np.ndarray) -> faiss.IndexFlatL2: |
|
index = faiss.IndexFlatL2(embeddings.shape[1]) |
|
index.add(embeddings) |
|
return index |
|
|
|
def search_index(index: faiss.IndexFlatL2, query: str, embeddings: np.ndarray, dataset: List[Dict]) -> List[Dict]: |
|
query_vector = model.encode([query]) |
|
distances, indices = index.search(np.array(query_vector), len(dataset)) |
|
return filter_by_distance(distances[0], indices[0], dataset) |
|
|
|
def filter_by_distance(distances: np.ndarray, indices: np.ndarray, dataset: List[Dict], ratio: float = SEARCH_THRESHOLD_RATIO) -> List[Dict]: |
|
if len(distances) == 0: |
|
return [] |
|
min_d, max_d = np.min(distances), np.max(distances) |
|
threshold = min_d + (max_d - min_d) * ratio |
|
return [dataset[i] for d, i in zip(distances, indices) if d <= threshold] |
|
|
|
|
|
data_configs = { |
|
"packages": "zigistry/packages", |
|
"programs": "zigistry/programs" |
|
} |
|
|
|
data_store = {} |
|
|
|
for key, dataset_name in data_configs.items(): |
|
dataset, embeddings = load_and_encode_dataset(dataset_name, include_readme=True) |
|
index = build_faiss_index(embeddings) |
|
data_store[key] = { |
|
"dataset": dataset, |
|
"index": index, |
|
"embeddings": embeddings |
|
} |
|
|
|
|
|
@app.get("/search/{category}/") |
|
def search(category: str, q: str = Query(...)): |
|
if category not in data_store: |
|
return JSONResponse(status_code=404, content={"error": "Invalid category"}) |
|
|
|
store = data_store[category] |
|
results = search_index(store["index"], q, store["embeddings"], store["dataset"]) |
|
headers = {"Access-Control-Allow-Origin": "*", "Content-Type": "application/json"} |
|
return JSONResponse(content=results, headers=headers) |