venky2k1 commited on
Commit
bf718e9
·
1 Parent(s): 86f2f38

Initial commit

Browse files
__pycache__/bug_detector.cpython-313.pyc ADDED
Binary file (1.9 kB). View file
 
bug_detector.py CHANGED
@@ -1,36 +1,14 @@
1
- <<<<<<< HEAD
2
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
3
  import torch
4
 
5
- # Load CodeT5+ model for code fixing
6
- model_ckpt = "Salesforce/codet5p-220m"
7
- tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
8
- model = AutoModelForSeq2SeqLM.from_pretrained(model_ckpt)
9
-
10
- def fix_code(code):
11
- prompt = f"fix: {code}"
12
- inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True)
13
- output = model.generate(inputs["input_ids"], max_length=256)
14
- fixed_code = tokenizer.decode(output[0], skip_special_tokens=True)
15
- return fixed_code.strip()
16
- =======
17
- import torch
18
- from transformers import AutoModelForSequenceClassification, AutoTokenizer
19
-
20
- # Load pre-trained CodeBERT model
21
- model = AutoModelForSequenceClassification.from_pretrained("microsoft/codebert-base", num_labels=2)
22
- tokenizer = AutoTokenizer.from_pretrained("microsoft/codebert-base")
23
-
24
- def detect_bug(code):
25
- inputs = tokenizer(code, return_tensors="pt", truncation=True, padding=True)
26
- outputs = model(**inputs)
27
- probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
28
- return "buggy" if probabilities[0][1] > probabilities[0][0] else "correct"
29
-
30
- # Optional test
31
- if __name__ == "__main__":
32
- sample = "def multiply(a, b): return a + b"
33
- print(detect_bug(sample))
34
- #detects if there's a bug in code
35
-
36
- >>>>>>> 22b22edd4386cff48f5ad4c4325e1f8524238b52
 
 
1
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
2
  import torch
3
 
4
+ # Load CodeT5 or similar model
5
+ model_name = "Salesforce/codet5-base"
6
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
7
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
8
+
9
+ def fix_code(code: str) -> str:
10
+ input_text = f"fix: {code}"
11
+ inputs = tokenizer(input_text, return_tensors="pt", padding=True, truncation=True, max_length=512)
12
+ with torch.no_grad():
13
+ output = model.generate(**inputs, max_length=512)
14
+ return tokenizer.decode(output[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt DELETED
@@ -1,5 +0,0 @@
1
- flask
2
- flask-cors
3
- torch
4
- transformers
5
- gradio
 
 
 
 
 
 
requirments.txt CHANGED
@@ -1,4 +1,3 @@
1
  transformers
2
  torch
3
  gradio
4
- flask
 
1
  transformers
2
  torch
3
  gradio