jorgemarcc's picture
Update app.py
3fffb69 verified
raw
history blame
4.33 kB
# -*- coding: utf-8 -*-
"""
[Martinez-Gil2024] Augmenting the Interpretability of GraphCodeBERT for Code Similarity Tasks, arXiv preprint arXiv:2410.05275, 2024
@author: Jorge Martinez-Gil
"""
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from transformers import RobertaTokenizer, RobertaModel
import torch
import gradio as gr
from io import BytesIO
from PIL import Image
# Load GraphCodeBERT model
tokenizer = RobertaTokenizer.from_pretrained("microsoft/graphcodebert-base")
model = RobertaModel.from_pretrained("microsoft/graphcodebert-base")
# Define sorting algorithms as strings
sorting_algorithms = {
"Bubble_Sort": """
def bubble_sort(arr):
n = len(arr)
for i in range(n):
for j in range(0, n-i-1):
if arr[j] > arr[j+1]:
arr[j], arr[j+1] = arr[j+1], arr[j]
return arr
""",
"Selection_Sort": """
def selection_sort(arr):
for i in range(len(arr)):
min_idx = i
for j in range(i+1, len(arr)):
if arr[j] < arr[min_idx]:
min_idx = j
arr[i], arr[min_idx] = arr[min_idx], arr[i]
return arr
""",
"Insertion_Sort": """
def insertion_sort(arr):
for i in range(1, len(arr)):
key = arr[i]
j = i-1
while j >= 0 and key < arr[j]:
arr[j + 1] = arr[j]
j -= 1
arr[j + 1] = key
return arr
""",
"Merge_Sort": """
def merge_sort(arr):
if len(arr) > 1:
mid = len(arr) // 2
L = arr[:mid]
R = arr[mid:]
merge_sort(L)
merge_sort(R)
i = j = k = 0
while i < len(L) and j < len(R):
if L[i] < R[j]:
arr[k] = L[i]
i += 1
else:
arr[k] = R[j]
j += 1
k += 1
while i < len(L):
arr[k] = L[i]
i += 1
k += 1
while j < len(R):
arr[k] = R[j]
j += 1
k += 1
return arr
""",
"Quick_Sort": """
def partition(arr, low, high):
i = (low - 1)
pivot = arr[high]
for j in range(low, high):
if arr[j] <= pivot:
i += 1
arr[i], arr[j] = arr[j], arr[i]
arr[i+1], arr[high] = arr[high], arr[i+1]
return (i + 1)
def quick_sort(arr, low, high):
if low < high:
pi = partition(arr, low, high)
quick_sort(arr, low, pi - 1)
quick_sort(arr, pi + 1, high)
return arr
"""
}
# Get token embeddings for a code snippet
def get_token_embeddings(code):
inputs = tokenizer(code, return_tensors="pt", max_length=512, truncation=True, padding=True)
with torch.no_grad():
outputs = model(**inputs)
token_embeddings = outputs.last_hidden_state.squeeze(0).cpu().numpy()
tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'].squeeze())
return token_embeddings, tokens
# Compare two algorithms and return PCA scatter plot
def compare_algorithms(algo1_name, algo2_name):
code1 = sorting_algorithms[algo1_name]
code2 = sorting_algorithms[algo2_name]
emb1, tokens1 = get_token_embeddings(code1)
emb2, tokens2 = get_token_embeddings(code2)
combined = np.concatenate([emb1, emb2], axis=0)
pca = PCA(n_components=2)
coords = pca.fit_transform(combined)
plt.figure(figsize=(6, 5), dpi=150)
plt.scatter(coords[:len(tokens1), 0], coords[:len(tokens1), 1], color='red', label=algo1_name, s=20)
plt.scatter(coords[len(tokens1):, 0], coords[len(tokens1):, 1], color='blue', label=algo2_name, s=20)
plt.legend()
plt.xticks([]); plt.yticks([]); plt.grid(False)
buf = BytesIO()
plt.savefig(buf, format='png', bbox_inches='tight')
plt.close()
buf.seek(0)
return Image.open(buf)
# Gradio interface
interface = gr.Interface(
fn=compare_algorithms,
inputs=[
gr.Dropdown(choices=list(sorting_algorithms.keys()), label="Algorithm 1"),
gr.Dropdown(choices=list(sorting_algorithms.keys()), label="Algorithm 2")
],
outputs=gr.Image(type="pil", label="Token Embedding PCA"),
title="GraphCodeBERT Token Embedding Comparison",
description="Visual comparison of token-level embeddings from GraphCodeBERT for classical sorting algorithms."
)
if __name__ == "__main__":
interface.launch()