File size: 2,974 Bytes
0861ec7 fad1ac4 0861ec7 6153fbc 0861ec7 6153fbc 0861ec7 6153fbc 0861ec7 6153fbc 0861ec7 6153fbc 0861ec7 6153fbc 0861ec7 6153fbc 0861ec7 6153fbc 2f9d235 6153fbc 2f9d235 6153fbc 04adf6e 6153fbc 0861ec7 6153fbc 4121aea 6153fbc fad1ac4 2f9d235 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
import faiss
import numpy as np
from fastapi import FastAPI, Query
from fastapi.responses import JSONResponse
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from typing import List, Dict
app = FastAPI()
# Constants
EMBEDDING_MODEL_NAME = "all-MiniLM-L6-v2"
SEARCH_THRESHOLD_RATIO = 0.3
FIELDS_TO_INCLUDE = [
"full_name", "description", "default_branch", "open_issues",
"stargazers_count", "forks_count", "watchers_count", "license",
"size", "fork", "updated_at", "has_build_zig",
"has_build_zig_zon", "created_at"
]
# Load embedding model
model = SentenceTransformer(EMBEDDING_MODEL_NAME)
# Helper functions
def prepare_text(entry: Dict, include_readme: bool = True) -> str:
parts = [str(entry.get(field, "")) for field in FIELDS_TO_INCLUDE]
if include_readme:
parts.append(entry.get("readme_content", ""))
parts.extend(entry.get("topics", []))
return " ".join(parts)
def load_and_encode_dataset(name: str, include_readme: bool = True):
raw_dataset = load_dataset(name)["train"]
texts = [prepare_text(item, include_readme) for item in raw_dataset]
embeddings = model.encode(texts)
return raw_dataset, np.array(embeddings)
def build_faiss_index(embeddings: np.ndarray) -> faiss.IndexFlatL2:
index = faiss.IndexFlatL2(embeddings.shape[1])
index.add(embeddings)
return index
def search_index(index: faiss.IndexFlatL2, query: str, embeddings: np.ndarray, dataset: List[Dict]) -> List[Dict]:
query_vector = model.encode([query])
distances, indices = index.search(np.array(query_vector), len(dataset))
return filter_by_distance(distances[0], indices[0], dataset)
def filter_by_distance(distances: np.ndarray, indices: np.ndarray, dataset: List[Dict], ratio: float = SEARCH_THRESHOLD_RATIO) -> List[Dict]:
if len(distances) == 0:
return []
min_d, max_d = np.min(distances), np.max(distances)
threshold = min_d + (max_d - min_d) * ratio
return [dataset[i] for d, i in zip(distances, indices) if d <= threshold]
# Load datasets and create indices
data_configs = {
"packages": "zigistry/packages",
"programs": "zigistry/programs"
}
data_store = {}
for key, dataset_name in data_configs.items():
dataset, embeddings = load_and_encode_dataset(dataset_name, include_readme=True)
index = build_faiss_index(embeddings)
data_store[key] = {
"dataset": dataset,
"index": index,
"embeddings": embeddings
}
# FastAPI endpoints
@app.get("/search/{category}/")
def search(category: str, q: str = Query(...)):
if category not in data_store:
return JSONResponse(status_code=404, content={"error": "Invalid category"})
store = data_store[category]
results = search_index(store["index"], q, store["embeddings"], store["dataset"])
headers = {"Access-Control-Allow-Origin": "*", "Content-Type": "application/json"}
return JSONResponse(content=results, headers=headers) |