Spaces:
Runtime error
Runtime error
| import json # For stringifying a dict | |
| import random # For selecting a search hint | |
| import gradio # GUI framework | |
| import datasets # Used to load publication dataset | |
| import numpy # For a few simple matrix operations | |
| import pandas # Needed for operating on dataset | |
| import sentence_transformers # Needed for query embedding | |
| import faiss # Needed for fast similarity search | |
| # Load the dataset and convert to pandas | |
| full_data = datasets.load_dataset("ccm/publications")["train"].to_pandas() | |
| # Define the base URL for Google Scholar | |
| SCHOLAR_URL = "https://scholar.google.com" | |
| # Filter out any publications without an abstract | |
| filter = [ | |
| '"abstract": null' in json.dumps(bibdict) | |
| for bibdict in full_data["bib_dict"].values | |
| ] | |
| data = full_data[~pandas.Series(filter)] | |
| data.reset_index(inplace=True) | |
| # Create a FAISS index for fast similarity search | |
| indices = [] | |
| metrics = [faiss.METRIC_INNER_PRODUCT ,faiss.METRIC_L2] | |
| normalization = [True, False] | |
| vectors = numpy.stack(data["embedding"].tolist(), axis=0) | |
| for metric in metrics: | |
| for normal in normalization: | |
| index = faiss.IndexFlatL2(len(data["embedding"][0])) | |
| index.metric_type = metric | |
| if normal: | |
| faiss.normalize_L2(vectors) | |
| index.train(vectors) | |
| index.add(vectors) | |
| indices.append(index) | |
| # Load the model for later use in embeddings | |
| model = sentence_transformers.SentenceTransformer("allenai-specter") | |
| # Define the search function | |
| def search(query: str, k: int, n: int): | |
| query = numpy.expand_dims(model.encode(query), axis=0) | |
| faiss.normalize_L2(query) | |
| D, I = indices[n].search(query, k) | |
| top_five = data.loc[I[0]] | |
| search_results = "" | |
| for i in range(k): | |
| search_results += "### " + top_five["bib_dict"].values[i]["title"] + "\n\n" | |
| search_results += str(int(100*D[0][i])) + "% relevant " | |
| if top_five["author_pub_id"].values[i] is not None: | |
| search_results += "/ [Full Text](https://scholar.google.com/citations?view_op=view_citation&citation_for_view=" + top_five["author_pub_id"].values[i] + ") " | |
| if top_five["citedby_url"].values[i] is not None: | |
| search_results += ( | |
| "/ [Cited By](" + SCHOLAR_URL + top_five["citedby_url"].values[i] + ") " | |
| ) | |
| if top_five["url_related_articles"].values[i] is not None: | |
| search_results += ( | |
| "/ [Related Articles](" | |
| + SCHOLAR_URL | |
| + top_five["url_related_articles"].values[i] | |
| + ") " | |
| ) | |
| search_results += "\n\n```bibtex\n" | |
| search_results += ( | |
| json.dumps(top_five["bibtex"].values[i], indent=4) | |
| .replace("\\n", "\n") | |
| .replace("\\t", "\t") | |
| .strip('"') | |
| ) | |
| search_results += "```\n" | |
| return search_results | |
| with gradio.Blocks() as demo: | |
| with gradio.Group(): | |
| query = gradio.Textbox( | |
| placeholder = random.choice([ | |
| "design for additive manufacturing", | |
| "best practices for agent based modeling", | |
| "arctic environmental science", | |
| "analysis of student teamwork" | |
| ]), | |
| show_label=False, | |
| lines=1, | |
| max_lines=1 | |
| ) | |
| with gradio.Accordion("Settings", open=False): | |
| k = gradio.Number(10.0, label="Number of results", precision=0) | |
| k = gradio.Radio([True, False], label="Normalized") | |
| results = gradio.Markdown() | |
| query.change(fn=search, inputs=[query, k, n], outputs=results) | |
| k.change(fn=search, inputs=[query, k, n], outputs=results) | |
| n.change(fn=search, inputs=[query, k, n], outputs=results) | |
| demo.launch(debug=True) | |