File size: 2,974 Bytes
0861ec7
 
 
fad1ac4
0861ec7
 
6153fbc
0861ec7
 
 
6153fbc
 
 
 
 
 
 
 
 
0861ec7
6153fbc
 
0861ec7
6153fbc
 
 
 
 
 
 
0861ec7
6153fbc
 
 
 
 
0861ec7
6153fbc
 
 
 
0861ec7
6153fbc
 
 
 
0861ec7
6153fbc
2f9d235
 
6153fbc
 
 
2f9d235
6153fbc
 
 
 
 
04adf6e
6153fbc
0861ec7
6153fbc
 
 
 
 
 
 
 
 
 
 
 
 
 
4121aea
6153fbc
 
fad1ac4
2f9d235
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import faiss
import numpy as np
from fastapi import FastAPI, Query
from fastapi.responses import JSONResponse
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from typing import List, Dict

app = FastAPI()

# Constants
EMBEDDING_MODEL_NAME = "all-MiniLM-L6-v2"
SEARCH_THRESHOLD_RATIO = 0.3
FIELDS_TO_INCLUDE = [
    "full_name", "description", "default_branch", "open_issues",
    "stargazers_count", "forks_count", "watchers_count", "license",
    "size", "fork", "updated_at", "has_build_zig",
    "has_build_zig_zon", "created_at"
]

# Load embedding model
model = SentenceTransformer(EMBEDDING_MODEL_NAME)

# Helper functions
def prepare_text(entry: Dict, include_readme: bool = True) -> str:
    parts = [str(entry.get(field, "")) for field in FIELDS_TO_INCLUDE]
    if include_readme:
        parts.append(entry.get("readme_content", ""))
    parts.extend(entry.get("topics", []))
    return " ".join(parts)

def load_and_encode_dataset(name: str, include_readme: bool = True):
    raw_dataset = load_dataset(name)["train"]
    texts = [prepare_text(item, include_readme) for item in raw_dataset]
    embeddings = model.encode(texts)
    return raw_dataset, np.array(embeddings)

def build_faiss_index(embeddings: np.ndarray) -> faiss.IndexFlatL2:
    index = faiss.IndexFlatL2(embeddings.shape[1])
    index.add(embeddings)
    return index

def search_index(index: faiss.IndexFlatL2, query: str, embeddings: np.ndarray, dataset: List[Dict]) -> List[Dict]:
    query_vector = model.encode([query])
    distances, indices = index.search(np.array(query_vector), len(dataset))
    return filter_by_distance(distances[0], indices[0], dataset)

def filter_by_distance(distances: np.ndarray, indices: np.ndarray, dataset: List[Dict], ratio: float = SEARCH_THRESHOLD_RATIO) -> List[Dict]:
    if len(distances) == 0:
        return []
    min_d, max_d = np.min(distances), np.max(distances)
    threshold = min_d + (max_d - min_d) * ratio
    return [dataset[i] for d, i in zip(distances, indices) if d <= threshold]

# Load datasets and create indices
data_configs = {
    "packages": "zigistry/packages",
    "programs": "zigistry/programs"
}

data_store = {}

for key, dataset_name in data_configs.items():
    dataset, embeddings = load_and_encode_dataset(dataset_name, include_readme=True)
    index = build_faiss_index(embeddings)
    data_store[key] = {
        "dataset": dataset,
        "index": index,
        "embeddings": embeddings
    }

# FastAPI endpoints
@app.get("/search/{category}/")
def search(category: str, q: str = Query(...)):
    if category not in data_store:
        return JSONResponse(status_code=404, content={"error": "Invalid category"})

    store = data_store[category]
    results = search_index(store["index"], q, store["embeddings"], store["dataset"])
    headers = {"Access-Control-Allow-Origin": "*", "Content-Type": "application/json"}
    return JSONResponse(content=results, headers=headers)