English
Generated from Trainer
zephyr-7b-beta / handler.py
Michael Brunzel
Add files for the custom inference of the zephyr beta model
8c88b7a
raw
history blame
1.99 kB
from typing import Dict, Any, List
from transformers import pipeline
import torch
class EndpointHandler:
def __init__(self, path=""):
# load model and processor from path
self.pipe = pipeline("text-generation", model="HuggingFaceH4/zephyr-7b-beta", torch_dtype=torch.bfloat16, device_map="auto")
def __call__(self, data: Dict[str, Any]) -> List[List[Dict[str, str]]]:
"""
Args:
data (:dict:):
The payload with the text prompt and generation parameters.
"""
# process input
inputs = data.pop("inputs", data)
parameters = data.pop("parameters", None)
if isinstance(inputs, list) and isinstance(inputs[0], list) or isinstance(inputs[0], dict):
if isinstance(inputs[0], dict):
inputs = [inputs]
messages = inputs
else:
if isinstance(inputs, str):
messages = [[
{
"role": "system",
"content": "You are a helpful AI assistant",
},
{"role": "user", "content": inputs},
]]
else:
messages = [[
{
"role": "system",
"content": "You are a helpful AI assistant",
},
{"role": "user", "content": input},
] for input in inputs]
prompts = []
for message in messages:
prompts += [self.pipe.tokenizer.apply_chat_template(message, tokenize=False, add_generation_prompt=True)]
# pass inputs with all kwargs in data
if parameters is not None:
outputs = self.pipe(
prompts,
**parameters)
else:
outputs = self.pipe(
prompts, max_new_tokens=32,
)
return [{"generated_text": outputs}]