Commit
·
6ff6cb6
1
Parent(s):
b9da93e
chore: Update dependencies and remove dead code
Browse files
app.py
CHANGED
@@ -1,179 +1,154 @@
|
|
1 |
-
import
|
2 |
-
|
3 |
-
from typing import
|
4 |
|
5 |
import gradio as gr
|
|
|
6 |
from dotenv import load_dotenv
|
7 |
-
from
|
8 |
-
from
|
9 |
-
|
10 |
|
11 |
load_dotenv()
|
|
|
|
|
|
|
|
|
12 |
|
13 |
-
URL = os.getenv("QDRANT_URL")
|
14 |
-
QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
|
15 |
-
sentence_embedding_model = SentenceTransformer("BAAI/bge-large-en")
|
16 |
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
url=URL,
|
22 |
-
api_key=QDRANT_API_KEY,
|
23 |
-
)
|
24 |
|
25 |
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
# else:
|
32 |
-
# return f"{bytes_size / (1024 ** 3):.2f} TB"
|
33 |
|
34 |
|
35 |
-
def
|
36 |
-
|
|
|
|
|
37 |
|
38 |
|
39 |
-
def format_results(
|
|
|
|
|
40 |
markdown = (
|
41 |
-
"<h1 style='text-align: center;'
|
42 |
-
" </h1> \n\n"
|
43 |
)
|
44 |
-
for result in results:
|
45 |
-
hub_id = result
|
46 |
-
|
47 |
-
lastModified = result.payload["lastModified"]
|
48 |
url = f"https://huggingface.co/datasets/{hub_id}"
|
49 |
-
|
|
|
|
|
|
|
|
|
|
|
50 |
markdown += header + "\n"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
)
|
67 |
-
markdown += "- " + "\n- ".join(linked_models) + "\n\n"
|
68 |
-
markdown += "</details>\n\n"
|
69 |
|
70 |
-
|
|
|
71 |
|
|
|
|
|
72 |
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
|
|
|
80 |
|
81 |
-
|
82 |
-
def search(query: str, limit: Optional[int] = 10, show_linked_models: bool = False):
|
83 |
-
query_ = sentence_embedding_model.encode(
|
84 |
-
f"Represent this sentence for searching relevant passages:{query}"
|
85 |
-
)
|
86 |
-
results = client.search(
|
87 |
-
collection_name="dataset_cards",
|
88 |
-
query_vector=query_,
|
89 |
-
limit=limit,
|
90 |
-
)
|
91 |
-
return format_results(results, show_associated_models=show_linked_models)
|
92 |
|
93 |
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
models.FieldCondition(key="id", match=models.MatchValue(value=hub_id)),
|
101 |
-
]
|
102 |
-
),
|
103 |
-
limit=1,
|
104 |
-
with_payload=True,
|
105 |
-
with_vectors=False,
|
106 |
)
|
107 |
-
|
108 |
-
|
109 |
-
except IndexError as e:
|
110 |
-
raise gr.Error(
|
111 |
-
f"Hub id {hub_id} not in the database. This could be because it is very new"
|
112 |
-
" or because it doesn't have much documentation."
|
113 |
-
) from e
|
114 |
-
|
115 |
-
|
116 |
-
@lru_cache()
|
117 |
-
def recommend(hub_id, limit: Optional[int] = 10, show_linked_models=False):
|
118 |
-
positive_id = hub_id_qdrant_id(hub_id)
|
119 |
-
results = client.recommend(
|
120 |
-
collection_name=collection_name, positive=[positive_id], limit=limit
|
121 |
)
|
122 |
-
return format_results(results, show_associated_models=show_linked_models)
|
123 |
|
124 |
-
|
125 |
-
def query(
|
126 |
-
search_term,
|
127 |
-
search_type,
|
128 |
-
limit: Optional[int] = 10,
|
129 |
-
show_linked_models: bool = False,
|
130 |
-
):
|
131 |
-
if search_type == "Recommend similar datasets":
|
132 |
-
return recommend(search_term, limit, show_linked_models)
|
133 |
-
else:
|
134 |
-
return search(search_term, limit, show_linked_models)
|
135 |
|
136 |
|
137 |
with gr.Blocks() as demo:
|
138 |
-
gr.Markdown("## 🤗
|
139 |
with gr.Row():
|
140 |
gr.Markdown(
|
141 |
-
"This Gradio app allows you to
|
142 |
-
"
|
143 |
-
" dataset or search for datasets based on a query. This is an early proof of concept. Feedback very welcome!"
|
144 |
)
|
145 |
with gr.Row():
|
146 |
-
|
147 |
-
value="
|
148 |
-
label="
|
149 |
)
|
150 |
|
151 |
with gr.Row():
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
with gr.Column():
|
161 |
-
max_results = gr.Slider(
|
162 |
-
minimum=1,
|
163 |
-
maximum=50,
|
164 |
-
step=1,
|
165 |
-
value=10,
|
166 |
-
label="Maximum number of results",
|
167 |
-
)
|
168 |
-
show_linked_models = gr.Checkbox(
|
169 |
-
label="Show associated models",
|
170 |
-
default=False,
|
171 |
-
)
|
172 |
|
173 |
results = gr.Markdown()
|
174 |
-
|
175 |
-
|
|
|
|
|
|
|
|
|
176 |
)
|
177 |
|
178 |
-
|
179 |
demo.launch()
|
|
|
1 |
+
import asyncio
|
2 |
+
import re
|
3 |
+
from typing import Dict, List
|
4 |
|
5 |
import gradio as gr
|
6 |
+
import httpx
|
7 |
from dotenv import load_dotenv
|
8 |
+
from huggingface_hub import ModelCard
|
9 |
+
from cashews import cache
|
10 |
+
|
11 |
|
12 |
load_dotenv()
|
13 |
+
cache.setup("mem://")
|
14 |
+
API_URL = "https://davanstrien-huggingface-datasets-search-v2.hf.space/similar"
|
15 |
+
HF_API_URL = "https://huggingface.co/api/datasets"
|
16 |
+
README_URL_TEMPLATE = "https://huggingface.co/datasets/{}/raw/main/README.md"
|
17 |
|
|
|
|
|
|
|
18 |
|
19 |
+
async def fetch_similar_datasets(dataset_id: str, limit: int = 10) -> List[Dict]:
|
20 |
+
async with httpx.AsyncClient() as client:
|
21 |
+
response = await client.get(f"{API_URL}?dataset_id={dataset_id}&n={limit}")
|
22 |
+
return response.json()["results"][1:] if response.status_code == 200 else []
|
|
|
|
|
|
|
23 |
|
24 |
|
25 |
+
async def fetch_dataset_card(dataset_id: str) -> str:
|
26 |
+
url = README_URL_TEMPLATE.format(dataset_id)
|
27 |
+
async with httpx.AsyncClient() as client:
|
28 |
+
response = await client.get(url)
|
29 |
+
return ModelCard(response.text).text if response.status_code == 200 else ""
|
|
|
|
|
30 |
|
31 |
|
32 |
+
async def fetch_dataset_info(dataset_id: str) -> Dict:
|
33 |
+
async with httpx.AsyncClient() as client:
|
34 |
+
response = await client.get(f"{HF_API_URL}/{dataset_id}")
|
35 |
+
return response.json() if response.status_code == 200 else {}
|
36 |
|
37 |
|
38 |
+
def format_results(
|
39 |
+
results: List[Dict], dataset_cards: List[str], dataset_infos: List[Dict]
|
40 |
+
) -> str:
|
41 |
markdown = (
|
42 |
+
"<h1 style='text-align: center;'>✨ Similar Datasets ✨</h1>\n\n"
|
|
|
43 |
)
|
44 |
+
for result, card, info in zip(results, dataset_cards, dataset_infos):
|
45 |
+
hub_id = result["dataset_id"]
|
46 |
+
similarity = result["similarity"]
|
|
|
47 |
url = f"https://huggingface.co/datasets/{hub_id}"
|
48 |
+
|
49 |
+
# Extract title from the card
|
50 |
+
title_match = re.match(r"^#\s*(.+)", card, re.MULTILINE)
|
51 |
+
title = title_match[1] if title_match else hub_id
|
52 |
+
|
53 |
+
header = f"## [{title}]({url})"
|
54 |
markdown += header + "\n"
|
55 |
+
markdown += f"**Similarity Score:** {similarity:.4f}\n\n"
|
56 |
+
|
57 |
+
if info:
|
58 |
+
downloads = info.get("downloads", 0)
|
59 |
+
likes = info.get("likes", 0)
|
60 |
+
last_modified = info.get("lastModified", "N/A")
|
61 |
+
markdown += f"**Downloads:** {downloads} | **Likes:** {likes} | **Last Modified:** {last_modified}\n\n"
|
62 |
+
|
63 |
+
if card:
|
64 |
+
# Remove the title from the card content
|
65 |
+
card_without_title = re.sub(
|
66 |
+
r"^#.*\n", "", card, count=1, flags=re.MULTILINE
|
67 |
+
)
|
68 |
|
69 |
+
# Split the card into paragraphs
|
70 |
+
paragraphs = card_without_title.split("\n\n")
|
71 |
+
|
72 |
+
# Find the first non-empty text paragraph that's not just an image
|
73 |
+
preview = next(
|
74 |
+
(
|
75 |
+
p
|
76 |
+
for p in paragraphs
|
77 |
+
if p.strip()
|
78 |
+
and not p.strip().startswith("![")
|
79 |
+
and not p.strip().startswith("<img")
|
80 |
+
),
|
81 |
+
"No preview available.",
|
82 |
+
)
|
|
|
|
|
|
|
83 |
|
84 |
+
# Limit the preview to a reasonable length (e.g., 300 characters)
|
85 |
+
preview = f"{preview[:300]}..." if len(preview) > 300 else preview
|
86 |
|
87 |
+
# Add the preview
|
88 |
+
markdown += f"{preview}\n\n"
|
89 |
|
90 |
+
# Limit image size in the full dataset card
|
91 |
+
full_card = re.sub(
|
92 |
+
r'<img src="([^"]+)"',
|
93 |
+
r'<img src="\1" style="max-width: 300px; max-height: 300px;"',
|
94 |
+
card_without_title,
|
95 |
+
)
|
96 |
+
full_card = re.sub(
|
97 |
+
r"!\[([^\]]*)\]\(([^\)]+)\)",
|
98 |
+
r'<img src="\2" alt="\1" style="max-width: 300px; max-height: 300px;">',
|
99 |
+
full_card,
|
100 |
+
)
|
101 |
+
markdown += f"<details><summary>Full Dataset Card</summary>\n\n{full_card}\n\n</details>\n\n"
|
102 |
|
103 |
+
markdown += "---\n\n"
|
104 |
|
105 |
+
return markdown
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
|
107 |
|
108 |
+
async def search_similar_datasets(dataset_id: str, limit: int = 10):
|
109 |
+
results = await fetch_similar_datasets(dataset_id, limit)
|
110 |
+
|
111 |
+
# Fetch dataset cards and info concurrently
|
112 |
+
dataset_cards = await asyncio.gather(
|
113 |
+
*[fetch_dataset_card(result["dataset_id"]) for result in results]
|
|
|
|
|
|
|
|
|
|
|
|
|
114 |
)
|
115 |
+
dataset_infos = await asyncio.gather(
|
116 |
+
*[fetch_dataset_info(result["dataset_id"]) for result in results]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
)
|
|
|
118 |
|
119 |
+
return format_results(results, dataset_cards, dataset_infos)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
|
121 |
|
122 |
with gr.Blocks() as demo:
|
123 |
+
gr.Markdown("## 🤗 Dataset Similarity Search")
|
124 |
with gr.Row():
|
125 |
gr.Markdown(
|
126 |
+
"This Gradio app allows you to find similar datasets based on a given dataset ID. "
|
127 |
+
"Enter a dataset ID (e.g., 'imdb') to find similar datasets with previews of their dataset cards."
|
|
|
128 |
)
|
129 |
with gr.Row():
|
130 |
+
dataset_id = gr.Textbox(
|
131 |
+
value="imdb",
|
132 |
+
label="Dataset ID (e.g., imdb, squad, glue)",
|
133 |
)
|
134 |
|
135 |
with gr.Row():
|
136 |
+
search_btn = gr.Button("Search Similar Datasets")
|
137 |
+
max_results = gr.Slider(
|
138 |
+
minimum=1,
|
139 |
+
maximum=50,
|
140 |
+
step=1,
|
141 |
+
value=10,
|
142 |
+
label="Maximum number of results",
|
143 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
144 |
|
145 |
results = gr.Markdown()
|
146 |
+
search_btn.click(
|
147 |
+
lambda dataset_id, limit: asyncio.run(
|
148 |
+
search_similar_datasets(dataset_id, limit)
|
149 |
+
),
|
150 |
+
inputs=[dataset_id, max_results],
|
151 |
+
outputs=results,
|
152 |
)
|
153 |
|
|
|
154 |
demo.launch()
|