File size: 4,409 Bytes
3fffb69 3d716d5 3fffb69 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
# -*- coding: utf-8 -*-
"""
[Martinez-Gil2025] Martinez-Gil, J. (2025).
Augmenting the Interpretability of GraphCodeBERT for Code Similarity Tasks.
International Journal of Software Engineering and Knowledge Engineering, 35(05), 657-678.
@author: Jorge Martinez-Gil
"""
import numpy as np
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
from transformers import RobertaTokenizer, RobertaModel
import torch
import gradio as gr
from io import BytesIO
from PIL import Image
# Load GraphCodeBERT model
tokenizer = RobertaTokenizer.from_pretrained("microsoft/graphcodebert-base")
model = RobertaModel.from_pretrained("microsoft/graphcodebert-base")
# Define sorting algorithms as strings
sorting_algorithms = {
"Bubble_Sort": """
def bubble_sort(arr):
n = len(arr)
for i in range(n):
for j in range(0, n-i-1):
if arr[j] > arr[j+1]:
arr[j], arr[j+1] = arr[j+1], arr[j]
return arr
""",
"Selection_Sort": """
def selection_sort(arr):
for i in range(len(arr)):
min_idx = i
for j in range(i+1, len(arr)):
if arr[j] < arr[min_idx]:
min_idx = j
arr[i], arr[min_idx] = arr[min_idx], arr[i]
return arr
""",
"Insertion_Sort": """
def insertion_sort(arr):
for i in range(1, len(arr)):
key = arr[i]
j = i-1
while j >= 0 and key < arr[j]:
arr[j + 1] = arr[j]
j -= 1
arr[j + 1] = key
return arr
""",
"Merge_Sort": """
def merge_sort(arr):
if len(arr) > 1:
mid = len(arr) // 2
L = arr[:mid]
R = arr[mid:]
merge_sort(L)
merge_sort(R)
i = j = k = 0
while i < len(L) and j < len(R):
if L[i] < R[j]:
arr[k] = L[i]
i += 1
else:
arr[k] = R[j]
j += 1
k += 1
while i < len(L):
arr[k] = L[i]
i += 1
k += 1
while j < len(R):
arr[k] = R[j]
j += 1
k += 1
return arr
""",
"Quick_Sort": """
def partition(arr, low, high):
i = (low - 1)
pivot = arr[high]
for j in range(low, high):
if arr[j] <= pivot:
i += 1
arr[i], arr[j] = arr[j], arr[i]
arr[i+1], arr[high] = arr[high], arr[i+1]
return (i + 1)
def quick_sort(arr, low, high):
if low < high:
pi = partition(arr, low, high)
quick_sort(arr, low, pi - 1)
quick_sort(arr, pi + 1, high)
return arr
"""
}
# Get token embeddings for a code snippet
def get_token_embeddings(code):
inputs = tokenizer(code, return_tensors="pt", max_length=512, truncation=True, padding=True)
with torch.no_grad():
outputs = model(**inputs)
token_embeddings = outputs.last_hidden_state.squeeze(0).cpu().numpy()
tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'].squeeze())
return token_embeddings, tokens
# Compare two algorithms and return PCA scatter plot
def compare_algorithms(algo1_name, algo2_name):
code1 = sorting_algorithms[algo1_name]
code2 = sorting_algorithms[algo2_name]
emb1, tokens1 = get_token_embeddings(code1)
emb2, tokens2 = get_token_embeddings(code2)
combined = np.concatenate([emb1, emb2], axis=0)
pca = PCA(n_components=2)
coords = pca.fit_transform(combined)
plt.figure(figsize=(6, 5), dpi=150)
plt.scatter(coords[:len(tokens1), 0], coords[:len(tokens1), 1], color='red', label=algo1_name, s=20)
plt.scatter(coords[len(tokens1):, 0], coords[len(tokens1):, 1], color='blue', label=algo2_name, s=20)
plt.legend()
plt.xticks([]); plt.yticks([]); plt.grid(False)
buf = BytesIO()
plt.savefig(buf, format='png', bbox_inches='tight')
plt.close()
buf.seek(0)
return Image.open(buf)
# Gradio interface
interface = gr.Interface(
fn=compare_algorithms,
inputs=[
gr.Dropdown(choices=list(sorting_algorithms.keys()), label="Algorithm 1"),
gr.Dropdown(choices=list(sorting_algorithms.keys()), label="Algorithm 2")
],
outputs=gr.Image(type="pil", label="Token Embedding PCA"),
title="GraphCodeBERT Token Embedding Comparison",
description="Visual comparison of token-level embeddings from GraphCodeBERT for classical sorting algorithms."
)
if __name__ == "__main__":
interface.launch()
|