|
|
|
|
|
|
|
|
|
import argparse |
|
import os |
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
from transformers.generation import GenerationConfig |
|
|
|
MaxLine = 50 |
|
SplitKey = ["\ndef "] |
|
CodeFileType = ["py"] |
|
|
|
def parse_args(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--path', type=str, default='Qwen-7B/eval/evaluate_ceval.py') |
|
parser.add_argument('--regenerate', action='store_true', default=False) |
|
args = parser.parse_args() |
|
return args |
|
|
|
class QWenChat(): |
|
def __init__(self): |
|
self.tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen-7B-Chat", trust_remote_code=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen-7B-Chat", device_map="auto", trust_remote_code=True).eval() |
|
|
|
|
|
self.model.generation_config = GenerationConfig.from_pretrained("Qwen/Qwen-7B-Chat", trust_remote_code=True) |
|
self.history = None |
|
|
|
def chat(self, query, system = ""): |
|
|
|
|
|
|
|
|
|
|
|
response, history = self.model.chat(self.tokenizer, query, history=None) |
|
self.history = history |
|
|
|
return response |
|
|
|
def gen_code_comments(context, model = None, **kwargs): |
|
prompt = "\n为以上代码生成细致的中文注释,注意使用合适的语法。要求必须在每个函数开头生成一段统一的函数功能注释。\n除了注释,请保证原始代码内容不变。不要返回除了注释和代码以外的其余信息,不要生成额外代码。\n" |
|
return model.chat(context + prompt) |
|
|
|
def read_file(path): |
|
f = open(path, "r",encoding='utf-8') |
|
lines = f.readlines() |
|
return "".join(lines) |
|
|
|
def write_file(path, context): |
|
with open(path,'w') as f: |
|
f.write(context) |
|
|
|
|
|
def split_context_by_maxline(text): |
|
lines = text.split("\n") |
|
lines_len = len(lines) |
|
res = [] |
|
for i in range(MaxLine, lines_len, MaxLine): |
|
res.append("\n".join(lines[i-MaxLine:i])) |
|
|
|
if i < lines_len: |
|
res.append("\n".join(lines[i:])) |
|
return res |
|
|
|
|
|
def split_context_by_splitkey(text): |
|
blocks = text.split(SplitKey[0]) |
|
return [blocks[0]] + [SplitKey[0]+x for x in blocks[1:]] |
|
|
|
|
|
def merge_code_and_comments(original_file, comments_path): |
|
res = [] |
|
ori_f = open(original_file, "r",encoding='utf-8') |
|
ori_lines = ori_f.readlines() |
|
|
|
com_f = open(comments_path, "r",encoding='utf-8') |
|
com_lines = com_f.readlines() |
|
len_com_lines = len(com_lines) |
|
p = 0 |
|
j = 0 |
|
for i, line in enumerate(ori_lines): |
|
if line.isspace(): |
|
continue |
|
if line.strip()[0] == '#': |
|
res.append(line) |
|
continue |
|
while j < len_com_lines and line[:-1] not in com_lines[j]: |
|
j += 1 |
|
if j < len_com_lines: |
|
p = j - 1 |
|
up_comments = [] |
|
triple_dot_flag = 0 |
|
while p < j: |
|
if p < 0 or (res and res[-1] and com_lines[p] == res[-1]): |
|
break |
|
if com_lines[p].strip() and (len(com_lines[p].strip())>3 and com_lines[p].strip()[-3:] == '"""' and com_lines[p].strip()[:3] == '"""') or (len(com_lines[p].strip())>3 and com_lines[p].strip()[-3:] == "'''" and com_lines[p].strip()[:3] == "'''"): |
|
up_comments.append(com_lines[p]) |
|
p -= 1 |
|
continue |
|
if com_lines[p].strip() and (com_lines[p].strip()[-3:] == '"""' or com_lines[p].strip()[:3] == '"""' or com_lines[p].strip()[-3:] == "'''" or com_lines[p].strip()[:3] == "'''"): |
|
triple_dot_flag = (triple_dot_flag + 1)%2 |
|
up_comments.append(com_lines[p]) |
|
p -= 1 |
|
continue |
|
if triple_dot_flag: |
|
up_comments.append(com_lines[p]) |
|
p -= 1 |
|
continue |
|
if (com_lines[p].strip()=="") or (com_lines[p].strip() and com_lines[p].strip()[0] == '#' and "省略部分内容" not in com_lines[p]): |
|
up_comments.append(com_lines[p]) |
|
else: |
|
break |
|
p -= 1 |
|
if up_comments: |
|
res.extend(reversed(up_comments)) |
|
if "#" in com_lines[j] and "#" not in line: |
|
in_line_comments = " #" + com_lines[j].split("#")[-1] |
|
res.append(line[:-1]+in_line_comments) |
|
else: |
|
res.append(line) |
|
p = j+1 |
|
else: |
|
res.append(line) |
|
j = p |
|
|
|
write_file(comments_path, "".join(res)) |
|
|
|
|
|
def deal_one_file(model, path, args): |
|
context = read_file(path) |
|
|
|
fname = path.split("/")[-1] |
|
fpath = "/".join(path.split("/")[:-1]) |
|
outfname = fname.split(".")[0]+"_comments."+fname.split(".")[-1] |
|
|
|
comments_path = os.path.join(fpath, outfname) |
|
if (not args.regenerate) and os.path.exists(comments_path): |
|
print("use cache: ", comments_path) |
|
return |
|
|
|
context_line = len(context.split("\n")) |
|
if context_line < MaxLine: |
|
res = gen_code_comments(context, model = model) |
|
elif SplitKey[0] not in context: |
|
context_list = split_context_by_maxline(context) |
|
res = "\n".join([gen_code_comments(context_block, model = model) for context_block in context_list]) |
|
else: |
|
context_list = split_context_by_splitkey(context) |
|
res = "\n".join([gen_code_comments(context_block, model = model) for context_block in context_list]) |
|
|
|
write_file(comments_path, res) |
|
merge_code_and_comments(path, comments_path) |
|
|
|
|
|
def deal_folder(model, path, args): |
|
for fl in os.listdir(path): |
|
now_path = os.path.join(path, fl) |
|
if os.path.isfile(now_path): |
|
if (now_path.split(".")[-1] in CodeFileType) and ("_comments" not in now_path): |
|
deal_one_file(model, now_path, args) |
|
elif os.path.isdir(now_path): |
|
deal_folder(model, now_path, args) |
|
else: |
|
print("Please specify a correct path!") |
|
|
|
def transfer(args): |
|
model = QWenChat() |
|
|
|
if os.path.isfile(args.path): |
|
if (args.path.split(".")[-1] in CodeFileType) and ("_comments" not in args.path): |
|
deal_one_file(model, args.path, args) |
|
elif os.path.isdir(args.path): |
|
deal_folder(model, args.path, args) |
|
else: |
|
print("Please specify a correct path!") |
|
|
|
if __name__ == '__main__': |
|
args = parse_args() |
|
print(args) |
|
transfer(args) |
|
|