lemur-7B / utils /inference.py
tianyang's picture
Update utils/inference.py
ed9d322
import torch
from transformers import LlamaTokenizer, LlamaForCausalLM
from peft import PeftModel
from typing import Iterator
from variables import SYSTEM, HUMAN, AI
def load_tokenizer_and_model(base_model, adapter_model, load_8bit=True):
"""
Loads the tokenizer and chatbot model.
Args:
base_model (str): The base model to use (path to the model).
adapter_model (str): The LoRA model to use (path to LoRA model).
load_8bit (bool): Whether to load the model in 8-bit mode.
"""
if torch.cuda.is_available():
device = "cuda"
else:
device = "cpu"
try:
if torch.backends.mps.is_available():
device = "mps"
except:
pass
tokenizer = LlamaTokenizer.from_pretrained(base_model)
if device == "cuda":
model = LlamaForCausalLM.from_pretrained(
base_model,
load_in_8bit=load_8bit,
torch_dtype=torch.float16
)
elif device == "mps":
model = LlamaForCausalLM.from_pretrained(
base_model,
device_map={"": device}
)
if adapter_model is not None:
model = PeftModel.from_pretrained(
model,
adapter_model,
device_map={"": device},
torch_dtype=torch.float16,
)
else:
model = LlamaForCausalLM.from_pretrained(
base_model,
device_map={"": device},
low_cpu_mem_usage=True,
torch_dtype=torch.bfloat16,
offload_folder="."
)
if adapter_model is not None:
model = PeftModel.from_pretrained(
model,
adapter_model,
torch_dtype=torch.bfloat16,
offload_folder="."
)
model.eval()
return tokenizer, model, device
class State:
interrupted = False
def interrupt(self):
self.interrupted = True
def recover(self):
self.interrupted = False
shared_state = State()
def decode(
input_ids: torch.Tensor,
model: PeftModel,
tokenizer: LlamaTokenizer,
stop_words: list,
max_length: int,
temperature: float = 1.0,
top_p: float = 1.0,
) -> Iterator[str]:
generated_tokens = []
past_key_values = None
for _ in range(max_length):
with torch.no_grad():
if past_key_values is None:
outputs = model(input_ids)
else:
outputs = model(input_ids[:, -1:], past_key_values=past_key_values)
logits = outputs.logits[:, -1, :]
past_key_values = outputs.past_key_values
# apply temperature
logits /= temperature
probs = torch.softmax(logits, dim=-1)
# apply top_p
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
mask = probs_sum - probs_sort > top_p
probs_sort[mask] = 0.0
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
next_token = torch.multinomial(probs_sort, num_samples=1)
next_token = torch.gather(probs_idx, -1, next_token)
input_ids = torch.cat((input_ids, next_token), dim=-1)
generated_tokens.append(next_token[0].item())
text = tokenizer.decode(generated_tokens)
yield text
if any([x in text for x in stop_words]):
return
def get_prompt_with_history(text, history, tokenizer, max_length=2048):
prompt = SYSTEM
history = [f"\n{HUMAN} {x[0]}\n{AI} {x[1]}" for x in history]
history.append(f"\n{HUMAN} {text}\n{AI}")
history_text = ""
flag = False
for x in history[::-1]:
if (
tokenizer(prompt + history_text + x, return_tensors="pt")["input_ids"].size(
-1
)
<= max_length
):
history_text = x + history_text
flag = True
else:
break
if flag:
return prompt + history_text, tokenizer(
prompt + history_text, return_tensors="pt"
)
else:
return None
def is_stop_word_or_prefix(s: str, stop_words: list) -> bool:
for stop_word in stop_words:
if s.endswith(stop_word):
return True
for i in range(1, len(stop_word)):
if s.endswith(stop_word[:i]):
return True
return False