martynattakit's picture
Update app.py
8a621d3 verified
raw
history blame
3.89 kB
import torch
from transformers import RobertaTokenizer, RobertaModel
import numpy as np
from scipy.special import softmax
import gradio as gr
import re
from huggingface_hub import hf_hub_download
# Define the model class with dimension reduction
class CodeClassifier(torch.nn.Module):
def __init__(self, base_model, num_labels=6):
super(CodeClassifier, self).__init__()
self.base = base_model
self.reduction = torch.nn.Linear(768, 512)
self.classifier = torch.nn.Linear(512, num_labels)
def forward(self, input_ids, attention_mask):
outputs = self.base(input_ids=input_ids, attention_mask=attention_mask)
reduced = self.reduction(outputs.pooler_output)
return self.classifier(reduced)
# Load base model and tokenizer from Hugging Face Model Hub
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = RobertaTokenizer.from_pretrained('microsoft/codebert-base')
base_model = RobertaModel.from_pretrained('microsoft/codebert-base')
# Initialize the CodeClassifier with the base model
model = CodeClassifier(base_model)
# Load the checkpoint from Hugginface Model Hub
checkpoint_path = hf_hub_download(repo_id="martynattakit/CodeSentinel-Model", filename="best_model.pt")
checkpoint = torch.load(checkpoint_path, map_location=device)
# Load the state dict, focusing on classifier weights
model_state = checkpoint.get('model_state_dict', checkpoint)
model.load_state_dict(model_state, strict=False)
print("Loaded state dict keys:", model.state_dict().keys())
print("Classifier weight shape:", model.classifier.weight.shape)
model.eval()
model.to(device)
# Label mapping with descriptions
label_map = {
0: ('none', 'No Vulnerability Detected'),
1: ('cwe-121', 'Stack-based Buffer Overflow'),
2: ('cwe-78', 'OS Command Injection'),
3: ('cwe-190', 'Integer Overflow or Wraparound'),
4: ('cwe-191', 'Integer Underflow'),
5: ('cwe-122', 'Heap-based Buffer Overflow')
}
def load_c_file(file):
try:
if file is None:
return ""
with open(file.name, 'r', encoding='utf-8') as f:
content = f.read()
return content
except Exception as e:
return f"Error reading file: {str(e)}"
def clean_code(code):
code = re.sub(r'/\*.*?\*/', '', code, flags=re.DOTALL)
code = re.sub(r'//.*$', '', code, flags=re.MULTILINE)
code = ' '.join(code.split())
return code
def evaluate_code(code):
try:
if len(code) >= 1500000:
return "Code too large"
cleaned_code = clean_code(code)
inputs = tokenizer(cleaned_code, return_tensors="pt", truncation=True, padding=True, max_length=256).to(device)
print("Input shape:", inputs['input_ids'].shape)
with torch.no_grad():
outputs = model(**inputs)
print("Raw logits:", outputs.cpu().numpy())
probs = softmax(outputs.cpu().numpy(), axis=1)
pred = np.argmax(probs, axis=1)[0]
cwe, description = label_map[pred]
return f"{cwe} {description}"
except Exception as e:
return f"Error during prediction: {str(e)}"
with gr.Blocks() as web:
with gr.Row():
with gr.Column(scale=1):
code_box = gr.Textbox(lines=20, label="** C/C++ Code", placeholder="Paste your C or C++ code here...")
with gr.Column(scale=1):
cc_file = gr.File(label="Upload C/C++ File (.c or .cpp)", file_types=[".c", ".cpp"])
check_btn = gr.Button("Check")
with gr.Row():
gr.Markdown("### Result:")
with gr.Row():
with gr.Column(scale=1):
label_box = gr.Textbox(label="Vulnerability", interactive=False)
cc_file.change(fn=load_c_file, inputs=cc_file, outputs=code_box)
check_btn.click(fn=evaluate_code, inputs=code_box, outputs=[label_box])
web.launch()