Spaces:
Running
Running
File size: 4,247 Bytes
dcf2cc2 a6150d4 41102b8 a6150d4 560647b a6150d4 560647b f70beba 560647b a6150d4 560647b 3a7aa35 a6150d4 dcf2cc2 c12c6cc 560647b c12c6cc 560647b 2074c5b c12c6cc 2074c5b c12c6cc 2074c5b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import gradio as gr
import torch
import transformers
from transformers import AutoModelForCausalLM, AutoTokenizer
from utils.prompter import Prompter
class CustomPrompter(Prompter):
def get_response(self, output: str) -> str:
# Safely split on '### Response:'
split_output = output.split(self.template["response_split"], maxsplit=1)
if len(split_output) < 2:
return output.strip()
response_part = split_output[1].strip()
# Optionally strip out any subsequent '### Instruction:'
end_index = response_part.find("### Instruction:")
if end_index != -1:
response_part = response_part[:end_index].strip()
return response_part
prompt_template_name = "alpaca"
prompter = CustomPrompter(prompt_template_name)
def tokenize(prompt, add_eos_token=True):
result = tokenizer(
prompt,
truncation=True,
max_length=cutoff_len,
padding=False,
return_tensors=None,
)
if (
result["input_ids"][-1] != tokenizer.eos_token_id
and len(result["input_ids"]) < cutoff_len
and add_eos_token
):
result["input_ids"].append(tokenizer.eos_token_id)
result["attention_mask"].append(1)
result["labels"] = result["input_ids"].copy()
return result
def generate_and_tokenize_prompt(data_point):
full_prompt = prompter.generate_prompt(
data_point["instruction"],
data_point["input"],
data_point["output"],
)
tokenized_full_prompt = tokenize(full_prompt)
if not train_on_inputs:
user_prompt = prompter.generate_prompt(
data_point["instruction"], data_point["input"]
)
tokenized_user_prompt = tokenize(
user_prompt, add_eos_token=add_eos_token
)
user_prompt_len = len(tokenized_user_prompt["input_ids"])
if add_eos_token:
user_prompt_len -= 1
tokenized_full_prompt["labels"] = [
-100
] * user_prompt_len + tokenized_full_prompt["labels"][
user_prompt_len:
] # could be sped up, probably
return tokenized_full_prompt
def evaluate(instruction):
input = None
prompt = prompter.generate_prompt(instruction, input)
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"]
# Example generation config
temperature=0.2
top_p=0.95
top_k=25
num_beams=1
max_new_tokens=256
repetition_penalty = 2.0
do_sample = True
generation_config = transformers.GenerationConfig(
temperature=temperature,
top_p=top_p,
top_k=top_k,
num_beams=num_beams,
repetition_penalty=repetition_penalty,
do_sample=do_sample,
min_new_tokens=32,
num_return_sequences=1,
pad_token_id=0,
# Optionally define a stopping criterion to stop at '### Instruction:'
# stopping_criteria=StoppingCriteriaList([StopOnTokens(tokenizer.encode("### Instruction:", add_special_tokens=False))]),
)
with torch.no_grad():
generation_output = model.generate(
input_ids=input_ids,
generation_config=generation_config,
return_dict_in_generate=True,
output_scores=True,
max_new_tokens=max_new_tokens,
)
# For demo, just take the first sequence
output = tokenizer.decode(generation_output.sequences[0], skip_special_tokens=True)
return prompter.get_response(output)
interface = gr.Interface(
fn=evaluate,
inputs=[
gr.components.Textbox(
lines=2,
label="Instruction",
placeholder="Explain economic growth.",
),
],
outputs=[
gr.components.Textbox(
lines=5,
label="Output",
)
],
title="🌲 ELM - Erasmian Language Model",
description=(
"ELM is a 900M parameter language model finetuned to follow instruction. "
"It is trained on Erasmus University academic outputs and the "
"[Stanford Alpaca](https://github.com/tatsu-lab/stanford_alpaca) dataset. "
"For more information, please visit [the GitHub repository](https://github.com/Joaoffg/ELM)."
),
)
interface.queue().launch() |