davanstrien HF staff commited on
Commit
6ff6cb6
·
1 Parent(s): b9da93e

chore: Update dependencies and remove dead code

Browse files
Files changed (1) hide show
  1. app.py +111 -136
app.py CHANGED
@@ -1,179 +1,154 @@
1
- import os
2
- from functools import lru_cache
3
- from typing import Optional
4
 
5
  import gradio as gr
 
6
  from dotenv import load_dotenv
7
- from qdrant_client import QdrantClient, models
8
- from sentence_transformers import SentenceTransformer
9
- from huggingface_hub import list_models
10
 
11
  load_dotenv()
 
 
 
 
12
 
13
- URL = os.getenv("QDRANT_URL")
14
- QDRANT_API_KEY = os.getenv("QDRANT_API_KEY")
15
- sentence_embedding_model = SentenceTransformer("BAAI/bge-large-en")
16
 
17
- print(URL)
18
- print(QDRANT_API_KEY)
19
- collection_name = "dataset_cards"
20
- client = QdrantClient(
21
- url=URL,
22
- api_key=QDRANT_API_KEY,
23
- )
24
 
25
 
26
- # def convert_bytes_to_human_readable_size(bytes_size):
27
- # if bytes_size < 1024**2:
28
- # return f"{bytes_size / 1024:.2f} MB"
29
- # elif bytes_size < 1024**3:
30
- # return f"{bytes_size / (1024 ** 2):.2f} GB"
31
- # else:
32
- # return f"{bytes_size / (1024 ** 3):.2f} TB"
33
 
34
 
35
- def format_time_nicely(time_str):
36
- return time_str.split("T")[0]
 
 
37
 
38
 
39
- def format_results(results, show_associated_models=True):
 
 
40
  markdown = (
41
- "<h1 style='text-align: center;'> &#x2728; Dataset Search Results &#x2728;"
42
- " </h1> \n\n"
43
  )
44
- for result in results:
45
- hub_id = result.payload["id"]
46
- download_number = result.payload["downloads"]
47
- lastModified = result.payload["lastModified"]
48
  url = f"https://huggingface.co/datasets/{hub_id}"
49
- header = f"## [{hub_id}]({url})"
 
 
 
 
 
50
  markdown += header + "\n"
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
- markdown += f"**30 Day Download:** {download_number}"
53
- if lastModified:
54
- markdown += f" | **Last Modified:** {format_time_nicely(lastModified)} \n\n"
55
- else:
56
- markdown += "\n\n"
57
- markdown += f"{result.payload['section_text']} \n"
58
- if show_associated_models:
59
- if linked_models := get_models_for_dataset(hub_id):
60
- linked_models = [
61
- f"[{model}](https://huggingface.co/{model})"
62
- for model in linked_models
63
- ]
64
- markdown += (
65
- "<details><summary>Models trained on this dataset</summary>\n\n"
66
- )
67
- markdown += "- " + "\n- ".join(linked_models) + "\n\n"
68
- markdown += "</details>\n\n"
69
 
70
- return markdown
 
71
 
 
 
72
 
73
- @lru_cache(maxsize=100_000)
74
- def get_models_for_dataset(id):
75
- results = list(iter(list_models(filter=f"dataset:{id}")))
76
- if results:
77
- results = list({result.id for result in results})
78
- return results
 
 
 
 
 
 
79
 
 
80
 
81
- @lru_cache(maxsize=200_000)
82
- def search(query: str, limit: Optional[int] = 10, show_linked_models: bool = False):
83
- query_ = sentence_embedding_model.encode(
84
- f"Represent this sentence for searching relevant passages:{query}"
85
- )
86
- results = client.search(
87
- collection_name="dataset_cards",
88
- query_vector=query_,
89
- limit=limit,
90
- )
91
- return format_results(results, show_associated_models=show_linked_models)
92
 
93
 
94
- @lru_cache(maxsize=100_000)
95
- def hub_id_qdrant_id(hub_id):
96
- matches = client.scroll(
97
- collection_name="dataset_cards",
98
- scroll_filter=models.Filter(
99
- must=[
100
- models.FieldCondition(key="id", match=models.MatchValue(value=hub_id)),
101
- ]
102
- ),
103
- limit=1,
104
- with_payload=True,
105
- with_vectors=False,
106
  )
107
- try:
108
- return matches[0][0].id
109
- except IndexError as e:
110
- raise gr.Error(
111
- f"Hub id {hub_id} not in the database. This could be because it is very new"
112
- " or because it doesn't have much documentation."
113
- ) from e
114
-
115
-
116
- @lru_cache()
117
- def recommend(hub_id, limit: Optional[int] = 10, show_linked_models=False):
118
- positive_id = hub_id_qdrant_id(hub_id)
119
- results = client.recommend(
120
- collection_name=collection_name, positive=[positive_id], limit=limit
121
  )
122
- return format_results(results, show_associated_models=show_linked_models)
123
 
124
-
125
- def query(
126
- search_term,
127
- search_type,
128
- limit: Optional[int] = 10,
129
- show_linked_models: bool = False,
130
- ):
131
- if search_type == "Recommend similar datasets":
132
- return recommend(search_term, limit, show_linked_models)
133
- else:
134
- return search(search_term, limit, show_linked_models)
135
 
136
 
137
  with gr.Blocks() as demo:
138
- gr.Markdown("## &#129303; Semantic Dataset Search")
139
  with gr.Row():
140
  gr.Markdown(
141
- "This Gradio app allows you to search for datasets based on their"
142
- " descriptions. You can either search for similar datasets to a given"
143
- " dataset or search for datasets based on a query. This is an early proof of concept. Feedback very welcome!"
144
  )
145
  with gr.Row():
146
- search_term = gr.Textbox(
147
- value="movie review sentiment",
148
- label="hub id i.e. IMDB or query i.e. movie review sentiment",
149
  )
150
 
151
  with gr.Row():
152
- with gr.Row():
153
- find_similar_btn = gr.Button("Search")
154
- search_type = gr.Radio(
155
- ["Recommend similar datasets", "Semantic Search"],
156
- label="Search type",
157
- value="Semantic Search",
158
- interactive=True,
159
- )
160
- with gr.Column():
161
- max_results = gr.Slider(
162
- minimum=1,
163
- maximum=50,
164
- step=1,
165
- value=10,
166
- label="Maximum number of results",
167
- )
168
- show_linked_models = gr.Checkbox(
169
- label="Show associated models",
170
- default=False,
171
- )
172
 
173
  results = gr.Markdown()
174
- find_similar_btn.click(
175
- query, [search_term, search_type, max_results, show_linked_models], results
 
 
 
 
176
  )
177
 
178
-
179
  demo.launch()
 
1
+ import asyncio
2
+ import re
3
+ from typing import Dict, List
4
 
5
  import gradio as gr
6
+ import httpx
7
  from dotenv import load_dotenv
8
+ from huggingface_hub import ModelCard
9
+ from cashews import cache
10
+
11
 
12
  load_dotenv()
13
+ cache.setup("mem://")
14
+ API_URL = "https://davanstrien-huggingface-datasets-search-v2.hf.space/similar"
15
+ HF_API_URL = "https://huggingface.co/api/datasets"
16
+ README_URL_TEMPLATE = "https://huggingface.co/datasets/{}/raw/main/README.md"
17
 
 
 
 
18
 
19
+ async def fetch_similar_datasets(dataset_id: str, limit: int = 10) -> List[Dict]:
20
+ async with httpx.AsyncClient() as client:
21
+ response = await client.get(f"{API_URL}?dataset_id={dataset_id}&n={limit}")
22
+ return response.json()["results"][1:] if response.status_code == 200 else []
 
 
 
23
 
24
 
25
+ async def fetch_dataset_card(dataset_id: str) -> str:
26
+ url = README_URL_TEMPLATE.format(dataset_id)
27
+ async with httpx.AsyncClient() as client:
28
+ response = await client.get(url)
29
+ return ModelCard(response.text).text if response.status_code == 200 else ""
 
 
30
 
31
 
32
+ async def fetch_dataset_info(dataset_id: str) -> Dict:
33
+ async with httpx.AsyncClient() as client:
34
+ response = await client.get(f"{HF_API_URL}/{dataset_id}")
35
+ return response.json() if response.status_code == 200 else {}
36
 
37
 
38
+ def format_results(
39
+ results: List[Dict], dataset_cards: List[str], dataset_infos: List[Dict]
40
+ ) -> str:
41
  markdown = (
42
+ "<h1 style='text-align: center;'>&#x2728; Similar Datasets &#x2728;</h1>\n\n"
 
43
  )
44
+ for result, card, info in zip(results, dataset_cards, dataset_infos):
45
+ hub_id = result["dataset_id"]
46
+ similarity = result["similarity"]
 
47
  url = f"https://huggingface.co/datasets/{hub_id}"
48
+
49
+ # Extract title from the card
50
+ title_match = re.match(r"^#\s*(.+)", card, re.MULTILINE)
51
+ title = title_match[1] if title_match else hub_id
52
+
53
+ header = f"## [{title}]({url})"
54
  markdown += header + "\n"
55
+ markdown += f"**Similarity Score:** {similarity:.4f}\n\n"
56
+
57
+ if info:
58
+ downloads = info.get("downloads", 0)
59
+ likes = info.get("likes", 0)
60
+ last_modified = info.get("lastModified", "N/A")
61
+ markdown += f"**Downloads:** {downloads} | **Likes:** {likes} | **Last Modified:** {last_modified}\n\n"
62
+
63
+ if card:
64
+ # Remove the title from the card content
65
+ card_without_title = re.sub(
66
+ r"^#.*\n", "", card, count=1, flags=re.MULTILINE
67
+ )
68
 
69
+ # Split the card into paragraphs
70
+ paragraphs = card_without_title.split("\n\n")
71
+
72
+ # Find the first non-empty text paragraph that's not just an image
73
+ preview = next(
74
+ (
75
+ p
76
+ for p in paragraphs
77
+ if p.strip()
78
+ and not p.strip().startswith("![")
79
+ and not p.strip().startswith("<img")
80
+ ),
81
+ "No preview available.",
82
+ )
 
 
 
83
 
84
+ # Limit the preview to a reasonable length (e.g., 300 characters)
85
+ preview = f"{preview[:300]}..." if len(preview) > 300 else preview
86
 
87
+ # Add the preview
88
+ markdown += f"{preview}\n\n"
89
 
90
+ # Limit image size in the full dataset card
91
+ full_card = re.sub(
92
+ r'<img src="([^"]+)"',
93
+ r'<img src="\1" style="max-width: 300px; max-height: 300px;"',
94
+ card_without_title,
95
+ )
96
+ full_card = re.sub(
97
+ r"!\[([^\]]*)\]\(([^\)]+)\)",
98
+ r'<img src="\2" alt="\1" style="max-width: 300px; max-height: 300px;">',
99
+ full_card,
100
+ )
101
+ markdown += f"<details><summary>Full Dataset Card</summary>\n\n{full_card}\n\n</details>\n\n"
102
 
103
+ markdown += "---\n\n"
104
 
105
+ return markdown
 
 
 
 
 
 
 
 
 
 
106
 
107
 
108
+ async def search_similar_datasets(dataset_id: str, limit: int = 10):
109
+ results = await fetch_similar_datasets(dataset_id, limit)
110
+
111
+ # Fetch dataset cards and info concurrently
112
+ dataset_cards = await asyncio.gather(
113
+ *[fetch_dataset_card(result["dataset_id"]) for result in results]
 
 
 
 
 
 
114
  )
115
+ dataset_infos = await asyncio.gather(
116
+ *[fetch_dataset_info(result["dataset_id"]) for result in results]
 
 
 
 
 
 
 
 
 
 
 
 
117
  )
 
118
 
119
+ return format_results(results, dataset_cards, dataset_infos)
 
 
 
 
 
 
 
 
 
 
120
 
121
 
122
  with gr.Blocks() as demo:
123
+ gr.Markdown("## &#129303; Dataset Similarity Search")
124
  with gr.Row():
125
  gr.Markdown(
126
+ "This Gradio app allows you to find similar datasets based on a given dataset ID. "
127
+ "Enter a dataset ID (e.g., 'imdb') to find similar datasets with previews of their dataset cards."
 
128
  )
129
  with gr.Row():
130
+ dataset_id = gr.Textbox(
131
+ value="imdb",
132
+ label="Dataset ID (e.g., imdb, squad, glue)",
133
  )
134
 
135
  with gr.Row():
136
+ search_btn = gr.Button("Search Similar Datasets")
137
+ max_results = gr.Slider(
138
+ minimum=1,
139
+ maximum=50,
140
+ step=1,
141
+ value=10,
142
+ label="Maximum number of results",
143
+ )
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
  results = gr.Markdown()
146
+ search_btn.click(
147
+ lambda dataset_id, limit: asyncio.run(
148
+ search_similar_datasets(dataset_id, limit)
149
+ ),
150
+ inputs=[dataset_id, max_results],
151
+ outputs=results,
152
  )
153
 
 
154
  demo.launch()