import jsonlines import argparse import pprint import sys import os import re from tqdm import tqdm import torch from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig from human_eval.data import write_jsonl, read_problems, stream_jsonl if torch.cuda.is_available(): device = "cuda" else: device = "cpu" try: if torch.backends.mps.is_available(): device = "mps" except: pass def read_mbpp(path): mbpp_problems = {} with jsonlines.open(path, "r") as fin: for obj in fin: mbpp_problems[obj["task_id"]] = obj return mbpp_problems def extract_text(prompt, remove_lines=True): token = '\"\"\"' start = token end = '>>>' start_idx = prompt.find(start) + len(start) end_idx = prompt.find(end) output = prompt[start_idx: end_idx] if remove_lines: output = output.replace('\n', ' ') output = re.sub(r"\s+", " ", output).strip() return output def generate_prompt(input): INSTRUCTION = f"""Below is an instruction that describes a task. Write a response that appropriately completes the request. ### Instruction: Create a Python script for this problem: {input} ### Response:""" return INSTRUCTION def get_model( load_8bit: bool = False, base_model: str = "bigcode/starcoder", ): assert base_model, ( "Please specify a --base_model, e.g. --base_model='bigcode/starcoder'" ) tokenizer = AutoTokenizer.from_pretrained(base_model) if device == "cuda": model = AutoModelForCausalLM.from_pretrained( base_model, load_in_8bit=load_8bit, torch_dtype=torch.float16, device_map="auto", ) elif device == "mps": model = AutoModelForCausalLM.from_pretrained( base_model, device_map={"": device}, torch_dtype=torch.float16, ) model.config.pad_token_id = tokenizer.pad_token_id if not load_8bit: model.half() # seems to fix bugs for some users. model.eval() if torch.__version__ >= "2" and sys.platform != "win32": model = torch.compile(model) return tokenizer, model def main(): parser = argparse.ArgumentParser() parser.add_argument('--model', type=str, default='bigcode/starcoder', help="") parser.add_argument('--output_path', type=str, help="") parser.add_argument('--start_index', type=int, default=0, help="") parser.add_argument('--end_index', type=int, default=164, help="") parser.add_argument('--temperature', type=float, default=0.8, help="") parser.add_argument('--N', type=int, default=200, help="") parser.add_argument('--max_len', type=int, default=512, help="") parser.add_argument('--decoding_style', type=str, default='sampling', help="") parser.add_argument('--num_seqs_per_iter', type=int, default=50, help='') parser.add_argument('--overwrite', action='store_true', help='') parser.add_argument('--mbpp_path', type=str, help="") args = parser.parse_args() argsdict = vars(args) print(pprint.pformat(argsdict)) STOP_SEQS = ['\nclass', '\ndef', '\n#', '\nif', '\nprint'] problems = read_mbpp(args.mbpp_path) task_ids = sorted(problems.keys())[args.start_index: args.end_index] prompts = [] for task_id in task_ids: prompt = f"\n{problems[task_id]['text']}\nTest examples:" if task_id == 493: # The test examples are too long. We choose to only include the function name. test_example = problems[task_id]['test_list'][0] prompt += f"\ncalculate_polygons(startx, starty, endx, endy, radius)" else: for test_example in problems[task_id]['test_list']: prompt += f"\n{test_example}" prompts.append(prompt) num_samples = len(prompts) print("Number of samples: {}".format(num_samples)) tokenizer, model = get_model(base_model=args.model) generation_config = GenerationConfig( pad_token_id=tokenizer.pad_token_id, do_sample=True, temperature=args.temperature, max_length=args.max_len, num_return_sequences=args.num_seqs_per_iter, eos_token_id=tokenizer.eos_token_id, top_p=0.95 ) print(f"Loaded {args.model}.") for i in tqdm(range(num_samples), ncols=0, total=num_samples): output_file = args.output_path + '/{}.jsonl'.format(args.start_index + i) if os.path.exists(output_file) and not args.overwrite: print(f'Skip {output_file} as it already exists') continue prompt = prompts[i].replace(' ', '\t') prompt_batch = [generate_prompt(prompt)] ids_batch = [task_ids[i]] completion_seqs = [] encoding = tokenizer(prompt_batch, return_tensors="pt", truncation=True, max_length=args.max_len).to(device) if args.decoding_style == 'sampling': loops = int(args.N / args.num_seqs_per_iter) else: loops = 1 for _ in tqdm(range(loops), total=loops, leave=False, ncols=0): with torch.no_grad(): if args.decoding_style == 'sampling': gen_tokens = model.generate( **encoding, generation_config=generation_config ) if gen_tokens is not None: gen_seqs = tokenizer.batch_decode(gen_tokens, skip_special_tokens=True) else: gen_seqs = None if gen_seqs is not None: assert len(ids_batch) == 1 task_id = ids_batch[0] for seq_idx, gen_seq in enumerate(gen_seqs): completion_seq = gen_seq.split("### Response:")[-1] completion_seq = completion_seq.replace('\t', ' ') all_code = gen_seq.replace('\t', ' ') completion_seqs.append( {'task_id': task_id, 'completion': completion_seq, 'all_code': all_code, } ) print("Saving results to {}".format(output_file)) write_jsonl(output_file, completion_seqs) if __name__ == '__main__': main()