File size: 2,065 Bytes
11ac6f7
 
 
 
 
2f34495
2624c54
 
2f34495
11ac6f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
from infinity_emb import AsyncEmbeddingEngine, EngineArgs
import numpy as np
from usearch.index import Index, Matches
import asyncio
import pandas as pd
import os
os.environ["HF_HOME"] = "/app"
os.environ["TRANSFORMERS_CACHE"] = "/app"
os.environ["INFINITY_QUEUE_SIZE"] = "512000"

engine = AsyncEmbeddingEngine.from_args(
    EngineArgs(
        model_name_or_path="michaelfeil/jina-embeddings-v2-base-code",
        batch_size=8,     
    )
)


async def embed_texts(texts: list[str]) -> np.ndarray:
    async with engine:
        embeddings = (await engine.embed(texts))[0]
        return np.array(embeddings)

def embed_texts_sync(texts: list[str]) -> np.ndarray:
    loop =  asyncio.new_event_loop()
    return loop.run_until_complete(embed_texts(texts))

index = None
docs_index = None


def build_index(demo_mode=True):
    global index, docs_index
    index = Index(
        ndim=embed_texts_sync(["Hi"]).shape[
            -1
        ],  # Define the number of dimensions in input vectors
        metric="cos",  # Choose 'l2sq', 'haversine' or other metric, default = 'ip'
        dtype="f16",  # Quantize to 'f16' or 'i8' if needed, default = 'f32'
        connectivity=16,  # How frequent should the connections in the graph be, optional
        expansion_add=128,  # Control the recall of indexing, optional
        expansion_search=64,  # Control the quality of search, optional
    )
    if demo_mode:
        docs_index = [
            "torch.add(*demo)",
            "torch.mul(*demo)",
            "torch.div(*demo)",
            "torch.sub(*demo)",
        ]
        embeddings = embed_texts_sync(docs_index)
        index.add(np.arange(len(docs_index)), embeddings)
        return
    # TODO: Michael, load parquet with embeddings


if index is None:
    build_index()


def answer_query(query: str) -> list[str]:
    embedding = embed_texts_sync([query])
    matches = index.search(embedding, 10)
    texts = [docs_index[match.key] for match in matches]
    return texts


if __name__ == "__main__":
    print(answer_query("torch.mul(*demo2)"))