Jonas Leeb commited on
Commit
7285400
·
1 Parent(s): da3c141

all other embeddings implemented, changed to class

Browse files
BERT embeddings/bert_embedding.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:761d01d079ba768682ce1146f6f6405d45b3c84e4052a12b0372d774d02dc4ca
3
+ size 81117464
TF-IDF embeddings/feature_names.txt ADDED
The diff for this file is too large to render. See raw diff
 
TF-IDF embeddings/tfidf_matrix_train.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3171341038274665272e760905eab46b6358481041a6efa6ed6f6669fc31ec5b
3
+ size 222218116
Word2Vec embeddings/word2vec_embedding.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:37ca6935e9edc41c12756eef5e62b4393c1b9bdb2c1cc4a5d1359236d1d03cd8
3
+ size 65242631
app.py CHANGED
@@ -1,110 +1,186 @@
1
  import re
2
  import gradio as gr
3
  from scipy.sparse import load_npz
 
 
 
 
4
  import numpy as np
5
  import json
6
  from datasets import load_dataset
7
  import os
8
- print("Current working directory:", os.getcwd())
9
- print("Files:", os.listdir())
10
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
 
13
- # --- Load data and embeddings ---
14
- with open("feature_names.txt", "r") as f:
15
- feature_names = [line.strip() for line in f]
 
 
 
 
 
 
 
16
 
17
- tfidf_matrix = load_npz("tfidf_matrix_train.npz")
 
18
 
19
- # Load dataset and initialize search engine
20
- dataset = load_dataset("ccdv/arxiv-classification", "no_ref") # replace with your dataset
 
 
 
21
 
22
- documents = []
23
- titles = []
24
- arxiv_ids = []
 
 
 
25
 
26
- for item in dataset["train"]:
27
- text = item["text"]
28
- if not text or len(text.strip()) < 10:
29
- continue
30
 
31
- lines = text.splitlines()
32
- title_lines = []
33
- found_arxiv = False
34
- arxiv_id = None
35
 
36
- for line in lines:
37
- line_strip = line.strip()
38
- if not found_arxiv and line_strip.lower().startswith("arxiv:"):
39
- found_arxiv = True
40
- match = re.search(r'arxiv:\d{4}\.\d{4,5}v\d', line_strip, flags=re.IGNORECASE)
41
- if match:
42
- arxiv_id = match.group(0).lower()
43
- elif not found_arxiv:
44
- title_lines.append(line_strip)
45
- else:
46
- if line_strip.lower().startswith("abstract"):
47
- break
48
-
49
- title = " ".join(title_lines).strip()
50
- documents.append(text.strip())
51
- titles.append(title)
52
- arxiv_ids.append(arxiv_id)
53
-
54
-
55
- def keyword_match_ranking(query, top_n=5):
56
- query_terms = query.lower().split()
57
- query_indices = [i for i, term in enumerate(feature_names) if term in query_terms]
58
- if not query_indices:
59
- return []
60
-
61
- scores = []
62
- for doc_idx in range(tfidf_matrix.shape[0]):
63
- doc_vector = tfidf_matrix[doc_idx]
64
- doc_score = sum(doc_vector[0, i] for i in query_indices)
65
- if doc_score > 0:
66
- scores.append((doc_idx, doc_score))
67
-
68
- scores.sort(key=lambda x: x[1], reverse=True)
69
- return scores[:top_n]
70
-
71
-
72
- def snippet_before_abstract(text):
73
- pattern = re.compile(r'a\s*b\s*s\s*t\s*r\s*a\s*c\s*t|i\s*n\s*t\s*r\s*o\s*d\s*u\s*c\s*t\s*i\s*o\s*n', re.IGNORECASE)
74
- match = pattern.search(text)
75
- if match:
76
- return text[:match.start()].strip()
77
- else:
78
- return text[:100].strip()
79
-
80
-
81
- def search_function(query):
82
- results = keyword_match_ranking(query)
83
- if not results:
84
- return "No results found."
85
-
86
- output = ""
87
- display_rank = 1
88
- for idx, score in results:
89
- if not arxiv_ids[idx]:
90
- continue
91
-
92
- link = f"https://arxiv.org/abs/{arxiv_ids[idx].replace('arxiv:', '')}"
93
- snippet = snippet_before_abstract(documents[idx]).replace('\n', '<br>')
94
- output += f"### Document {display_rank}\n"
95
- output += f"[arXiv Link]({link})\n\n"
96
- output += f"<pre>{snippet}</pre>\n\n---\n"
97
- display_rank += 1
98
-
99
- return output
100
-
101
-
102
- iface = gr.Interface(
103
- fn=search_function,
104
- inputs=gr.Textbox(lines=1, placeholder="Enter your search query"),
105
- outputs=gr.Markdown(),
106
- title="arXiv Search Engine",
107
- description="Search TF-IDF encoded arXiv papers by keyword.",
108
- )
109
-
110
- iface.launch()
 
1
  import re
2
  import gradio as gr
3
  from scipy.sparse import load_npz
4
+ import torch
5
+ from sklearn.metrics.pairwise import cosine_similarity
6
+ from sklearn.preprocessing import normalize
7
+ from transformers import BertTokenizer, BertModel
8
  import numpy as np
9
  import json
10
  from datasets import load_dataset
11
  import os
12
+ from gensim.models import KeyedVectors
13
+
14
+
15
+
16
+ class ArxivSearch:
17
+ def __init__(self, dataset, embedding="tfidf"):
18
+ self.dataset = dataset
19
+ self.embedding = embedding
20
+ self.documents = []
21
+ self.titles = []
22
+ self.raw_texts = []
23
+ self.arxiv_ids = []
24
+
25
+ self.embedding_dropdown = gr.Dropdown(
26
+ choices=["tfidf", "word2vec", "bert"],
27
+ value="tfidf",
28
+ label="Model"
29
+ )
30
+
31
+ self.iface = gr.Interface(
32
+ fn=self.search_function,
33
+ inputs=[
34
+ gr.Textbox(lines=1, placeholder="Enter your search query"),
35
+ self.embedding_dropdown
36
+ ],
37
+ outputs=gr.Markdown(),
38
+ title="arXiv Search Engine",
39
+ description="Search arXiv papers by keyword and embedding model.",
40
+ )
41
+
42
+ self.load_data(dataset)
43
+ self.load_model(embedding)
44
+
45
+ self.iface.launch()
46
+
47
+
48
+ # # --- Load data and embeddings ---
49
+ # with open("feature_names.txt", "r") as f:
50
+ # feature_names = [line.strip() for line in f]
51
+
52
+ # tfidf_matrix = load_npz("tfidf_matrix_train.npz")
53
+
54
+ # Load dataset and initialize search engine
55
+
56
+ def load_data(self, dataset):
57
+ train_data = dataset["train"]
58
+ for item in train_data.select(range(len(train_data))):
59
+ text = item["text"]
60
+ if not text or len(text.strip()) < 10:
61
+ continue
62
+
63
+ lines = text.splitlines()
64
+ title_lines = []
65
+ found_arxiv = False
66
+ arxiv_id = None
67
+
68
+ for line in lines:
69
+ line_strip = line.strip()
70
+ if not found_arxiv and line_strip.lower().startswith("arxiv:"):
71
+ found_arxiv = True
72
+ match = re.search(r'arxiv:\d{4}\.\d{4,5}v\d', line_strip, flags=re.IGNORECASE)
73
+ if match:
74
+ arxiv_id = match.group(0).lower()
75
+ elif not found_arxiv:
76
+ title_lines.append(line_strip)
77
+ else:
78
+ if line_strip.lower().startswith("abstract"):
79
+ break
80
+
81
+ title = " ".join(title_lines).strip()
82
+
83
+ self.raw_texts.append(text.strip())
84
+ self.titles.append(title)
85
+ self.documents.append(text.strip())
86
+ self.arxiv_ids.append(arxiv_id)
87
+
88
+ def keyword_match_ranking(self, query, top_n=5):
89
+ query_terms = query.lower().split()
90
+ query_indices = [i for i, term in enumerate(self.feature_names) if term in query_terms]
91
+ if not query_indices:
92
+ return []
93
+ scores = []
94
+ for doc_idx in range(self.tfidf_matrix.shape[0]):
95
+ doc_vector = self.tfidf_matrix[doc_idx]
96
+ doc_score = sum(doc_vector[0, i] for i in query_indices)
97
+ if doc_score > 0:
98
+ scores.append((doc_idx, doc_score))
99
+ scores.sort(key=lambda x: x[1], reverse=True)
100
+ return scores[:top_n]
101
+
102
+ def word2vec_search(self, query, top_n=5):
103
+ tokens = [word for word in query.split() if word in self.wv_model.key_to_index]
104
+ if not tokens:
105
+ return []
106
+ vectors = np.array([self.wv_model[word] for word in tokens])
107
+ query_vec = normalize(np.mean(vectors, axis=0).reshape(1, -1))
108
+ sims = cosine_similarity(query_vec, self.word2vec_embeddings).flatten()
109
+ top_indices = sims.argsort()[::-1][:top_n]
110
+ return [(i, sims[i]) for i in top_indices]
111
+
112
+ def bert_search(self, query, top_n=5):
113
+ with torch.no_grad():
114
+ inputs = self.tokenizer(query, return_tensors="pt", truncation=True, padding=True)
115
+ outputs = self.model(**inputs)
116
+ query_vec = normalize(outputs.last_hidden_state[:, 0, :].numpy())
117
+ sims = cosine_similarity(query_vec, self.bert_embeddings).flatten()
118
+ top_indices = sims.argsort()[::-1][:top_n]
119
+ return [(i, sims[i]) for i in top_indices]
120
+
121
+ def load_model(self, embedding):
122
+ if embedding == "tfidf":
123
+ self.tfidf_matrix = load_npz("TF-IDF embeddings/tfidf_matrix_train.npz")
124
+ with open("TF-IDF embeddings/feature_names.txt", "r") as f:
125
+ self.feature_names = [line.strip() for line in f.readlines()]
126
+ elif embedding == "word2vec":
127
+ # Use trimmed model here
128
+ self.word2vec_embeddings = normalize(np.load("Word2Vec embeddings/word2vec_embedding.npz")["word2vec_embedding"])
129
+ self.wv_model = KeyedVectors.load("models/word2vec-trimmed.model")
130
+ elif embedding == "bert":
131
+ self.bert_embeddings = normalize(np.load("BERT embeddings/bert_embedding.npz")["bert_embedding"])
132
+ self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
133
+ self.model = BertModel.from_pretrained('bert-base-uncased')
134
+ self.model.eval()
135
+ else:
136
+ raise ValueError(f"Unsupported embedding type: {embedding}")
137
+
138
+ def on_model_change(self, change):
139
+ new_model = change["new"]
140
+ self.embedding = new_model
141
+ self.load_model(new_model)
142
+
143
+
144
+ def snippet_before_abstract(self, text):
145
+ pattern = re.compile(r'a\s*b\s*s\s*t\s*r\s*a\s*c\s*t|i\s*n\s*t\s*r\s*o\s*d\s*u\s*c\s*t\s*i\s*o\s*n', re.IGNORECASE)
146
+ match = pattern.search(text)
147
+ if match:
148
+ return text[:match.start()].strip()
149
+ else:
150
+ return text[:100].strip()
151
 
152
 
153
+ def search_function(self, query, embedding):
154
+ # Load or switch embedding model here if needed
155
+ if embedding == "tfidf":
156
+ results = self.keyword_match_ranking(query)
157
+ elif embedding == "word2vec":
158
+ results = self.word2vec_search(query)
159
+ elif embedding == "bert":
160
+ results = self.bert_search(query)
161
+ else:
162
+ return "No results found."
163
 
164
+ if not results:
165
+ return "No results found."
166
 
167
+ output = ""
168
+ display_rank = 1
169
+ for idx, score in results:
170
+ if not self.arxiv_ids[idx]:
171
+ continue
172
 
173
+ link = f"https://arxiv.org/abs/{self.arxiv_ids[idx].replace('arxiv:', '')}"
174
+ snippet = self.snippet_before_abstract(self.documents[idx]).replace('\n', '<br>')
175
+ output += f"### Document {display_rank}\n"
176
+ output += f"[arXiv Link]({link})\n\n"
177
+ output += f"<pre>{snippet}</pre>\n\n---\n"
178
+ display_rank += 1
179
 
180
+ return output
 
 
 
181
 
 
 
 
 
182
 
183
+ if __name__ == "__main__":
184
+ dataset = load_dataset("ccdv/arxiv-classification", "no_ref") # replace with your dataset
185
+ search_engine = ArxivSearch(dataset, embedding="tfidf") # Initialize with tfidf or any other embedding
186
+ search_engine.iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
models/word2vec-trimmed.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:785f477908089d8e1d5e1ce94f04ccbecb2bdb655f6cc468b7bacaac3e40d663
3
+ size 3735368
models/word2vec-trimmed.model.vectors.npy ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:01c2c062175d68b6f745b6e798d91033e3d46c0e23571d5bb37b0450d2ff5293
3
+ size 234224528
requirements.txt CHANGED
@@ -1,4 +1,8 @@
1
  gradio
2
  scipy
3
  numpy
4
- datasets
 
 
 
 
 
1
  gradio
2
  scipy
3
  numpy
4
+ datasets
5
+ torch
6
+ gensim
7
+ sklearn
8
+ transformers