jorgemarcc commited on
Commit
3fffb69
·
verified ·
1 Parent(s): 32898cf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +155 -211
app.py CHANGED
@@ -1,211 +1,155 @@
1
- # -*- coding: utf-8 -*-
2
- """
3
- [Martinez-Gil2024] Augmenting the Interpretability of GraphCodeBERT for Code Similarity Tasks, arXiv preprint arXiv:2410.05275, 2024
4
-
5
- @author: Jorge Martinez-Gil
6
- """
7
-
8
- import os
9
- from transformers import RobertaTokenizer, RobertaModel
10
- from sklearn.decomposition import PCA
11
- import matplotlib.pyplot as plt
12
- import numpy as np
13
- import itertools
14
-
15
- # Initialize GraphCodeBERT
16
- tokenizer = RobertaTokenizer.from_pretrained("microsoft/graphcodebert-base")
17
- model = RobertaModel.from_pretrained("microsoft/graphcodebert-base")
18
-
19
- # Define the classical sorting algorithms
20
- sorting_algorithms = {
21
- "Bubble_Sort": """
22
- def bubble_sort(arr):
23
- n = len(arr)
24
- for i in range(n):
25
- for j in range(0, n-i-1):
26
- if arr[j] > arr[j+1]:
27
- arr[j], arr[j+1] = arr[j+1], arr[j]
28
- return arr
29
- """,
30
-
31
- "Selection_Sort": """
32
- def selection_sort(arr):
33
- for i in range(len(arr)):
34
- min_idx = i
35
- for j in range(i+1, len(arr)):
36
- if arr[j] < arr[min_idx]:
37
- min_idx = j
38
- arr[i], arr[min_idx] = arr[min_idx], arr[i]
39
- return arr
40
- """,
41
-
42
- "Insertion_Sort": """
43
- def insertion_sort(arr):
44
- for i in range(1, len(arr)):
45
- key = arr[i]
46
- j = i-1
47
- while j >=0 and key < arr[j]:
48
- arr[j + 1] = arr[j]
49
- j -= 1
50
- arr[j + 1] = key
51
- return arr
52
- """,
53
-
54
- "Merge_Sort": """
55
- def merge_sort(arr):
56
- if len(arr) > 1:
57
- mid = len(arr)//2
58
- L = arr[:mid]
59
- R = arr[mid:]
60
-
61
- merge_sort(L)
62
- merge_sort(R)
63
-
64
- i = j = k = 0
65
-
66
- while i < len(L) and j < len(R):
67
- if L[i] < R[j]:
68
- arr[k] = L[i]
69
- i += 1
70
- else:
71
- arr[k] = R[j]
72
- j += 1
73
- k += 1
74
-
75
- while i < len(L):
76
- arr[k] = L[i]
77
- i += 1
78
- k += 1
79
-
80
- while j < len(R):
81
- arr[k] = R[j]
82
- j += 1
83
- k += 1
84
- return arr
85
- """,
86
-
87
- "Quick_Sort": """
88
- def partition(arr, low, high):
89
- i = (low-1)
90
- pivot = arr[high]
91
-
92
- for j in range(low, high):
93
- if arr[j] <= pivot:
94
- i = i+1
95
- arr[i], arr[j] = arr[j], arr[i]
96
- arr[i+1], arr[high] = arr[high], arr[i+1]
97
- return (i+1)
98
-
99
- def quick_sort(arr, low, high):
100
- if low < high:
101
- pi = partition(arr, low, high)
102
- quick_sort(arr, low, pi-1)
103
- quick_sort(arr, pi+1, high)
104
- return arr
105
- """
106
- }
107
-
108
- # Function to get token embeddings for a code snippet
109
- def get_token_embeddings(code):
110
- inputs = tokenizer(code, return_tensors="pt", max_length=512, truncation=True, padding=True)
111
- outputs = model(**inputs)
112
- token_embeddings = outputs.last_hidden_state.squeeze().detach().numpy()
113
- tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'].squeeze())
114
- return token_embeddings, tokens
115
-
116
- # Directory to save images
117
- output_dir = "pca_pairwise_comparisons"
118
- os.makedirs(output_dir, exist_ok=True)
119
-
120
- # Generate all possible pairs of sorting algorithms
121
- algorithm_pairs = list(itertools.combinations(sorting_algorithms.keys(), 2))
122
-
123
- # Loop over each pair and generate the visualizations
124
- for (algo1_name, algo2_name) in algorithm_pairs:
125
- algo1_code = sorting_algorithms[algo1_name]
126
- algo2_code = sorting_algorithms[algo2_name]
127
-
128
- # Get token embeddings for both algorithms
129
- algo1_embeddings, algo1_tokens = get_token_embeddings(algo1_code)
130
- algo2_embeddings, algo2_tokens = get_token_embeddings(algo2_code)
131
-
132
- # Combine embeddings
133
- all_embeddings = np.concatenate((algo1_embeddings, algo2_embeddings), axis=0)
134
-
135
- # Reduce dimensionality to 2D using PCA
136
- pca = PCA(n_components=2)
137
- embeddings_2d = pca.fit_transform(all_embeddings)
138
-
139
- # Plotting the token embeddings in 2D
140
- plt.figure(figsize=(10, 8), dpi=300)
141
-
142
- # Scatter plot for the first algorithm tokens
143
- plt.scatter(embeddings_2d[:len(algo1_tokens), 0],
144
- embeddings_2d[:len(algo1_tokens), 1],
145
- color='red', s=50, label=algo1_name, alpha=0.8)
146
-
147
- # Scatter plot for the second algorithm tokens
148
- plt.scatter(embeddings_2d[len(algo1_tokens):, 0],
149
- embeddings_2d[len(algo1_tokens):, 1],
150
- color='blue', s=50, label=algo2_name, alpha=0.8)
151
-
152
- # Make the visualization more professional
153
- plt.xticks([])
154
- plt.yticks([])
155
- plt.xlabel('')
156
- plt.ylabel('')
157
- plt.grid(False)
158
- plt.legend()
159
-
160
- # Save the figure as a high-quality PNG file
161
- output_file = os.path.join(output_dir, f"{algo1_name}_vs_{algo2_name}_tokens_2d_pca.png")
162
- plt.savefig(output_file, format='png', dpi=300, bbox_inches='tight')
163
-
164
- # Show the plot
165
- plt.close()
166
-
167
- print("All pairwise comparison images have been generated.")
168
-
169
-
170
- import gradio as gr
171
- from io import BytesIO
172
- from PIL import Image
173
-
174
- def compare_algorithms(algo1_name, algo2_name):
175
- algo1_code = sorting_algorithms[algo1_name]
176
- algo2_code = sorting_algorithms[algo2_name]
177
-
178
- # Get token embeddings
179
- algo1_embeddings, algo1_tokens = get_token_embeddings(algo1_code)
180
- algo2_embeddings, algo2_tokens = get_token_embeddings(algo2_code)
181
-
182
- # Combine and reduce
183
- all_embeddings = np.concatenate((algo1_embeddings, algo2_embeddings), axis=0)
184
- pca = PCA(n_components=2)
185
- embeddings_2d = pca.fit_transform(all_embeddings)
186
-
187
- # Plot
188
- plt.figure(figsize=(6, 5), dpi=150)
189
- plt.scatter(embeddings_2d[:len(algo1_tokens), 0], embeddings_2d[:len(algo1_tokens), 1], color='red', s=20, label=algo1_name)
190
- plt.scatter(embeddings_2d[len(algo1_tokens):, 0], embeddings_2d[len(algo1_tokens):, 1], color='blue', s=20, label=algo2_name)
191
- plt.xticks([]); plt.yticks([]); plt.grid(False); plt.legend()
192
-
193
- # Save to BytesIO
194
- buf = BytesIO()
195
- plt.savefig(buf, format='png', bbox_inches='tight')
196
- plt.close()
197
- buf.seek(0)
198
- return Image.open(buf)
199
-
200
- interface = gr.Interface(
201
- fn=compare_algorithms,
202
- inputs=[
203
- gr.Dropdown(choices=list(sorting_algorithms.keys()), label="Algorithm 1"),
204
- gr.Dropdown(choices=list(sorting_algorithms.keys()), label="Algorithm 2")
205
- ],
206
- outputs=gr.Image(type="pil", label="Token PCA Plot"),
207
- title="Code Similarity Visualization with GraphCodeBERT"
208
- )
209
-
210
- if __name__ == "__main__":
211
- interface.launch()
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ [Martinez-Gil2024] Augmenting the Interpretability of GraphCodeBERT for Code Similarity Tasks, arXiv preprint arXiv:2410.05275, 2024
4
+
5
+ @author: Jorge Martinez-Gil
6
+ """
7
+
8
+ import numpy as np
9
+ import matplotlib.pyplot as plt
10
+ from sklearn.decomposition import PCA
11
+ from transformers import RobertaTokenizer, RobertaModel
12
+ import torch
13
+ import gradio as gr
14
+ from io import BytesIO
15
+ from PIL import Image
16
+
17
+ # Load GraphCodeBERT model
18
+ tokenizer = RobertaTokenizer.from_pretrained("microsoft/graphcodebert-base")
19
+ model = RobertaModel.from_pretrained("microsoft/graphcodebert-base")
20
+
21
+ # Define sorting algorithms as strings
22
+ sorting_algorithms = {
23
+ "Bubble_Sort": """
24
+ def bubble_sort(arr):
25
+ n = len(arr)
26
+ for i in range(n):
27
+ for j in range(0, n-i-1):
28
+ if arr[j] > arr[j+1]:
29
+ arr[j], arr[j+1] = arr[j+1], arr[j]
30
+ return arr
31
+ """,
32
+
33
+ "Selection_Sort": """
34
+ def selection_sort(arr):
35
+ for i in range(len(arr)):
36
+ min_idx = i
37
+ for j in range(i+1, len(arr)):
38
+ if arr[j] < arr[min_idx]:
39
+ min_idx = j
40
+ arr[i], arr[min_idx] = arr[min_idx], arr[i]
41
+ return arr
42
+ """,
43
+
44
+ "Insertion_Sort": """
45
+ def insertion_sort(arr):
46
+ for i in range(1, len(arr)):
47
+ key = arr[i]
48
+ j = i-1
49
+ while j >= 0 and key < arr[j]:
50
+ arr[j + 1] = arr[j]
51
+ j -= 1
52
+ arr[j + 1] = key
53
+ return arr
54
+ """,
55
+
56
+ "Merge_Sort": """
57
+ def merge_sort(arr):
58
+ if len(arr) > 1:
59
+ mid = len(arr) // 2
60
+ L = arr[:mid]
61
+ R = arr[mid:]
62
+
63
+ merge_sort(L)
64
+ merge_sort(R)
65
+
66
+ i = j = k = 0
67
+ while i < len(L) and j < len(R):
68
+ if L[i] < R[j]:
69
+ arr[k] = L[i]
70
+ i += 1
71
+ else:
72
+ arr[k] = R[j]
73
+ j += 1
74
+ k += 1
75
+
76
+ while i < len(L):
77
+ arr[k] = L[i]
78
+ i += 1
79
+ k += 1
80
+
81
+ while j < len(R):
82
+ arr[k] = R[j]
83
+ j += 1
84
+ k += 1
85
+ return arr
86
+ """,
87
+
88
+ "Quick_Sort": """
89
+ def partition(arr, low, high):
90
+ i = (low - 1)
91
+ pivot = arr[high]
92
+ for j in range(low, high):
93
+ if arr[j] <= pivot:
94
+ i += 1
95
+ arr[i], arr[j] = arr[j], arr[i]
96
+ arr[i+1], arr[high] = arr[high], arr[i+1]
97
+ return (i + 1)
98
+
99
+ def quick_sort(arr, low, high):
100
+ if low < high:
101
+ pi = partition(arr, low, high)
102
+ quick_sort(arr, low, pi - 1)
103
+ quick_sort(arr, pi + 1, high)
104
+ return arr
105
+ """
106
+ }
107
+
108
+ # Get token embeddings for a code snippet
109
+ def get_token_embeddings(code):
110
+ inputs = tokenizer(code, return_tensors="pt", max_length=512, truncation=True, padding=True)
111
+ with torch.no_grad():
112
+ outputs = model(**inputs)
113
+ token_embeddings = outputs.last_hidden_state.squeeze(0).cpu().numpy()
114
+ tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'].squeeze())
115
+ return token_embeddings, tokens
116
+
117
+ # Compare two algorithms and return PCA scatter plot
118
+ def compare_algorithms(algo1_name, algo2_name):
119
+ code1 = sorting_algorithms[algo1_name]
120
+ code2 = sorting_algorithms[algo2_name]
121
+
122
+ emb1, tokens1 = get_token_embeddings(code1)
123
+ emb2, tokens2 = get_token_embeddings(code2)
124
+
125
+ combined = np.concatenate([emb1, emb2], axis=0)
126
+ pca = PCA(n_components=2)
127
+ coords = pca.fit_transform(combined)
128
+
129
+ plt.figure(figsize=(6, 5), dpi=150)
130
+ plt.scatter(coords[:len(tokens1), 0], coords[:len(tokens1), 1], color='red', label=algo1_name, s=20)
131
+ plt.scatter(coords[len(tokens1):, 0], coords[len(tokens1):, 1], color='blue', label=algo2_name, s=20)
132
+ plt.legend()
133
+ plt.xticks([]); plt.yticks([]); plt.grid(False)
134
+
135
+ buf = BytesIO()
136
+ plt.savefig(buf, format='png', bbox_inches='tight')
137
+ plt.close()
138
+ buf.seek(0)
139
+ return Image.open(buf)
140
+
141
+ # Gradio interface
142
+ interface = gr.Interface(
143
+ fn=compare_algorithms,
144
+ inputs=[
145
+ gr.Dropdown(choices=list(sorting_algorithms.keys()), label="Algorithm 1"),
146
+ gr.Dropdown(choices=list(sorting_algorithms.keys()), label="Algorithm 2")
147
+ ],
148
+ outputs=gr.Image(type="pil", label="Token Embedding PCA"),
149
+ title="GraphCodeBERT Token Embedding Comparison",
150
+ description="Visual comparison of token-level embeddings from GraphCodeBERT for classical sorting algorithms."
151
+ )
152
+
153
+ if __name__ == "__main__":
154
+ interface.launch()
155
+