WizardLM-1.6 / WizardCoder /src /inference_wizardcoder.py
Canstralian's picture
Upload 26 files
df96e38 verified
import sys
import os
import fire
import torch
import transformers
import json
import jsonlines
from transformers import AutoTokenizer, AutoModelForCausalLM, GenerationConfig
if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
try:
if torch.backends.mps.is_available():
device = "mps"
except:
pass
def evaluate(
batch_data,
tokenizer,
model,
input=None,
temperature=1,
top_p=0.9,
top_k=40,
num_beams=1,
max_new_tokens=2048,
**kwargs,
):
prompts = generate_prompt(batch_data, input)
inputs = tokenizer(prompts, return_tensors="pt", max_length=256, truncation=True, padding=True)
input_ids = inputs["input_ids"].to(device)
generation_config = GenerationConfig(
temperature=temperature,
top_p=top_p,
top_k=top_k,
num_beams=num_beams,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
**kwargs,
)
with torch.no_grad():
generation_output = model.generate(
input_ids=input_ids,
generation_config=generation_config,
return_dict_in_generate=True,
output_scores=True,
max_new_tokens=max_new_tokens,
)
s = generation_output.sequences
output = tokenizer.batch_decode(s, skip_special_tokens=True)
return output
def generate_prompt(instruction, input=None):
return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
### Instruction:
{instruction}
### Response:"""
def main(
load_8bit: bool = False,
base_model: str = "Model_Path",
input_data_path = "Input.jsonl",
output_data_path = "Output.jsonl",
):
assert base_model, (
"Please specify a --base_model, e.g. --base_model='bigcode/starcoder'"
)
tokenizer = AutoTokenizer.from_pretrained(base_model)
if device == "cuda":
model = AutoModelForCausalLM.from_pretrained(
base_model,
load_in_8bit=load_8bit,
torch_dtype=torch.float16,
device_map="auto",
)
elif device == "mps":
model = AutoModelForCausalLM.from_pretrained(
base_model,
device_map={"": device},
torch_dtype=torch.float16,
)
model.config.pad_token_id = tokenizer.pad_token_id
if not load_8bit:
model.half()
model.eval()
if torch.__version__ >= "2" and sys.platform != "win32":
model = torch.compile(model)
input_data = jsonlines.open(input_data_path, mode='r')
output_data = jsonlines.open(output_data_path, mode='w')
for num, line in enumerate(input_data):
one_data = line
id = one_data["idx"]
instruction = one_data["Instruction"]
print(instruction)
_output = evaluate(instruction, tokenizer, model)
final_output = _output[0].split("### Response:")[1].strip()
new_data = {
"id": id,
"instruction": instruction,
"wizardcoder": final_output
}
output_data.write(new_data)
if __name__ == "__main__":
fire.Fire(main)