|
import faiss |
|
import numpy as np |
|
from fastapi import FastAPI, Query, HTTPException |
|
from fastapi.responses import JSONResponse |
|
from datasets import load_dataset |
|
from sentence_transformers import SentenceTransformer |
|
from typing import List, Dict, Tuple |
|
|
|
app = FastAPI() |
|
|
|
FIELDS = ( |
|
"full_name", |
|
"description", |
|
"default_branch", |
|
"open_issues", |
|
"stargazers_count", |
|
"forks_count", |
|
"watchers_count", |
|
"license", |
|
"size", |
|
"fork", |
|
"updated_at", |
|
"has_build_zig", |
|
"has_build_zig_zon", |
|
"created_at", |
|
) |
|
|
|
print("Loading sentence transformer model (all-MiniLM-L6-v2)...") |
|
model = SentenceTransformer("all-MiniLM-L6-v2") |
|
print("Model loaded successfully.") |
|
|
|
def load_and_index_dataset(name: str, include_readme: bool = False) -> Tuple[faiss.IndexFlatL2, List[Dict]]: |
|
try: |
|
print(f"Loading dataset '{name}'...") |
|
dataset = load_dataset(name)["train"] |
|
|
|
repo_texts = [ |
|
" ".join(str(x.get(field, "")) for field in FIELDS) + |
|
(" " + x.get("readme_content", "") if include_readme else "") + |
|
" " + " ".join(x.get("topics", [])) |
|
for x in dataset |
|
] |
|
|
|
if not include_readme: |
|
dataset = [{k: v for k, v in item.items() if k != "readme_content"} for item in dataset] |
|
|
|
print(f"Creating embeddings for {len(repo_texts)} documents in '{name}'...") |
|
repo_embeddings = model.encode(repo_texts, show_progress_bar=True) |
|
|
|
embedding_dim = repo_embeddings.shape[1] |
|
index = faiss.IndexFlatL2(embedding_dim) |
|
index.add(np.array(repo_embeddings, dtype=np.float32)) |
|
|
|
print(f"'{name}' dataset indexed with {index.ntotal} vectors.") |
|
return index, list(dataset) |
|
except Exception as e: |
|
print(f"Error loading dataset '{name}': {e}") |
|
raise RuntimeError(f"Dataset loading/indexing failed: {name}") |
|
|
|
indices: Dict[str, Tuple[faiss.IndexFlatL2, List[Dict]]] = {} |
|
|
|
for key, readme_flag in {"packages": True, "programs": True}.items(): |
|
try: |
|
index, data = load_and_index_dataset(f"zigistry/{key}", include_readme=readme_flag) |
|
indices[key] = (index, data) |
|
except Exception as e: |
|
print(f"Failed to prepare index for {key}: {e}") |
|
indices[key] = (None, []) |
|
|
|
def perform_search(query: str, dataset_key: str, k: int) -> List[Dict]: |
|
index, dataset = indices.get(dataset_key, (None, [])) |
|
if not index: |
|
raise HTTPException(status_code=500, detail=f"Index not available for {dataset_key}") |
|
|
|
try: |
|
query_embedding = model.encode([query]) |
|
distances, idxs = index.search(np.array(query_embedding, dtype=np.float32), k) |
|
|
|
results = [] |
|
for dist, idx in zip(distances[0], idxs[0]): |
|
if idx == -1: |
|
continue |
|
item = dataset[int(idx)].copy() |
|
item["relevance_score"] = float(1.0 - dist / 2.0) |
|
results.append(item) |
|
|
|
return results |
|
except Exception as e: |
|
print(f"Error during search: {e}") |
|
raise HTTPException(status_code=500, detail="Search failed") |
|
|
|
@app.get("/searchPackages/") |
|
def search_packages(q: str = Query(...), k: int = Query(10)) -> JSONResponse: |
|
if not q: |
|
raise HTTPException(status_code=400, detail="Query parameter 'q' is required.") |
|
results = perform_search(q, "packages", k) |
|
return JSONResponse(content=results, headers={"Access-Control-Allow-Origin": "*"}) |
|
|
|
@app.get("/searchPrograms/") |
|
def search_programs(q: str = Query(...), k: int = Query(10)) -> JSONResponse: |
|
if not q: |
|
raise HTTPException(status_code=400, detail="Query parameter 'q' is required.") |
|
results = perform_search(q, "programs", k) |
|
return JSONResponse(content=results, headers={"Access-Control-Allow-Origin": "*"}) |
|
|