venky2k1 commited on
Commit
4f36dfb
·
1 Parent(s): 63fb23d

Added detect_bug using CodeBERT

Browse files
Files changed (1) hide show
  1. bug_detector.py +8 -4
bug_detector.py CHANGED
@@ -5,13 +5,17 @@ from transformers import AutoModelForSequenceClassification, AutoTokenizer
5
  model = AutoModelForSequenceClassification.from_pretrained("microsoft/codebert-base", num_labels=2)
6
  tokenizer = AutoTokenizer.from_pretrained("microsoft/codebert-base")
7
 
8
- def classify_code(code):
 
 
 
 
9
  inputs = tokenizer(code, return_tensors="pt", truncation=True, padding=True)
10
  outputs = model(**inputs)
11
  probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
12
  return "buggy" if probabilities[0][1] > probabilities[0][0] else "correct"
13
 
14
- # Test example
15
  if __name__ == "__main__":
16
- test_code = "def add(a, b): return a * b" # Buggy function
17
- print(classify_code(test_code))
 
5
  model = AutoModelForSequenceClassification.from_pretrained("microsoft/codebert-base", num_labels=2)
6
  tokenizer = AutoTokenizer.from_pretrained("microsoft/codebert-base")
7
 
8
+ def detect_bug(code):
9
+ """
10
+ Detects whether the provided code is buggy or correct using CodeBERT.
11
+ Returns "buggy" or "correct".
12
+ """
13
  inputs = tokenizer(code, return_tensors="pt", truncation=True, padding=True)
14
  outputs = model(**inputs)
15
  probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
16
  return "buggy" if probabilities[0][1] > probabilities[0][0] else "correct"
17
 
18
+ # Optional: test this locally
19
  if __name__ == "__main__":
20
+ test_code = "def add(a, b): return a * b" # Example of buggy code
21
+ print(detect_bug(test_code))