File size: 4,040 Bytes
7324658
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import os
from typing import ClassVar

# import dotenv
import gradio as gr
import lancedb
import srsly
from huggingface_hub import snapshot_download
from lancedb.embeddings.base import TextEmbeddingFunction
from lancedb.embeddings.registry import register
from lancedb.pydantic import LanceModel, Vector
from lancedb.rerankers import CohereReranker, ColbertReranker
from lancedb.util import attempt_import_or_raise

# dotenv.load_dotenv()


@register("coherev3")
class CohereEmbeddingFunction_2(TextEmbeddingFunction):
    name: str = "embed-english-v3.0"
    client: ClassVar = None

    def ndims(self):
        return 768

    def generate_embeddings(self, texts):
        """
        Get the embeddings for the given texts

        Parameters
        ----------
        texts: list[str] or np.ndarray (of str)
            The texts to embed
        """
        # TODO retry, rate limit, token limit
        self._init_client()
        rs = CohereEmbeddingFunction_2.client.embed(
            texts=texts, model=self.name, input_type="search_document"
        )

        return [emb for emb in rs.embeddings]

    def _init_client(self):
        cohere = attempt_import_or_raise("cohere")
        if CohereEmbeddingFunction_2.client is None:
            CohereEmbeddingFunction_2.client = cohere.Client(
                os.environ["COHERE_API_KEY"]
            )


COHERE_EMBEDDER = CohereEmbeddingFunction_2.create()


class ArxivModel(LanceModel):
    text: str = COHERE_EMBEDDER.SourceField()
    vector: Vector(1024) = COHERE_EMBEDDER.VectorField()
    title: str
    paper_title: str
    content_type: str
    arxiv_id: str


def download_data():
    snapshot_download(
        repo_id="rbiswasfc/zotero_db",
        repo_type="dataset",
        local_dir="./data",
        token=os.environ["HF_TOKEN"],
    )
    print("Data downloaded!")


download_data()

VERSION = "0.0.0a"
DB = lancedb.connect("./data/.lancedb_zotero_v0")
ID_TO_ABSTRACT = srsly.read_json("./data/id_to_abstract.json")
RERANKERS = {"colbert": ColbertReranker(), "cohere": CohereReranker()}
TBL = DB.open_table("arxiv_zotero_v0")


def _format_results(arxiv_refs):
    results = []
    for arx_id, paper_title in arxiv_refs.items():
        abstract = ID_TO_ABSTRACT.get(arx_id, "")
        # these are all ugly hacks because the data preprocessing is poor. to be fixed v soon.
        if "Abstract\n\n" in abstract:
            abstract = abstract.split("Abstract\n\n")[-1]
        if paper_title in abstract:
            abstract = abstract.split(paper_title)[-1]
        if abstract.startswith("\n"):
            abstract = abstract[1:]
        if "\n\n" in abstract[:20]:
            abstract = "\n\n".join(abstract.split("\n\n")[1:])
        result = {
            "title": paper_title,
            "url": f"https://arxiv.org/abs/{arx_id}",
            "abstract": abstract,
        }
        results.append(result)

    return results


def query_db(query: str, k: int = 10, reranker: str = "cohere"):
    raw_results = TBL.search(query, query_type="hybrid").limit(k)
    if reranker is not None:
        ranked_results = raw_results.rerank(reranker=RERANKERS[reranker])
    else:
        ranked_results = raw_results

    ranked_results = ranked_results.to_pandas()
    top_results = ranked_results.groupby("arxiv_id").agg({"_relevance_score": "sum"})
    top_results = top_results.sort_values(by="_relevance_score", ascending=False).head(
        3
    )
    top_results_dict = {
        row["arxiv_id"]: row["paper_title"]
        for index, row in ranked_results.iterrows()
        if row["arxiv_id"] in top_results.index
    }

    final_results = _format_results(top_results_dict)
    return final_results


with gr.Blocks() as demo:
    with gr.Row():
        query = gr.Textbox(label="Query", placeholder="Enter your query...")
        submit_btn = gr.Button("Submit")
    output = gr.JSON(label="Search Results")

    # # callback ---
    submit_btn.click(
        fn=query_db,
        inputs=query,
        outputs=output,
    )


demo.launch()