|
|
|
""" |
|
Martinez-Gil, J. (2025). Augmenting the Interpretability of GraphCodeBERT for Code Similarity Tasks. |
|
International Journal of Software Engineering and Knowledge Engineering, 35(05), 657–678. |
|
""" |
|
|
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
from sklearn.decomposition import PCA |
|
from transformers import RobertaTokenizer, RobertaModel |
|
import torch |
|
import gradio as gr |
|
from io import BytesIO |
|
from PIL import Image |
|
|
|
|
|
tokenizer = RobertaTokenizer.from_pretrained("microsoft/graphcodebert-base", cache_dir="models/") |
|
model = RobertaModel.from_pretrained("microsoft/graphcodebert-base", cache_dir="models/") |
|
|
|
|
|
sorting_algorithms = { |
|
"Bubble_Sort": """ |
|
def bubble_sort(arr): |
|
n = len(arr) |
|
for i in range(n): |
|
for j in range(0, n-i-1): |
|
if arr[j] > arr[j+1]: |
|
arr[j], arr[j+1] = arr[j+1], arr[j] |
|
return arr |
|
""", |
|
|
|
"Selection_Sort": """ |
|
def selection_sort(arr): |
|
for i in range(len(arr)): |
|
min_idx = i |
|
for j in range(i+1, len(arr)): |
|
if arr[j] < arr[min_idx]: |
|
min_idx = j |
|
arr[i], arr[min_idx] = arr[min_idx], arr[i] |
|
return arr |
|
""", |
|
|
|
"Insertion_Sort": """ |
|
def insertion_sort(arr): |
|
for i in range(1, len(arr)): |
|
key = arr[i] |
|
j = i-1 |
|
while j >= 0 and key < arr[j]: |
|
arr[j + 1] = arr[j] |
|
j -= 1 |
|
arr[j + 1] = key |
|
return arr |
|
""", |
|
|
|
"Merge_Sort": """ |
|
def merge_sort(arr): |
|
if len(arr) > 1: |
|
mid = len(arr) // 2 |
|
L = arr[:mid] |
|
R = arr[mid:] |
|
merge_sort(L) |
|
merge_sort(R) |
|
i = j = k = 0 |
|
while i < len(L) and j < len(R): |
|
if L[i] < R[j]: |
|
arr[k] = L[i] |
|
i += 1 |
|
else: |
|
arr[k] = R[j] |
|
j += 1 |
|
k += 1 |
|
while i < len(L): |
|
arr[k] = L[i] |
|
i += 1 |
|
k += 1 |
|
while j < len(R): |
|
arr[k] = R[j] |
|
j += 1 |
|
k += 1 |
|
return arr |
|
""", |
|
|
|
"Quick_Sort": """ |
|
def partition(arr, low, high): |
|
i = (low - 1) |
|
pivot = arr[high] |
|
for j in range(low, high): |
|
if arr[j] <= pivot: |
|
i += 1 |
|
arr[i], arr[j] = arr[j], arr[i] |
|
arr[i+1], arr[high] = arr[high], arr[i+1] |
|
return (i + 1) |
|
def quick_sort(arr, low, high): |
|
if low < high: |
|
pi = partition(arr, low, high) |
|
quick_sort(arr, low, pi - 1) |
|
quick_sort(arr, pi + 1, high) |
|
return arr |
|
""" |
|
} |
|
|
|
|
|
def get_token_embeddings(code): |
|
inputs = tokenizer(code, return_tensors="pt", max_length=512, truncation=True, padding=True) |
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
token_embeddings = outputs.last_hidden_state.squeeze(0).cpu().numpy() |
|
tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'].squeeze()) |
|
return token_embeddings, tokens |
|
|
|
|
|
def compare_algorithms(algo1_name, algo2_name): |
|
code1 = sorting_algorithms[algo1_name] |
|
code2 = sorting_algorithms[algo2_name] |
|
|
|
emb1, tokens1 = get_token_embeddings(code1) |
|
emb2, tokens2 = get_token_embeddings(code2) |
|
|
|
combined = np.concatenate([emb1, emb2], axis=0) |
|
pca = PCA(n_components=2) |
|
coords = pca.fit_transform(combined) |
|
|
|
plt.figure(figsize=(6, 5), dpi=150) |
|
plt.scatter(coords[:len(tokens1), 0], coords[:len(tokens1), 1], color='red', label=algo1_name, s=20) |
|
plt.scatter(coords[len(tokens1):, 0], coords[len(tokens1):, 1], color='blue', label=algo2_name, s=20) |
|
plt.legend() |
|
plt.xticks([]); plt.yticks([]); plt.grid(False) |
|
|
|
buf = BytesIO() |
|
plt.savefig(buf, format='png', bbox_inches='tight') |
|
plt.close() |
|
buf.seek(0) |
|
return Image.open(buf) |
|
|
|
|
|
interface = gr.Interface( |
|
fn=compare_algorithms, |
|
inputs=[ |
|
gr.Dropdown(choices=list(sorting_algorithms.keys()), label="Algorithm 1"), |
|
gr.Dropdown(choices=list(sorting_algorithms.keys()), label="Algorithm 2") |
|
], |
|
outputs=gr.Image(type="pil", label="Token Embedding PCA"), |
|
title="GraphCodeBERT Token Embedding Comparison", |
|
description="Visual comparison of token-level embeddings from GraphCodeBERT for classical sorting algorithms." |
|
) |
|
|
|
if __name__ == "__main__": |
|
interface.launch() |
|
|
|
|
|
|