Spaces:
Running
Running
File size: 4,040 Bytes
7324658 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import os
from typing import ClassVar
# import dotenv
import gradio as gr
import lancedb
import srsly
from huggingface_hub import snapshot_download
from lancedb.embeddings.base import TextEmbeddingFunction
from lancedb.embeddings.registry import register
from lancedb.pydantic import LanceModel, Vector
from lancedb.rerankers import CohereReranker, ColbertReranker
from lancedb.util import attempt_import_or_raise
# dotenv.load_dotenv()
@register("coherev3")
class CohereEmbeddingFunction_2(TextEmbeddingFunction):
name: str = "embed-english-v3.0"
client: ClassVar = None
def ndims(self):
return 768
def generate_embeddings(self, texts):
"""
Get the embeddings for the given texts
Parameters
----------
texts: list[str] or np.ndarray (of str)
The texts to embed
"""
# TODO retry, rate limit, token limit
self._init_client()
rs = CohereEmbeddingFunction_2.client.embed(
texts=texts, model=self.name, input_type="search_document"
)
return [emb for emb in rs.embeddings]
def _init_client(self):
cohere = attempt_import_or_raise("cohere")
if CohereEmbeddingFunction_2.client is None:
CohereEmbeddingFunction_2.client = cohere.Client(
os.environ["COHERE_API_KEY"]
)
COHERE_EMBEDDER = CohereEmbeddingFunction_2.create()
class ArxivModel(LanceModel):
text: str = COHERE_EMBEDDER.SourceField()
vector: Vector(1024) = COHERE_EMBEDDER.VectorField()
title: str
paper_title: str
content_type: str
arxiv_id: str
def download_data():
snapshot_download(
repo_id="rbiswasfc/zotero_db",
repo_type="dataset",
local_dir="./data",
token=os.environ["HF_TOKEN"],
)
print("Data downloaded!")
download_data()
VERSION = "0.0.0a"
DB = lancedb.connect("./data/.lancedb_zotero_v0")
ID_TO_ABSTRACT = srsly.read_json("./data/id_to_abstract.json")
RERANKERS = {"colbert": ColbertReranker(), "cohere": CohereReranker()}
TBL = DB.open_table("arxiv_zotero_v0")
def _format_results(arxiv_refs):
results = []
for arx_id, paper_title in arxiv_refs.items():
abstract = ID_TO_ABSTRACT.get(arx_id, "")
# these are all ugly hacks because the data preprocessing is poor. to be fixed v soon.
if "Abstract\n\n" in abstract:
abstract = abstract.split("Abstract\n\n")[-1]
if paper_title in abstract:
abstract = abstract.split(paper_title)[-1]
if abstract.startswith("\n"):
abstract = abstract[1:]
if "\n\n" in abstract[:20]:
abstract = "\n\n".join(abstract.split("\n\n")[1:])
result = {
"title": paper_title,
"url": f"https://arxiv.org/abs/{arx_id}",
"abstract": abstract,
}
results.append(result)
return results
def query_db(query: str, k: int = 10, reranker: str = "cohere"):
raw_results = TBL.search(query, query_type="hybrid").limit(k)
if reranker is not None:
ranked_results = raw_results.rerank(reranker=RERANKERS[reranker])
else:
ranked_results = raw_results
ranked_results = ranked_results.to_pandas()
top_results = ranked_results.groupby("arxiv_id").agg({"_relevance_score": "sum"})
top_results = top_results.sort_values(by="_relevance_score", ascending=False).head(
3
)
top_results_dict = {
row["arxiv_id"]: row["paper_title"]
for index, row in ranked_results.iterrows()
if row["arxiv_id"] in top_results.index
}
final_results = _format_results(top_results_dict)
return final_results
with gr.Blocks() as demo:
with gr.Row():
query = gr.Textbox(label="Query", placeholder="Enter your query...")
submit_btn = gr.Button("Submit")
output = gr.JSON(label="Search Results")
# # callback ---
submit_btn.click(
fn=query_db,
inputs=query,
outputs=output,
)
demo.launch()
|