Jan Mühlnikel
commited on
Commit
·
3d9250a
1
Parent(s):
b188b37
experiment
Browse files- functions/calc_matches.py +16 -4
functions/calc_matches.py
CHANGED
@@ -10,8 +10,8 @@ def calc_matches(filtered_df, project_df, similarity_matrix, top_x):
|
|
10 |
st.write(similarity_matrix.shape)
|
11 |
|
12 |
# Ensure the matrix is in a suitable format for manipulation
|
13 |
-
|
14 |
-
|
15 |
|
16 |
# Get indices from dataframes
|
17 |
filtered_df_indices = filtered_df.index.to_list()
|
@@ -38,10 +38,22 @@ def calc_matches(filtered_df, project_df, similarity_matrix, top_x):
|
|
38 |
# Get the corresponding similarity values
|
39 |
#top_values = match_matrix.data[linear_indices]
|
40 |
|
41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
|
43 |
# Convert flat indices to 2D row and column indices
|
44 |
-
row_indices, col_indices =
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
|
46 |
# Get the values corresponding to the top k indices
|
47 |
top_values = match_matrix[row_indices, col_indices]
|
|
|
10 |
st.write(similarity_matrix.shape)
|
11 |
|
12 |
# Ensure the matrix is in a suitable format for manipulation
|
13 |
+
if not isinstance(similarity_matrix, csr_matrix):
|
14 |
+
similarity_matrix = csr_matrix(similarity_matrix)
|
15 |
|
16 |
# Get indices from dataframes
|
17 |
filtered_df_indices = filtered_df.index.to_list()
|
|
|
38 |
# Get the corresponding similarity values
|
39 |
#top_values = match_matrix.data[linear_indices]
|
40 |
|
41 |
+
flat_data = match_matrix.data
|
42 |
+
|
43 |
+
# Get the indices that would sort the data array in descending order
|
44 |
+
sorted_indices = np.argsort(flat_data)[::-1]
|
45 |
+
|
46 |
+
# Take the first k indices to get the top k maximum values
|
47 |
+
top_indices = sorted_indices[:top_x]
|
48 |
|
49 |
# Convert flat indices to 2D row and column indices
|
50 |
+
row_indices, col_indices = match_matrix.nonzero()
|
51 |
+
row_indices = row_indices[top_indices]
|
52 |
+
col_indices = col_indices[top_indices]
|
53 |
+
|
54 |
+
# Get the values corresponding to the top k indices
|
55 |
+
top_values = flat_data[top_indices]
|
56 |
+
|
57 |
|
58 |
# Get the values corresponding to the top k indices
|
59 |
top_values = match_matrix[row_indices, col_indices]
|