# -*- coding: utf-8 -*- """ [Martinez-Gil2024] Augmenting the Interpretability of GraphCodeBERT for Code Similarity Tasks, arXiv preprint arXiv:2410.05275, 2024 @author: Jorge Martinez-Gil """ import numpy as np import matplotlib.pyplot as plt from sklearn.decomposition import PCA from transformers import RobertaTokenizer, RobertaModel import torch import gradio as gr from io import BytesIO from PIL import Image # Load GraphCodeBERT model tokenizer = RobertaTokenizer.from_pretrained("microsoft/graphcodebert-base") model = RobertaModel.from_pretrained("microsoft/graphcodebert-base") # Define sorting algorithms as strings sorting_algorithms = { "Bubble_Sort": """ def bubble_sort(arr): n = len(arr) for i in range(n): for j in range(0, n-i-1): if arr[j] > arr[j+1]: arr[j], arr[j+1] = arr[j+1], arr[j] return arr """, "Selection_Sort": """ def selection_sort(arr): for i in range(len(arr)): min_idx = i for j in range(i+1, len(arr)): if arr[j] < arr[min_idx]: min_idx = j arr[i], arr[min_idx] = arr[min_idx], arr[i] return arr """, "Insertion_Sort": """ def insertion_sort(arr): for i in range(1, len(arr)): key = arr[i] j = i-1 while j >= 0 and key < arr[j]: arr[j + 1] = arr[j] j -= 1 arr[j + 1] = key return arr """, "Merge_Sort": """ def merge_sort(arr): if len(arr) > 1: mid = len(arr) // 2 L = arr[:mid] R = arr[mid:] merge_sort(L) merge_sort(R) i = j = k = 0 while i < len(L) and j < len(R): if L[i] < R[j]: arr[k] = L[i] i += 1 else: arr[k] = R[j] j += 1 k += 1 while i < len(L): arr[k] = L[i] i += 1 k += 1 while j < len(R): arr[k] = R[j] j += 1 k += 1 return arr """, "Quick_Sort": """ def partition(arr, low, high): i = (low - 1) pivot = arr[high] for j in range(low, high): if arr[j] <= pivot: i += 1 arr[i], arr[j] = arr[j], arr[i] arr[i+1], arr[high] = arr[high], arr[i+1] return (i + 1) def quick_sort(arr, low, high): if low < high: pi = partition(arr, low, high) quick_sort(arr, low, pi - 1) quick_sort(arr, pi + 1, high) return arr """ } # Get token embeddings for a code snippet def get_token_embeddings(code): inputs = tokenizer(code, return_tensors="pt", max_length=512, truncation=True, padding=True) with torch.no_grad(): outputs = model(**inputs) token_embeddings = outputs.last_hidden_state.squeeze(0).cpu().numpy() tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'].squeeze()) return token_embeddings, tokens # Compare two algorithms and return PCA scatter plot def compare_algorithms(algo1_name, algo2_name): code1 = sorting_algorithms[algo1_name] code2 = sorting_algorithms[algo2_name] emb1, tokens1 = get_token_embeddings(code1) emb2, tokens2 = get_token_embeddings(code2) combined = np.concatenate([emb1, emb2], axis=0) pca = PCA(n_components=2) coords = pca.fit_transform(combined) plt.figure(figsize=(6, 5), dpi=150) plt.scatter(coords[:len(tokens1), 0], coords[:len(tokens1), 1], color='red', label=algo1_name, s=20) plt.scatter(coords[len(tokens1):, 0], coords[len(tokens1):, 1], color='blue', label=algo2_name, s=20) plt.legend() plt.xticks([]); plt.yticks([]); plt.grid(False) buf = BytesIO() plt.savefig(buf, format='png', bbox_inches='tight') plt.close() buf.seek(0) return Image.open(buf) # Gradio interface interface = gr.Interface( fn=compare_algorithms, inputs=[ gr.Dropdown(choices=list(sorting_algorithms.keys()), label="Algorithm 1"), gr.Dropdown(choices=list(sorting_algorithms.keys()), label="Algorithm 2") ], outputs=gr.Image(type="pil", label="Token Embedding PCA"), title="GraphCodeBERT Token Embedding Comparison", description="Visual comparison of token-level embeddings from GraphCodeBERT for classical sorting algorithms." ) if __name__ == "__main__": interface.launch()