NouRed commited on
Commit
e7209b2
·
verified ·
1 Parent(s): ca382da

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +165 -0
app.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*- Nour Eddine Zekaoui et al.
2
+
3
+ import torch
4
+ import gradio as gr
5
+
6
+ from peft import PeftModel
7
+ from transformers import (
8
+ AutoTokenizer,
9
+ BitsAndBytesConfig,
10
+ AutoModelForCausalLM)
11
+
12
+
13
+ def generate_prompt(instruction, input=None):
14
+ if input:
15
+ return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. # noqa: E501
16
+
17
+ ### Instruction:
18
+ {instruction}
19
+
20
+ ### Input:
21
+ {input}
22
+
23
+ ### Response:
24
+ """
25
+ else:
26
+ return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request. # noqa: E501
27
+
28
+ ### Instruction:
29
+ {instruction}
30
+
31
+ ### Response:
32
+ """
33
+
34
+
35
+ based_model_path = "meta-llama/Meta-Llama-3-8B"
36
+ lora_weights = "NouRed/BioMed-Tuned-Llama-3-8b"
37
+
38
+ load_in_4bit=True
39
+ bnb_4bit_use_double_quant=True
40
+ bnb_4bit_quant_type="nf4"
41
+ bnb_4bit_compute_dtype=torch.bfloat16
42
+ device = torch.device("cuda" if torch.cuda.is_available() else "CPU")
43
+
44
+
45
+ tokenizer = AutoTokenizer.from_pretrained(
46
+ based_model_path,
47
+ )
48
+
49
+ tokenizer.padding_side = 'right'
50
+ tokenizer.pad_token = tokenizer.eos_token
51
+ tokenizer.add_eos_token = True
52
+
53
+
54
+ quantization_config = BitsAndBytesConfig(
55
+ load_in_4bit=load_in_4bit,
56
+ bnb_4bit_use_double_quant=bnb_4bit_use_double_quant,
57
+ bnb_4bit_quant_type=bnb_4bit_quant_type,
58
+ bnb_4bit_compute_dtype=bnb_4bit_compute_dtype
59
+ )
60
+
61
+ base_model = AutoModelForCausalLM.from_pretrained(
62
+ based_model_path,
63
+ device_map="auto",
64
+ attn_implementation="flash_attention_2", # I have an A100 GPU with 40GB of RAM 😎
65
+ quantization_config=quantization_config,
66
+ )
67
+
68
+ model = PeftModel.from_pretrained(
69
+ base_model,
70
+ lora_weights,
71
+ torch_dtype=torch.float16,
72
+ )
73
+
74
+
75
+ def generate(
76
+ instruction,
77
+ input=None,
78
+ temperature=0.1,
79
+ top_p=0.9,
80
+ top_k=40,
81
+ num_beams=4,
82
+ max_new_tokens=128,
83
+ do_sample=True,
84
+ **kwargs):
85
+
86
+ prompt = generate_prompt(instruction, input)
87
+ inputs = tokenizer(prompt, return_tensors="pt").to(device)
88
+
89
+ with torch.no_grad():
90
+ generated_ids = model.generate(
91
+ **inputs,
92
+ top_p=top_p,
93
+ top_k=top_k,
94
+ do_sample=do_sample,
95
+ max_new_tokens=max_new_tokens,
96
+ )
97
+
98
+ output = tokenizer.decode(
99
+ generated_ids[0],
100
+ skip_special_tokens=True,
101
+ clean_up_tokenization_spaces=True
102
+ )
103
+
104
+ response = output.split("### Response:")[1].strip()
105
+
106
+ return response
107
+
108
+
109
+ description = """
110
+ <div style="justify-content: center; text-align: center;">
111
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
112
+ <h2>
113
+ <p> BioMed-LLaMa-3: Effecient Intruction Fine-Tuning in Biomedical Language</p>
114
+ </h2>
115
+ </div>
116
+ <div style="display: flex; justify-content: center; align-items: center; text-align: center;">
117
+ <a href="https://huggingface.co/NouRed/BioMed-Tuned-Llama-3-8b" target="_blank"><img src="https://img.shields.io/badge/🤗_Hugging_Face-BioMedLLaMa3-orange" alt="HF HUB"></a> &nbsp;&nbsp;
118
+ <a href="https://colab.research.google.com/drive/1PDa8b5TqpAYxDVlF0Elv32KOM2kFaXJh" target="_blank"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Inference Notebook"></a>
119
+ </div>
120
+ </div>
121
+ """
122
+
123
+
124
+ gr.Interface(
125
+ fn=generate,
126
+ inputs=[
127
+ gr.components.Textbox(
128
+ lines=2,
129
+ label="Instruction",
130
+ placeholder="Tell me about Covid-19?",
131
+ ),
132
+ gr.components.Textbox(lines=2, label="Input", placeholder="none"),
133
+ gr.components.Slider(
134
+ minimum=0, maximum=1, value=0.1, label="Temperature"
135
+ ),
136
+ gr.components.Slider(
137
+ minimum=0, maximum=1, value=0.9, label="Top p"
138
+ ),
139
+ gr.components.Slider(
140
+ minimum=0, maximum=100, step=1, value=40, label="Top k"
141
+ ),
142
+ gr.components.Slider(
143
+ minimum=1, maximum=4, step=1, value=4, label="Beams"
144
+ ),
145
+ gr.components.Slider(
146
+ minimum=1, maximum=2000, step=1, value=128, label="Max tokens"
147
+ ),
148
+ gr.components.Checkbox(
149
+ value=True, label="Do Sample", info="Do you want to use sampling during text generation?"
150
+ ),
151
+ ],
152
+ outputs=[
153
+ gr.components.Textbox(
154
+ lines=5,
155
+ label="Output",
156
+ )
157
+ ],
158
+ examples=[
159
+ ["Suggest treatment for pneumonia", "", 0.1, 0.9, 40, 4, 128, True],
160
+ ["I have a sore throat, slight cough, tiredness. should i get tested fro covid 19?", "", 0.1, 0.9, 40, 4, 128, True],
161
+ ["Husband of this patient asked me how to treat premature ejaculation and how to increase her libido.", "", 0.1, 0.9, 40, 4, 128, True],
162
+ ],
163
+ theme="soft",
164
+ description=description, # noqa: E501
165
+ ).launch()