api-ai / app.py
RohanVashisht's picture
Update app.py
6153fbc verified
raw
history blame
2.97 kB
import faiss
import numpy as np
from fastapi import FastAPI, Query
from fastapi.responses import JSONResponse
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from typing import List, Dict
app = FastAPI()
# Constants
EMBEDDING_MODEL_NAME = "all-MiniLM-L6-v2"
SEARCH_THRESHOLD_RATIO = 0.3
FIELDS_TO_INCLUDE = [
"full_name", "description", "default_branch", "open_issues",
"stargazers_count", "forks_count", "watchers_count", "license",
"size", "fork", "updated_at", "has_build_zig",
"has_build_zig_zon", "created_at"
]
# Load embedding model
model = SentenceTransformer(EMBEDDING_MODEL_NAME)
# Helper functions
def prepare_text(entry: Dict, include_readme: bool = True) -> str:
parts = [str(entry.get(field, "")) for field in FIELDS_TO_INCLUDE]
if include_readme:
parts.append(entry.get("readme_content", ""))
parts.extend(entry.get("topics", []))
return " ".join(parts)
def load_and_encode_dataset(name: str, include_readme: bool = True):
raw_dataset = load_dataset(name)["train"]
texts = [prepare_text(item, include_readme) for item in raw_dataset]
embeddings = model.encode(texts)
return raw_dataset, np.array(embeddings)
def build_faiss_index(embeddings: np.ndarray) -> faiss.IndexFlatL2:
index = faiss.IndexFlatL2(embeddings.shape[1])
index.add(embeddings)
return index
def search_index(index: faiss.IndexFlatL2, query: str, embeddings: np.ndarray, dataset: List[Dict]) -> List[Dict]:
query_vector = model.encode([query])
distances, indices = index.search(np.array(query_vector), len(dataset))
return filter_by_distance(distances[0], indices[0], dataset)
def filter_by_distance(distances: np.ndarray, indices: np.ndarray, dataset: List[Dict], ratio: float = SEARCH_THRESHOLD_RATIO) -> List[Dict]:
if len(distances) == 0:
return []
min_d, max_d = np.min(distances), np.max(distances)
threshold = min_d + (max_d - min_d) * ratio
return [dataset[i] for d, i in zip(distances, indices) if d <= threshold]
# Load datasets and create indices
data_configs = {
"packages": "zigistry/packages",
"programs": "zigistry/programs"
}
data_store = {}
for key, dataset_name in data_configs.items():
dataset, embeddings = load_and_encode_dataset(dataset_name, include_readme=True)
index = build_faiss_index(embeddings)
data_store[key] = {
"dataset": dataset,
"index": index,
"embeddings": embeddings
}
# FastAPI endpoints
@app.get("/search/{category}/")
def search(category: str, q: str = Query(...)):
if category not in data_store:
return JSONResponse(status_code=404, content={"error": "Invalid category"})
store = data_store[category]
results = search_index(store["index"], q, store["embeddings"], store["dataset"])
headers = {"Access-Control-Allow-Origin": "*", "Content-Type": "application/json"}
return JSONResponse(content=results, headers=headers)