|
import gradio as gr |
|
from qdrant_client import QdrantClient |
|
from qdrant_client import models |
|
from sentence_transformers import SentenceTransformer |
|
from huggingface_hub import hf_hub_url |
|
from dotenv import load_dotenv |
|
import os |
|
from functools import lru_cache |
|
|
|
load_dotenv() |
|
|
|
URL = os.getenv("QDRANT_URL") |
|
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY") |
|
sentence_embedding_model = SentenceTransformer("BAAI/bge-large-en") |
|
|
|
print(URL) |
|
print(QDRANT_API_KEY) |
|
collection_name = "dataset_cards" |
|
client = QdrantClient( |
|
url=URL, |
|
api_key=QDRANT_API_KEY, |
|
) |
|
|
|
|
|
def format_results(results): |
|
markdown = "" |
|
for result in results: |
|
hub_id = result.payload["id"] |
|
url = hf_hub_url(hub_id, "README.md", repo_type="dataset") |
|
header = f"## [{hub_id}]({url})" |
|
markdown += header + "\n" |
|
markdown += result.payload["section_text"] + "\n" |
|
return markdown |
|
|
|
|
|
@lru_cache() |
|
def search(query: str): |
|
query_ = sentence_embedding_model.encode( |
|
f"Represent this sentence for searching relevant passages:{query}" |
|
) |
|
results = client.search( |
|
collection_name="dataset_cards", |
|
query_vector=query_, |
|
limit=10, |
|
) |
|
return format_results(results) |
|
|
|
|
|
@lru_cache() |
|
def hub_id_qdrant_id(hub_id): |
|
matches = client.scroll( |
|
collection_name="dataset_cards", |
|
scroll_filter=models.Filter( |
|
must=[ |
|
models.FieldCondition(key="id", match=models.MatchValue(value=hub_id)), |
|
] |
|
), |
|
limit=1, |
|
with_payload=True, |
|
with_vectors=False, |
|
) |
|
try: |
|
return matches[0][0].id |
|
except IndexError as e: |
|
raise gr.Error( |
|
f"Hub id {hub_id} not in out database. This could be because it is very new or because it doesn't have much documentation." |
|
) from e |
|
|
|
|
|
@lru_cache() |
|
def recommend(hub_id): |
|
positive_id = hub_id_qdrant_id(hub_id) |
|
results = client.recommend(collection_name=collection_name, positive=[positive_id]) |
|
return format_results(results) |
|
|
|
|
|
def query(search_term, search_type): |
|
if search_type == "Recommend similar datasets": |
|
return recommend(search_term) |
|
else: |
|
return search(search_term) |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("## 🤗 Semantic Dataset Search") |
|
with gr.Row(): |
|
gr.Markdown( |
|
"This Gradio app allows you to search for datasets based on their descriptions. You can either search for similar datasets to a given dataset or search for datasets based on a query." |
|
) |
|
with gr.Row(): |
|
search_term = gr.Textbox(value="movie review sentiment", |
|
label="hub id i.e. IMDB or query i.e. movie review sentiment" |
|
) |
|
with gr.Row(): |
|
with gr.Row(): |
|
find_similar_btn = gr.Button("Search") |
|
search_type = gr.Radio( |
|
["Recommend similar datasets", "Semantic Search"], |
|
label="Search type", |
|
value="Semantic Search", |
|
interactive=True, |
|
) |
|
|
|
results = gr.Markdown() |
|
find_similar_btn.click(query, [search_term, search_type], results) |
|
|
|
|
|
demo.launch() |
|
|