Joaoffg-ELM / app.py
Joaoffg's picture
Update app.py
560647b verified
import gradio as gr
import torch
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
from utils.prompter import Prompter
class CustomPrompter(Prompter):
def get_response(self, output: str) -> str:
# Safely split on '### Response:'
split_output = output.split(self.template["response_split"], maxsplit=1)
if len(split_output) < 2:
return output.strip()
response_part = split_output[1].strip()
# Optionally strip out any subsequent '### Instruction:'
end_index = response_part.find("### Instruction:")
if end_index != -1:
response_part = response_part[:end_index].strip()
return response_part
prompt_template_name = "alpaca"
prompter = CustomPrompter(prompt_template_name)
def tokenize(prompt, add_eos_token=True):
result = tokenizer(
prompt,
truncation=True,
max_length=cutoff_len,
padding=False,
return_tensors=None,
)
if (
result["input_ids"][-1] != tokenizer.eos_token_id
and len(result["input_ids"]) < cutoff_len
and add_eos_token
):
result["input_ids"].append(tokenizer.eos_token_id)
result["attention_mask"].append(1)
result["labels"] = result["input_ids"].copy()
return result
def generate_and_tokenize_prompt(data_point):
full_prompt = prompter.generate_prompt(
data_point["instruction"],
data_point["input"],
data_point["output"],
)
tokenized_full_prompt = tokenize(full_prompt)
if not train_on_inputs:
user_prompt = prompter.generate_prompt(
data_point["instruction"], data_point["input"]
)
tokenized_user_prompt = tokenize(
user_prompt, add_eos_token=add_eos_token
)
user_prompt_len = len(tokenized_user_prompt["input_ids"])
if add_eos_token:
user_prompt_len -= 1
tokenized_full_prompt["labels"] = [
-100
] * user_prompt_len + tokenized_full_prompt["labels"][
user_prompt_len:
] # could be sped up, probably
return tokenized_full_prompt
def evaluate(instruction):
input = None
prompt = prompter.generate_prompt(instruction, input)
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"]
# Example generation config
temperature=0.2
top_p=0.95
top_k=25
num_beams=1
max_new_tokens=256
repetition_penalty = 2.0
do_sample = True
generation_config = transformers.GenerationConfig(
temperature=temperature,
top_p=top_p,
top_k=top_k,
num_beams=num_beams,
repetition_penalty=repetition_penalty,
do_sample=do_sample,
min_new_tokens=32,
num_return_sequences=1,
pad_token_id=0,
# Optionally define a stopping criterion to stop at '### Instruction:'
# stopping_criteria=StoppingCriteriaList([StopOnTokens(tokenizer.encode("### Instruction:", add_special_tokens=False))]),
)
with torch.no_grad():
generation_output = model.generate(
input_ids=input_ids,
generation_config=generation_config,
return_dict_in_generate=True,
output_scores=True,
max_new_tokens=max_new_tokens,
)
# For demo, just take the first sequence
output = tokenizer.decode(generation_output.sequences[0], skip_special_tokens=True)
return prompter.get_response(output)
interface = gr.Interface(
fn=evaluate,
inputs=[
gr.components.Textbox(
lines=2,
label="Instruction",
placeholder="Explain economic growth.",
),
],
outputs=[
gr.components.Textbox(
lines=5,
label="Output",
)
],
title="🌲 ELM - Erasmian Language Model",
description=(
"ELM is a 900M parameter language model finetuned to follow instruction. "
"It is trained on Erasmus University academic outputs and the "
"[Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca) dataset. "
"For more information, please visit [the GitHub repository](https://github.com/Joaoffg/ELM)."
),
)
interface.queue().launch()