ccm's picture
Update main.py
f97030a verified
raw
history blame
3.73 kB
import json # For stringifying a dict
import random # For selecting a search hint
import gradio # GUI framework
import datasets # Used to load publication dataset
import numpy # For a few simple matrix operations
import pandas # Needed for operating on dataset
import sentence_transformers # Needed for query embedding
import faiss # Needed for fast similarity search
# Load the dataset and convert to pandas
full_data = datasets.load_dataset("ccm/publications")["train"].to_pandas()
# Define the base URL for Google Scholar
SCHOLAR_URL = "https://scholar.google.com"
# Filter out any publications without an abstract
filter = [
'"abstract": null' in json.dumps(bibdict)
for bibdict in full_data["bib_dict"].values
]
data = full_data[~pandas.Series(filter)]
data.reset_index(inplace=True)
# Create a FAISS index for fast similarity search
indices = []
metrics = [faiss.METRIC_INNER_PRODUCT ,faiss.METRIC_L2]
normalization = [True, False]
vectors = numpy.stack(data["embedding"].tolist(), axis=0)
for metric in metrics:
for normal in normalization:
index = faiss.IndexFlatL2(len(data["embedding"][0]))
index.metric_type = metric
if normal:
faiss.normalize_L2(vectors)
index.train(vectors)
index.add(vectors)
indices.append(index)
# Load the model for later use in embeddings
model = sentence_transformers.SentenceTransformer("allenai-specter")
# Define the search function
def search(query: str, k: int, n: int):
query = numpy.expand_dims(model.encode(query), axis=0)
faiss.normalize_L2(query)
D, I = indices[n].search(query, k)
top_five = data.loc[I[0]]
search_results = ""
for i in range(k):
search_results += "### " + top_five["bib_dict"].values[i]["title"] + "\n\n"
search_results += str(int(100*D[0][i])) + "% relevant "
if top_five["author_pub_id"].values[i] is not None:
search_results += "/ [Full Text](https://scholar.google.com/citations?view_op=view_citation&citation_for_view=" + top_five["author_pub_id"].values[i] + ") "
if top_five["citedby_url"].values[i] is not None:
search_results += (
"/ [Cited By](" + SCHOLAR_URL + top_five["citedby_url"].values[i] + ") "
)
if top_five["url_related_articles"].values[i] is not None:
search_results += (
"/ [Related Articles]("
+ SCHOLAR_URL
+ top_five["url_related_articles"].values[i]
+ ") "
)
search_results += "\n\n```bibtex\n"
search_results += (
json.dumps(top_five["bibtex"].values[i], indent=4)
.replace("\\n", "\n")
.replace("\\t", "\t")
.strip('"')
)
search_results += "```\n"
return search_results
with gradio.Blocks() as demo:
with gradio.Group():
query = gradio.Textbox(
placeholder = random.choice([
"design for additive manufacturing",
"best practices for agent based modeling",
"arctic environmental science",
"analysis of student teamwork"
]),
show_label=False,
lines=1,
max_lines=1
)
with gradio.Accordion("Settings", open=False):
k = gradio.Number(10.0, label="Number of results", precision=0)
k = gradio.Radio([True, False], label="Normalized")
results = gradio.Markdown()
query.change(fn=search, inputs=[query, k, n], outputs=results)
k.change(fn=search, inputs=[query, k, n], outputs=results)
n.change(fn=search, inputs=[query, k, n], outputs=results)
demo.launch(debug=True)