BioMed-LLaMa-3 / app.py
NouRed's picture
Update app.py
ca2100e verified
raw
history blame
5.12 kB
# -*- coding: utf-8 -*- Nour Eddine Zekaoui et al.
import os
import torch
import spaces
import gradio as gr
from peft import PeftModel
from transformers import (
AutoTokenizer,
BitsAndBytesConfig,
AutoModelForCausalLM)
# Set an environment variable
HF_TOKEN = os.environ.get("HF_TOKEN")
def generate_prompt(instruction, input=None):
if input:
return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. # noqa: E501
### Instruction:
{instruction}
### Input:
{input}
### Response:
"""
else:
return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request. # noqa: E501
### Instruction:
{instruction}
### Response:
"""
@spaces.GPU
def models():
based_model_path = "meta-llama/Meta-Llama-3-8B"
lora_weights = "NouRed/BioMed-Tuned-Llama-3-8b"
load_in_4bit=True
bnb_4bit_use_double_quant=True
bnb_4bit_quant_type="nf4"
bnb_4bit_compute_dtype=torch.bfloat16
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained(
based_model_path,
)
tokenizer.padding_side = 'right'
tokenizer.pad_token = tokenizer.eos_token
tokenizer.add_eos_token = True
quantization_config = BitsAndBytesConfig(
load_in_4bit=load_in_4bit,
bnb_4bit_use_double_quant=bnb_4bit_use_double_quant,
bnb_4bit_quant_type=bnb_4bit_quant_type,
bnb_4bit_compute_dtype=bnb_4bit_compute_dtype
)
base_model = AutoModelForCausalLM.from_pretrained(
based_model_path,
device_map="auto",
attn_implementation="flash_attention_2", # I have an A100 GPU with 40GB of RAM 😎
quantization_config=quantization_config,
)
model = PeftModel.from_pretrained(
base_model,
lora_weights,
torch_dtype=torch.float16,
)
return model, tokenizer
model, tokenizer = models()
@spaces.GPU
def generate(
instruction,
input=None,
temperature=0.1,
top_p=0.9,
top_k=40,
num_beams=4,
max_new_tokens=128,
do_sample=True,
**kwargs):
prompt = generate_prompt(instruction, input)
inputs = tokenizer(prompt, return_tensors="pt").to(device)
with torch.no_grad():
generated_ids = model.generate(
**inputs,
top_p=top_p,
top_k=top_k,
do_sample=do_sample,
max_new_tokens=max_new_tokens,
)
output = tokenizer.decode(
generated_ids[0],
skip_special_tokens=True,
clean_up_tokenization_spaces=True
)
response = output.split("### Response:")[1].strip()
return response
description = """
<div style="justify-content: center; text-align: center;">
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<h2>
<p> BioMed-LLaMa-3: Effecient Intruction Fine-Tuning in Biomedical Language</p>
</h2>
</div>
<div style="display: flex; justify-content: center; align-items: center; text-align: center;">
<a href="https://huggingface.co/NouRed/BioMed-Tuned-Llama-3-8b" target="_blank"><img src="https://img.shields.io/badge/πŸ€—_Hugging_Face-BioMedLLaMa3-orange" alt="HF HUB"></a> &nbsp;&nbsp;
<a href="https://colab.research.google.com/drive/1PDa8b5TqpAYxDVlF0Elv32KOM2kFaXJh" target="_blank"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Inference Notebook"></a>
</div>
</div>
"""
gr.Interface(
fn=generate,
inputs=[
gr.components.Textbox(
lines=2,
label="Instruction",
placeholder="Tell me about Covid-19?",
),
gr.components.Textbox(lines=2, label="Input", placeholder="none"),
gr.components.Slider(
minimum=0, maximum=1, value=0.1, label="Temperature"
),
gr.components.Slider(
minimum=0, maximum=1, value=0.9, label="Top p"
),
gr.components.Slider(
minimum=0, maximum=100, step=1, value=40, label="Top k"
),
gr.components.Slider(
minimum=1, maximum=4, step=1, value=4, label="Beams"
),
gr.components.Slider(
minimum=1, maximum=2000, step=1, value=128, label="Max tokens"
),
gr.components.Checkbox(
value=True, label="Do Sample", info="Do you want to use sampling during text generation?"
),
],
outputs=[
gr.components.Textbox(
lines=5,
label="Output",
)
],
examples=[
["Suggest treatment for pneumonia", "", 0.1, 0.9, 40, 4, 128, True],
["I have a sore throat, slight cough, tiredness. should i get tested fro covid 19?", "", 0.1, 0.9, 40, 4, 128, True],
["Husband of this patient asked me how to treat premature ejaculation and how to increase her libido.", "", 0.1, 0.9, 40, 4, 128, True],
],
theme="soft",
description=description, # noqa: E501
).launch()