|
import re |
|
import gradio as gr |
|
from scipy.sparse import load_npz |
|
import numpy as np |
|
import json |
|
from datasets import load_dataset |
|
import os |
|
print("Current working directory:", os.getcwd()) |
|
print("Files:", os.listdir()) |
|
|
|
|
|
|
|
|
|
with open("feature_names.txt", "r") as f: |
|
feature_names = [line.strip() for line in f] |
|
|
|
tfidf_matrix = load_npz("tfidf_matrix_train.npz") |
|
|
|
|
|
dataset = load_dataset("ccdv/arxiv-classification", "no_ref") |
|
|
|
documents = [] |
|
titles = [] |
|
arxiv_ids = [] |
|
|
|
for item in dataset["train"]: |
|
text = item["text"] |
|
if not text or len(text.strip()) < 10: |
|
continue |
|
|
|
lines = text.splitlines() |
|
title_lines = [] |
|
found_arxiv = False |
|
arxiv_id = None |
|
|
|
for line in lines: |
|
line_strip = line.strip() |
|
if not found_arxiv and line_strip.lower().startswith("arxiv:"): |
|
found_arxiv = True |
|
match = re.search(r'arxiv:\d{4}\.\d{4,5}v\d', line_strip, flags=re.IGNORECASE) |
|
if match: |
|
arxiv_id = match.group(0).lower() |
|
elif not found_arxiv: |
|
title_lines.append(line_strip) |
|
else: |
|
if line_strip.lower().startswith("abstract"): |
|
break |
|
|
|
title = " ".join(title_lines).strip() |
|
documents.append(text.strip()) |
|
titles.append(title) |
|
arxiv_ids.append(arxiv_id) |
|
|
|
|
|
def keyword_match_ranking(query, top_n=5): |
|
query_terms = query.lower().split() |
|
query_indices = [i for i, term in enumerate(feature_names) if term in query_terms] |
|
if not query_indices: |
|
return [] |
|
|
|
scores = [] |
|
for doc_idx in range(tfidf_matrix.shape[0]): |
|
doc_vector = tfidf_matrix[doc_idx] |
|
doc_score = sum(doc_vector[0, i] for i in query_indices) |
|
if doc_score > 0: |
|
scores.append((doc_idx, doc_score)) |
|
|
|
scores.sort(key=lambda x: x[1], reverse=True) |
|
return scores[:top_n] |
|
|
|
|
|
def snippet_before_abstract(text): |
|
pattern = re.compile(r'a\s*b\s*s\s*t\s*r\s*a\s*c\s*t|i\s*n\s*t\s*r\s*o\s*d\s*u\s*c\s*t\s*i\s*o\s*n', re.IGNORECASE) |
|
match = pattern.search(text) |
|
if match: |
|
return text[:match.start()].strip() |
|
else: |
|
return text[:100].strip() |
|
|
|
|
|
def search_function(query): |
|
results = keyword_match_ranking(query) |
|
if not results: |
|
return "No results found." |
|
|
|
output = "" |
|
display_rank = 1 |
|
for idx, score in results: |
|
if not arxiv_ids[idx]: |
|
continue |
|
|
|
link = f"https://arxiv.org/abs/{arxiv_ids[idx].replace('arxiv:', '')}" |
|
snippet = snippet_before_abstract(documents[idx]).replace('\n', '<br>') |
|
output += f"### Document {display_rank}\n" |
|
output += f"[arXiv Link]({link})\n\n" |
|
output += f"<pre>{snippet}</pre>\n\n---\n" |
|
display_rank += 1 |
|
|
|
return output |
|
|
|
|
|
iface = gr.Interface( |
|
fn=search_function, |
|
inputs=gr.Textbox(lines=1, placeholder="Enter your search query"), |
|
outputs=gr.Markdown(), |
|
title="arXiv Search Engine", |
|
description="Search TF-IDF encoded arXiv papers by keyword.", |
|
) |
|
|
|
iface.launch() |
|
|