File size: 3,734 Bytes
a64dda4
d114225
6a4b3a2
a64dda4
 
6a4b3a2
a64dda4
 
 
 
c6cda2e
a64dda4
 
5e710c8
a64dda4
 
 
 
 
 
 
 
6a4b3a2
 
c6cda2e
a64dda4
e550b0a
1f79ef9
 
bc256ab
bd640c9
1fb5651
1f79ef9
 
 
 
 
 
 
bc256ab
 
 
c6cda2e
a64dda4
1f79ef9
bc256ab
 
1f79ef9
c6cda2e
 
 
 
a64dda4
6c108b9
92714f7
1f39bff
eb9eac3
a64dda4
7b9f27a
a64dda4
eb9eac3
a64dda4
7b9f27a
a64dda4
 
971d097
a64dda4
dc47567
a64dda4
 
 
 
 
 
36a1e14
c6cda2e
 
6a4b3a2
c6cda2e
ff52d8e
a64dda4
65dab0c
 
56e3607
65dab0c
 
 
 
 
 
a64dda4
ff52d8e
281c5f8
f97030a
9cf274d
1f79ef9
 
 
c6cda2e
a64dda4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import json  # For stringifying a dict
import random # For selecting a search hint

import gradio  # GUI framework
import datasets  # Used to load publication dataset

import numpy  # For a few simple matrix operations
import pandas  # Needed for operating on dataset
import sentence_transformers  # Needed for query embedding
import faiss  # Needed for fast similarity search

# Load the dataset and convert to pandas
full_data = datasets.load_dataset("ccm/publications")["train"].to_pandas()

# Define the base URL for Google Scholar
SCHOLAR_URL = "https://scholar.google.com"

# Filter out any publications without an abstract
filter = [
    '"abstract": null' in json.dumps(bibdict)
    for bibdict in full_data["bib_dict"].values
]
data = full_data[~pandas.Series(filter)]
data.reset_index(inplace=True)

# Create a FAISS index for fast similarity search
indices = []
metrics = [faiss.METRIC_INNER_PRODUCT ,faiss.METRIC_L2]
normalization = [True, False]
vectors = numpy.stack(data["embedding"].tolist(), axis=0)
for  metric in metrics:
    for normal in normalization:
        index = faiss.IndexFlatL2(len(data["embedding"][0]))
        index.metric_type = metric
        if normal:
            faiss.normalize_L2(vectors)
        index.train(vectors)
        index.add(vectors)
        indices.append(index)

# Load the model for later use in embeddings
model = sentence_transformers.SentenceTransformer("allenai-specter")

# Define the search function
def search(query: str, k: int, n: int):
    query = numpy.expand_dims(model.encode(query), axis=0)
    faiss.normalize_L2(query)
    D, I = indices[n].search(query, k)
    top_five = data.loc[I[0]]
    search_results = ""

    for i in range(k):
        search_results += "### " + top_five["bib_dict"].values[i]["title"] + "\n\n"
        search_results += str(int(100*D[0][i])) + "% relevant  "
        if top_five["author_pub_id"].values[i] is not None:
            search_results += "/  [Full Text](https://scholar.google.com/citations?view_op=view_citation&citation_for_view=" + top_five["author_pub_id"].values[i] + ")  "
        if top_five["citedby_url"].values[i] is not None:
            search_results += (
                "/  [Cited By](" + SCHOLAR_URL + top_five["citedby_url"].values[i] + ")  "
            )
        if top_five["url_related_articles"].values[i] is not None:
            search_results += (
                "/  [Related Articles]("
                + SCHOLAR_URL
                + top_five["url_related_articles"].values[i]
                + ")  "
            )
        search_results += "\n\n```bibtex\n"
        search_results += (
            json.dumps(top_five["bibtex"].values[i], indent=4)
            .replace("\\n", "\n")
            .replace("\\t", "\t")
            .strip('"')
        )
        search_results += "```\n"
    return search_results


with gradio.Blocks() as demo:
    with gradio.Group():
        query = gradio.Textbox(
            placeholder = random.choice([
                "design for additive manufacturing",
                "best practices for agent based modeling",
                "arctic environmental science",
                "analysis of student teamwork"
            ]), 
            show_label=False, 
            lines=1, 
            max_lines=1
        )
        with gradio.Accordion("Settings", open=False):
            k = gradio.Number(10.0, label="Number of results", precision=0)
            k = gradio.Radio([True, False], label="Normalized")
    results = gradio.Markdown()
    query.change(fn=search, inputs=[query, k, n], outputs=results)
    k.change(fn=search, inputs=[query, k, n], outputs=results)
    n.change(fn=search, inputs=[query, k, n], outputs=results)

demo.launch(debug=True)