Added option to select models to search word in
Browse files- app.py +16 -3
- word2vec.py +39 -4
app.py
CHANGED
|
@@ -20,6 +20,7 @@ if active_tab == "Nearest neighbours":
|
|
| 20 |
with col2:
|
| 21 |
time_slice = st.selectbox("Time slice", ["Archaic", "Classical", "Hellenistic", "Early Roman", "Late Roman"])
|
| 22 |
|
|
|
|
| 23 |
n = st.slider("Number of neighbours", 1, 50, 15)
|
| 24 |
|
| 25 |
nearest_neighbours_button = st.button("Find nearest neighbours")
|
|
@@ -28,14 +29,26 @@ if active_tab == "Nearest neighbours":
|
|
| 28 |
if nearest_neighbours_button:
|
| 29 |
|
| 30 |
# Rewrite timeslices to model names: Archaic -> archaic_cbow
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
time_slice = time_slice.lower() + "_cbow"
|
| 32 |
-
|
|
|
|
| 33 |
|
| 34 |
# Check if all fields are filled in
|
| 35 |
-
if validate_nearest_neighbours(word, time_slice, n) == False:
|
| 36 |
st.error('Please fill in all fields')
|
| 37 |
else:
|
| 38 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 39 |
df = pd.DataFrame(nearest_neighbours, columns=["Word", "Time slice", "Similarity"])
|
| 40 |
st.table(df)
|
| 41 |
|
|
|
|
| 20 |
with col2:
|
| 21 |
time_slice = st.selectbox("Time slice", ["Archaic", "Classical", "Hellenistic", "Early Roman", "Late Roman"])
|
| 22 |
|
| 23 |
+
models = st.multiselect("Select models to search for neighbours", ["Archaic", "Classical", "Hellenistic", "Early Roman", "Late Roman"])
|
| 24 |
n = st.slider("Number of neighbours", 1, 50, 15)
|
| 25 |
|
| 26 |
nearest_neighbours_button = st.button("Find nearest neighbours")
|
|
|
|
| 29 |
if nearest_neighbours_button:
|
| 30 |
|
| 31 |
# Rewrite timeslices to model names: Archaic -> archaic_cbow
|
| 32 |
+
if time_slice == 'Hellenistic':
|
| 33 |
+
time_slice = 'hellen'
|
| 34 |
+
elif time_slice == 'Early Roman':
|
| 35 |
+
time_slice = 'early_roman'
|
| 36 |
+
elif time_slice == 'Late Roman':
|
| 37 |
+
time_slice = 'late_roman'
|
| 38 |
+
|
| 39 |
time_slice = time_slice.lower() + "_cbow"
|
| 40 |
+
|
| 41 |
+
|
| 42 |
|
| 43 |
# Check if all fields are filled in
|
| 44 |
+
if validate_nearest_neighbours(word, time_slice, n, models) == False:
|
| 45 |
st.error('Please fill in all fields')
|
| 46 |
else:
|
| 47 |
+
# Rewrite models to list of all loaded models
|
| 48 |
+
models = load_selected_models(models)
|
| 49 |
+
|
| 50 |
+
nearest_neighbours = get_nearest_neighbours(word, time_slice, n, models)
|
| 51 |
+
|
| 52 |
df = pd.DataFrame(nearest_neighbours, columns=["Word", "Time slice", "Similarity"])
|
| 53 |
st.table(df)
|
| 54 |
|
word2vec.py
CHANGED
|
@@ -18,6 +18,24 @@ def load_all_models():
|
|
| 18 |
return [archaic, classical, early_roman, hellen, late_roman]
|
| 19 |
|
| 20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
def load_word2vec_model(model_path):
|
| 22 |
'''
|
| 23 |
Load a word2vec model from a file
|
|
@@ -120,15 +138,31 @@ def get_cosine_similarity_one_word(word, time_slice1, time_slice2):
|
|
| 120 |
|
| 121 |
|
| 122 |
|
| 123 |
-
def validate_nearest_neighbours(word, time_slice_model, n):
|
| 124 |
'''
|
| 125 |
Validate the input of the nearest neighbours function
|
| 126 |
'''
|
| 127 |
-
if word == '' or time_slice_model == [] or n == '':
|
| 128 |
return False
|
| 129 |
return True
|
| 130 |
|
| 131 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
def get_nearest_neighbours(word, time_slice_model, n=10, models=load_all_models()):
|
| 133 |
'''
|
| 134 |
Return the nearest neighbours of a word
|
|
@@ -149,6 +183,7 @@ def get_nearest_neighbours(word, time_slice_model, n=10, models=load_all_models(
|
|
| 149 |
# Iterate over all models
|
| 150 |
for model in models:
|
| 151 |
model_name = model[0]
|
|
|
|
| 152 |
model = model[1]
|
| 153 |
|
| 154 |
# Iterate over all words of the model
|
|
@@ -162,14 +197,14 @@ def get_nearest_neighbours(word, time_slice_model, n=10, models=load_all_models(
|
|
| 162 |
|
| 163 |
# If the list of nearest neighbours is not full yet, add the current word
|
| 164 |
if len(nearest_neighbours) < n:
|
| 165 |
-
nearest_neighbours.append((word,
|
| 166 |
|
| 167 |
# If the list of nearest neighbours is full, replace the word with the smallest cosine similarity
|
| 168 |
else:
|
| 169 |
smallest_neighbour = min(nearest_neighbours, key=lambda x: x[2])
|
| 170 |
if cosine_similarity_vectors > smallest_neighbour[2]:
|
| 171 |
nearest_neighbours.remove(smallest_neighbour)
|
| 172 |
-
nearest_neighbours.append((word,
|
| 173 |
|
| 174 |
|
| 175 |
return sorted(nearest_neighbours, key=lambda x: x[2], reverse=True)
|
|
|
|
| 18 |
return [archaic, classical, early_roman, hellen, late_roman]
|
| 19 |
|
| 20 |
|
| 21 |
+
def load_selected_models(selected_models):
|
| 22 |
+
'''
|
| 23 |
+
Load the selected word2vec models
|
| 24 |
+
'''
|
| 25 |
+
models = []
|
| 26 |
+
for model in selected_models:
|
| 27 |
+
if model == "Early Roman":
|
| 28 |
+
model = "early_roman"
|
| 29 |
+
elif model == "Late Roman":
|
| 30 |
+
model = "late_roman"
|
| 31 |
+
elif model == "Hellenistic":
|
| 32 |
+
model = "hellen"
|
| 33 |
+
model_name = model.lower() + "_cbow"
|
| 34 |
+
models.append([model_name, load_word2vec_model(f'models/{model_name}.model')])
|
| 35 |
+
|
| 36 |
+
return models
|
| 37 |
+
|
| 38 |
+
|
| 39 |
def load_word2vec_model(model_path):
|
| 40 |
'''
|
| 41 |
Load a word2vec model from a file
|
|
|
|
| 138 |
|
| 139 |
|
| 140 |
|
| 141 |
+
def validate_nearest_neighbours(word, time_slice_model, n, models):
|
| 142 |
'''
|
| 143 |
Validate the input of the nearest neighbours function
|
| 144 |
'''
|
| 145 |
+
if word == '' or time_slice_model == [] or n == '' or models == []:
|
| 146 |
return False
|
| 147 |
return True
|
| 148 |
|
| 149 |
|
| 150 |
+
def convert_model_to_time_name(model_name):
|
| 151 |
+
'''
|
| 152 |
+
Convert the model name to the time slice name
|
| 153 |
+
'''
|
| 154 |
+
if model_name == 'archaic_cbow':
|
| 155 |
+
return 'Archaic'
|
| 156 |
+
elif model_name == 'classical_cbow':
|
| 157 |
+
return 'Classical'
|
| 158 |
+
elif model_name == 'early_roman_cbow':
|
| 159 |
+
return 'Early Roman'
|
| 160 |
+
elif model_name == 'hellen_cbow':
|
| 161 |
+
return 'Hellenistic'
|
| 162 |
+
elif model_name == 'late_roman_cbow':
|
| 163 |
+
return 'Late Roman'
|
| 164 |
+
|
| 165 |
+
|
| 166 |
def get_nearest_neighbours(word, time_slice_model, n=10, models=load_all_models()):
|
| 167 |
'''
|
| 168 |
Return the nearest neighbours of a word
|
|
|
|
| 183 |
# Iterate over all models
|
| 184 |
for model in models:
|
| 185 |
model_name = model[0]
|
| 186 |
+
time_name = convert_model_to_time_name(model_name)
|
| 187 |
model = model[1]
|
| 188 |
|
| 189 |
# Iterate over all words of the model
|
|
|
|
| 197 |
|
| 198 |
# If the list of nearest neighbours is not full yet, add the current word
|
| 199 |
if len(nearest_neighbours) < n:
|
| 200 |
+
nearest_neighbours.append((word, time_name, cosine_similarity_vectors))
|
| 201 |
|
| 202 |
# If the list of nearest neighbours is full, replace the word with the smallest cosine similarity
|
| 203 |
else:
|
| 204 |
smallest_neighbour = min(nearest_neighbours, key=lambda x: x[2])
|
| 205 |
if cosine_similarity_vectors > smallest_neighbour[2]:
|
| 206 |
nearest_neighbours.remove(smallest_neighbour)
|
| 207 |
+
nearest_neighbours.append((word, time_name, cosine_similarity_vectors))
|
| 208 |
|
| 209 |
|
| 210 |
return sorted(nearest_neighbours, key=lambda x: x[2], reverse=True)
|