hysts HF Staff commited on
Commit
c2aa37f
·
1 Parent(s): 496c8a5
Files changed (5) hide show
  1. app.py +22 -49
  2. app_mcp.py +129 -0
  3. search.py +30 -0
  4. semantic_search.py +0 -41
  5. table.py +1 -1
app.py CHANGED
@@ -4,8 +4,9 @@ import gradio as gr
4
  import polars as pl
5
  from gradio_modal import Modal
6
 
 
7
  from app_pr import demo as demo_pr
8
- from semantic_search import semantic_search
9
  from table import df_orig
10
 
11
  DESCRIPTION = "# ICLR 2025"
@@ -59,10 +60,7 @@ df_main = df_orig.select(
59
 
60
  df_main = df_main.with_columns(
61
  [
62
- pl.when(pl.col(col) == "").then(None).otherwise(pl.col(col))
63
- .cast(pl.Int64)
64
- .fill_null(0)
65
- .alias(col)
66
  for col in ["upvotes", "num_comments"]
67
  ]
68
  )
@@ -120,32 +118,25 @@ def update_num_papers(df: pl.DataFrame) -> str:
120
 
121
 
122
  def update_df(
123
- search_mode: str,
124
  search_query: str,
125
  candidate_pool_size: int,
126
- score_threshold: float,
127
  presentation_type: str,
128
  column_names: list[str],
129
- case_insensitive: bool = True,
130
  ) -> gr.Dataframe:
 
 
 
131
  df = df_main.clone()
132
  column_names = ["Title", *column_names]
133
 
134
  if search_query:
135
- if search_mode == "Title Search":
136
- if case_insensitive:
137
- search_query = f"(?i){search_query}"
138
- try:
139
- df = df.filter(pl.col("Title").str.contains(search_query))
140
- except pl.exceptions.ComputeError as e:
141
- raise gr.Error(str(e)) from e
142
  else:
143
- paper_ids, scores = semantic_search(search_query, candidate_pool_size, score_threshold)
144
- if not paper_ids:
145
- df = df.head(0)
146
- else:
147
- df = pl.DataFrame({"paper_id": paper_ids, "score": scores}).join(df, on="paper_id", how="inner")
148
- df = df.sort("score", descending=True).drop("score")
149
 
150
  if presentation_type != "(ALL)":
151
  df = df.filter(pl.col("Type").str.contains(presentation_type))
@@ -159,10 +150,6 @@ def update_df(
159
  )
160
 
161
 
162
- def update_search_mode(search_mode: str) -> gr.Accordion:
163
- return gr.Accordion(visible=search_mode == "Semantic Search")
164
-
165
-
166
  def df_row_selected(
167
  evt: gr.SelectData,
168
  ) -> tuple[
@@ -186,21 +173,11 @@ with gr.Blocks(css_paths="style.css") as demo:
186
  gr.Markdown(DESCRIPTION)
187
  with gr.Accordion(label="Tutorial", open=True):
188
  gr.Markdown(TUTORIAL)
189
- with gr.Group():
190
- search_mode = gr.Radio(
191
- label="Search Mode",
192
- choices=["Semantic Search", "Title Search"],
193
- value="Semantic Search",
194
- show_label=False,
195
- info="Note: Semantic search consumes your ZeroGPU quota.",
196
- )
197
- search_query = gr.Textbox(label="Search", submit_btn=True, show_label=False, placeholder="Enter query here")
198
- with gr.Accordion(label="Advanced Search Options", open=False) as advanced_search_options:
199
- with gr.Row():
200
- candidate_pool_size = gr.Slider(
201
- label="Candidate Pool Size", minimum=1, maximum=1000, step=1, value=300
202
- )
203
- score_threshold = gr.Slider(label="Score Threshold", minimum=0, maximum=1, step=0.01, value=0.5)
204
 
205
  presentation_type = gr.Radio(
206
  label="Presentation Type",
@@ -231,19 +208,12 @@ with gr.Blocks(css_paths="style.css") as demo:
231
  title = gr.Textbox(label="Title")
232
  abstract = gr.Textbox(label="Abstract")
233
 
234
- search_mode.change(
235
- fn=update_search_mode,
236
- inputs=search_mode,
237
- outputs=advanced_search_options,
238
- )
239
-
240
  df.select(fn=df_row_selected, outputs=[abstract_modal, title, abstract])
241
 
242
  inputs = [
243
- search_mode,
244
  search_query,
245
  candidate_pool_size,
246
- score_threshold,
247
  presentation_type,
248
  column_names,
249
  ]
@@ -277,10 +247,13 @@ with gr.Blocks(css_paths="style.css") as demo:
277
  api_name=False,
278
  )
279
 
 
 
 
280
 
281
  with demo.route("Open PR"):
282
  demo_pr.render()
283
 
284
 
285
  if __name__ == "__main__":
286
- demo.queue(api_open=False).launch(show_api=False)
 
4
  import polars as pl
5
  from gradio_modal import Modal
6
 
7
+ from app_mcp import demo as demo_mcp
8
  from app_pr import demo as demo_pr
9
+ from search import search
10
  from table import df_orig
11
 
12
  DESCRIPTION = "# ICLR 2025"
 
60
 
61
  df_main = df_main.with_columns(
62
  [
63
+ pl.when(pl.col(col) == "").then(None).otherwise(pl.col(col)).cast(pl.Int64).fill_null(0).alias(col)
 
 
 
64
  for col in ["upvotes", "num_comments"]
65
  ]
66
  )
 
118
 
119
 
120
  def update_df(
 
121
  search_query: str,
122
  candidate_pool_size: int,
123
+ num_results: int,
124
  presentation_type: str,
125
  column_names: list[str],
 
126
  ) -> gr.Dataframe:
127
+ if num_results > candidate_pool_size:
128
+ raise gr.Error("Number of results must be less than or equal to candidate pool size", print_exception=False)
129
+
130
  df = df_main.clone()
131
  column_names = ["Title", *column_names]
132
 
133
  if search_query:
134
+ results = search(search_query, candidate_pool_size, num_results)
135
+ if not results:
136
+ df = df.head(0)
 
 
 
 
137
  else:
138
+ df = pl.DataFrame(results).join(df, on="paper_id", how="inner")
139
+ df = df.sort("ce_score", descending=True).drop("ce_score")
 
 
 
 
140
 
141
  if presentation_type != "(ALL)":
142
  df = df.filter(pl.col("Type").str.contains(presentation_type))
 
150
  )
151
 
152
 
 
 
 
 
153
  def df_row_selected(
154
  evt: gr.SelectData,
155
  ) -> tuple[
 
173
  gr.Markdown(DESCRIPTION)
174
  with gr.Accordion(label="Tutorial", open=True):
175
  gr.Markdown(TUTORIAL)
176
+ search_query = gr.Textbox(label="Search", submit_btn=True, show_label=False, placeholder="Search...")
177
+ with gr.Accordion(label="Advanced Search Options", open=False) as advanced_search_options:
178
+ with gr.Row():
179
+ candidate_pool_size = gr.Slider(label="Candidate Pool Size", minimum=1, maximum=600, step=1, value=200)
180
+ num_results = gr.Slider(label="Number of Results", minimum=1, maximum=400, step=1, value=100)
 
 
 
 
 
 
 
 
 
 
181
 
182
  presentation_type = gr.Radio(
183
  label="Presentation Type",
 
208
  title = gr.Textbox(label="Title")
209
  abstract = gr.Textbox(label="Abstract")
210
 
 
 
 
 
 
 
211
  df.select(fn=df_row_selected, outputs=[abstract_modal, title, abstract])
212
 
213
  inputs = [
 
214
  search_query,
215
  candidate_pool_size,
216
+ num_results,
217
  presentation_type,
218
  column_names,
219
  ]
 
247
  api_name=False,
248
  )
249
 
250
+ with gr.Row(visible=False):
251
+ demo_mcp.render()
252
+
253
 
254
  with demo.route("Open PR"):
255
  demo_pr.render()
256
 
257
 
258
  if __name__ == "__main__":
259
+ demo.launch(mcp_server=True)
app_mcp.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import polars as pl
3
+
4
+ from search import search
5
+ from table import df_orig
6
+
7
+ COLUMNS_MCP = [
8
+ "title",
9
+ "authors",
10
+ "abstract",
11
+ "openreview_url",
12
+ "arxiv_id",
13
+ "paper_page",
14
+ "space_ids",
15
+ "model_ids",
16
+ "dataset_ids",
17
+ "upvotes",
18
+ "num_comments",
19
+ "project_page",
20
+ "github",
21
+ "row_index",
22
+ ]
23
+ DEFAULT_COLUMNS_MCP = [
24
+ "title",
25
+ "authors",
26
+ "abstract",
27
+ "openreview_url",
28
+ "arxiv_id",
29
+ "project_page",
30
+ "github",
31
+ "row_index",
32
+ ]
33
+
34
+ df_mcp = df_orig.rename({"openreview": "openreview_url", "paper_id": "row_index"}).select(COLUMNS_MCP)
35
+
36
+
37
+ def search_papers(
38
+ search_query: str,
39
+ candidate_pool_size: int,
40
+ num_results: int,
41
+ columns: list[str],
42
+ ) -> list[dict]:
43
+ """Searches ICLR 2025 papers relevant to a user query in English.
44
+
45
+ This function performs a semantic search over ICLR 2025 papers.
46
+ It uses a dual-stage retrieval process:
47
+ - First, it retrieves `candidate_pool_size` papers using dense vector similarity.
48
+ - Then, it re-ranks them with a cross-encoder model to select the top `num_results` most relevant papers.
49
+ - The search results are returned as a list of dictionaries.
50
+
51
+ Note:
52
+ The search query must be written in English. Queries in other languages are not supported.
53
+
54
+ Args:
55
+ search_query (str): The natural language query input by the user. Must be in English.
56
+ candidate_pool_size (int): Number of candidate papers to retrieve using the dense vector model.
57
+ num_results (int): Final number of top-ranked papers to return after re-ranking.
58
+ columns (list[str]): The columns to select from the DataFrame.
59
+
60
+ Returns:
61
+ list[dict]: A list of dictionaries of the top-ranked papers matching the query, sorted by relevance.
62
+ """
63
+ if not search_query:
64
+ raise ValueError("Search query cannot be empty")
65
+ if num_results > candidate_pool_size:
66
+ raise ValueError("Number of results must be less than or equal to candidate pool size")
67
+
68
+ df = df_mcp.clone()
69
+ results = search(search_query, candidate_pool_size, num_results)
70
+ df = pl.DataFrame(results).rename({"paper_id": "row_index"}).join(df, on="row_index", how="inner")
71
+ df = df.sort("ce_score", descending=True)
72
+ return df.select(columns).to_dicts()
73
+
74
+
75
+ def get_metadata(row_index: int) -> dict:
76
+ """Returns a dictionary of metadata for a ICLR 2025 paper at the given table row index.
77
+
78
+ Args:
79
+ row_index (int): The index of the paper in the internal paper list table.
80
+
81
+ Returns:
82
+ dict: A dictionary containing metadata for the corresponding paper.
83
+ """
84
+ return df_mcp.filter(pl.col("row_index") == row_index).to_dicts()[0]
85
+
86
+
87
+ def get_table(columns: list[str]) -> list[dict]:
88
+ """Returns a list of dictionaries of all ICLR 2025 papers.
89
+
90
+ Args:
91
+ columns (list[str]): The columns to select from the DataFrame.
92
+
93
+ Returns:
94
+ list[dict]: A list of dictionaries of all ICLR 2025 papers.
95
+ """
96
+ return df_mcp.select(columns).to_dicts()
97
+
98
+
99
+ with gr.Blocks() as demo:
100
+ search_query = gr.Textbox(label="Search", submit_btn=True)
101
+ candidate_pool_size = gr.Slider(label="Candidate Pool Size", minimum=1, maximum=500, step=1, value=200)
102
+ num_results = gr.Slider(label="Number of Results", minimum=1, maximum=400, step=1, value=100)
103
+ column_names = gr.CheckboxGroup(label="Columns", choices=COLUMNS_MCP, value=DEFAULT_COLUMNS_MCP)
104
+ row_index = gr.Slider(label="Row Index", minimum=0, maximum=len(df_mcp) - 1, step=1, value=0)
105
+
106
+ out = gr.JSON()
107
+
108
+ search_papers_btn = gr.Button("Search Papers")
109
+ get_metadata_btn = gr.Button("Get Metadata")
110
+ get_table_btn = gr.Button("Get Table")
111
+
112
+ search_papers_btn.click(
113
+ fn=search_papers,
114
+ inputs=[search_query, candidate_pool_size, num_results, column_names],
115
+ outputs=out,
116
+ )
117
+ get_metadata_btn.click(
118
+ fn=get_metadata,
119
+ inputs=row_index,
120
+ outputs=out,
121
+ )
122
+ get_table_btn.click(
123
+ fn=get_table,
124
+ inputs=column_names,
125
+ outputs=out,
126
+ )
127
+
128
+ if __name__ == "__main__":
129
+ demo.launch(mcp_server=True)
search.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datasets
2
+ import numpy as np
3
+ import spaces
4
+ from sentence_transformers import CrossEncoder, SentenceTransformer
5
+
6
+ from table import BASE_REPO_ID
7
+
8
+ ds = datasets.load_dataset(BASE_REPO_ID, split="train")
9
+ ds.add_faiss_index(column="embedding")
10
+
11
+ bi_model = SentenceTransformer("BAAI/bge-base-en-v1.5")
12
+ ce_model = CrossEncoder("BAAI/bge-reranker-base")
13
+
14
+
15
+ @spaces.GPU(duration=10)
16
+ def search(query: str, candidate_pool_size: int = 100, retrieval_k: int = 50) -> list[dict]:
17
+ prefix = "Represent this sentence for searching relevant passages: "
18
+ q_vec = bi_model.encode(prefix + query, normalize_embeddings=True)
19
+
20
+ _, retrieved_ds = ds.get_nearest_examples("embedding", q_vec, k=candidate_pool_size)
21
+
22
+ ce_inputs = [
23
+ (query, f"{retrieved_ds['title'][i]} {retrieved_ds['abstract'][i]}") for i in range(len(retrieved_ds["title"]))
24
+ ]
25
+ ce_scores = ce_model.predict(ce_inputs, batch_size=16)
26
+
27
+ sorted_idx = np.argsort(ce_scores)[::-1]
28
+ return [
29
+ {"paper_id": retrieved_ds["paper_id"][i], "ce_score": float(ce_scores[i])} for i in sorted_idx[:retrieval_k]
30
+ ]
semantic_search.py DELETED
@@ -1,41 +0,0 @@
1
- import datasets
2
- import numpy as np
3
- import scipy.spatial
4
- import scipy.special
5
- import spaces
6
- from sentence_transformers import CrossEncoder, SentenceTransformer
7
-
8
- from table import BASE_REPO_ID
9
-
10
- ds = datasets.load_dataset(BASE_REPO_ID, split="train")
11
- ds = ds.rename_column("submission_number", "paper_id")
12
- ds.add_faiss_index(column="embedding")
13
-
14
- model = SentenceTransformer("all-MiniLM-L6-v2")
15
- reranker = CrossEncoder("cross-encoder/ms-marco-MiniLM-L-6-v2")
16
-
17
-
18
- @spaces.GPU(duration=5)
19
- def semantic_search(
20
- query: str, candidate_pool_size: int = 300, score_threshold: float = 0.5
21
- ) -> tuple[list[int], list[float]]:
22
- query_vec = model.encode(query)
23
- _, retrieved_data = ds.get_nearest_examples("embedding", query_vec, k=candidate_pool_size)
24
-
25
- rerank_inputs = [
26
- [query, f"{title}\n{abstract}"]
27
- for title, abstract in zip(retrieved_data["title"], retrieved_data["abstract"], strict=True)
28
- ]
29
- rerank_scores = reranker.predict(rerank_inputs)
30
- sorted_indices = np.argsort(rerank_scores)[::-1]
31
-
32
- paper_ids = []
33
- scores = []
34
- for i in sorted_indices:
35
- score = float(scipy.special.expit(rerank_scores[i]))
36
- if score < score_threshold:
37
- break
38
- paper_ids.append(retrieved_data["paper_id"][i])
39
- scores.append(score)
40
-
41
- return paper_ids, scores
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
table.py CHANGED
@@ -61,7 +61,7 @@ def format_author_claim_ratio(row: dict) -> str:
61
  df_orig = (
62
  datasets.load_dataset(BASE_REPO_ID, split="train")
63
  .to_polars()
64
- .rename({"paper_url": "openreview", "submission_number": "paper_id"})
65
  .with_columns(
66
  pl.lit([], dtype=pl.List(pl.Utf8)).alias(col_name) for col_name in ["space_ids", "model_ids", "dataset_ids"]
67
  )
 
61
  df_orig = (
62
  datasets.load_dataset(BASE_REPO_ID, split="train")
63
  .to_polars()
64
+ .rename({"paper_url": "openreview"})
65
  .with_columns(
66
  pl.lit([], dtype=pl.List(pl.Utf8)).alias(col_name) for col_name in ["space_ids", "model_ids", "dataset_ids"]
67
  )