whan12 commited on
Commit
8ed6e93
Β·
verified Β·
1 Parent(s): 56451ce

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +169 -83
app.py CHANGED
@@ -4,15 +4,19 @@ import copy
4
  import gradio as gr
5
  import PIL.Image
6
  import torch
7
- from transformers import BitsAndBytesConfig, pipeline, LlavaNextProcessor, LlavaNextForConditionalGeneration
 
8
  import re
9
  import time
10
 
11
  DESCRIPTION = "# LLaVA πŸ’ͺ - THE IRON PUMPING MACHINE VISION BEAST"
12
 
13
  model_id = "llava-hf/llava-v1.6-vicuna-7b-hf"
 
 
 
 
14
 
15
- pipe = LlavaNextForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.float16, low_cpu_mem_usage=True)
16
 
17
  def extract_response_pairs(text):
18
  turns = re.split(r'(USER:|ASSISTANT:)', text)[1:]
@@ -24,72 +28,69 @@ def extract_response_pairs(text):
24
 
25
  return conv_list
26
 
27
- def add_text(history, text):
28
- history = history + [[text, None]]
29
- return history, text
30
-
31
- def infer(image, prompt, temperature, length_penalty, repetition_penalty, max_length, min_length, top_p):
32
- outputs = pipe(images=image, prompt=prompt,
33
- generate_kwargs={"temperature": temperature,
34
- "length_penalty": length_penalty,
35
- "repetition_penalty": repetition_penalty,
36
- "max_length": max_length,
37
- "min_length": min_length,
38
- "top_p": top_p})
39
- inference_output = outputs[0]["generated_text"]
40
- return inference_output
41
-
42
- def arnold_speak(text):
43
- # Add Arnold Schwarzenegger-style phrases and modify speech
44
- arnold_phrases = [
45
- "Come with me if you want to lift!",
46
- "I'll be back... after my protein shake.",
47
- "Hasta la vista, baby weight!",
48
- "Get to da choppa... I mean, da squat rack!",
49
- "You lack discipline! But don't worry, I'm here to pump you up!"
50
- ]
51
-
52
- text = text.replace(".", "!") # More enthusiastic punctuation
53
- text = text.replace("gym", "iron paradise")
54
- text = text.replace("exercise", "pump iron")
55
- text = text.replace("workout", "sculpt your physique")
56
-
57
- # Add random Arnold phrase to the end
58
- text += " " + arnold_phrases[torch.randint(0, len(arnold_phrases), (1,)).item()]
59
-
60
- return text
61
 
62
- def bot(history_chat, text_input, image, temperature, length_penalty, repetition_penalty, max_length, min_length, top_p, arnold_mode):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  if text_input == "":
64
  gr.Warning("Please input text")
65
 
66
- if image is None:
67
  gr.Warning("Please input image or wait for image to be uploaded before clicking submit.")
68
-
69
- chat_history = " ".join([item for sublist in history_chat for item in sublist]) # Flatten history
70
-
71
- if arnold_mode:
72
- system_prompt = "you are a bodybuilding coach, and you sound like Arnold Schwarzenegger. Give advice on gains, training, and inspire me at the end. Use Arnold's catchphrases and speaking style."
73
- else:
74
- system_prompt = "You are a helpful AI assistant. Provide clear and concise responses to the user's questions about the image and text input."
75
-
76
- chat_history = f"{system_prompt}\n{chat_history}\nUSER: <image>\n{text_input}\nASSISTANT:"
77
-
78
- inference_result = infer(image, chat_history, temperature, length_penalty, repetition_penalty, max_length, min_length, top_p)
79
  chat_val = extract_response_pairs(inference_result)
80
 
 
81
  chat_state_list = copy.deepcopy(chat_val)
82
- chat_state_list[-1][1] = "" # empty last response
83
-
84
- response = chat_val[-1][1]
85
- if arnold_mode:
86
- response = arnold_speak(response)
87
 
88
- for character in response:
 
89
  chat_state_list[-1][1] += character
90
  time.sleep(0.05)
 
91
  yield chat_state_list
92
 
 
93
  css = """
94
  #mkd {
95
  height: 500px;
@@ -97,52 +98,137 @@ css = """
97
  border: 1px solid #ccc;
98
  }
99
  """
100
-
101
- with gr.Blocks(css=css) as demo:
102
  gr.Markdown(DESCRIPTION)
103
  gr.Markdown("""## LLaVA, one of the greatest multimodal chat models is now available in Transformers with 4-bit quantization! ⚑️
104
  See the docs here: https://huggingface.co/docs/transformers/main/en/model_doc/llava.""")
105
  chatbot = gr.Chatbot(label="Chat", show_label=False)
106
  gr.Markdown("Input image and text and start chatting πŸ‘‡")
107
  with gr.Row():
108
- image = gr.Image(type="pil")
109
- text_input = gr.Text(label="Chat Input", show_label=False, max_lines=3, container=False)
 
110
 
111
  history_chat = gr.State(value=[])
112
- arnold_mode = gr.Checkbox(label="Arnold Schwarzenegger Mode", value=False)
113
 
114
  with gr.Accordion(label="Advanced settings", open=False):
115
- temperature = gr.Slider(label="Temperature", info="Used with nucleus sampling.", minimum=0.5, maximum=1.0, step=0.1, value=1.0)
116
- length_penalty = gr.Slider(label="Length Penalty", info="Set to larger for longer sequence, used with beam search.", minimum=-1.0, maximum=2.0, step=0.2, value=1.0)
117
- repetition_penalty = gr.Slider(label="Repetition Penalty", info="Larger value prevents repetition.", minimum=1.0, maximum=5.0, step=0.5, value=1.5)
118
- max_length = gr.Slider(label="Max Length", minimum=1, maximum=500, step=1, value=200)
119
- min_length = gr.Slider(label="Minimum Length", minimum=1, maximum=100, step=1, value=1)
120
- top_p = gr.Slider(label="Top P", info="Used with nucleus sampling.", minimum=0.5, maximum=1.0, step=0.1, value=0.9)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
- chat_inputs = [image, text_input, temperature, length_penalty, repetition_penalty, max_length, min_length, top_p, history_chat, arnold_mode]
123
 
 
 
 
 
 
 
 
 
 
 
 
124
  with gr.Row():
125
- clear_chat_button = gr.Button("Clear")
126
- cancel_btn = gr.Button("Stop Generation")
127
- chat_button = gr.Button("Submit", variant="primary")
128
 
129
- chat_event1 = chat_button.click(add_text, [chatbot, text_input], [chatbot, text_input]).then(
130
- bot, chat_inputs, chatbot
131
- )
 
 
 
 
132
 
133
- chat_event2 = text_input.submit(add_text, [chatbot, text_input], [chatbot, text_input]).then(
134
- bot, chat_inputs, chatbot
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
 
137
- clear_chat_button.click(lambda: ([], []), inputs=None, outputs=[chatbot, history_chat], queue=False, api_name="clear")
138
- image.change(lambda: ([], []), inputs=None, outputs=[chatbot, history_chat], queue=False)
139
- cancel_btn.click(None, [], [], cancels=[chat_event1, chat_event2])
140
 
141
- examples = [
142
- ["./examples/baklava.png", "How to make this pastry?"],
143
- ["./examples/bee.png", "Describe this image."]
144
- ]
145
- gr.Examples(examples=examples, inputs=[image, text_input])
146
 
147
  if __name__ == "__main__":
148
  demo.queue(max_size=10).launch(debug=True)
 
4
  import gradio as gr
5
  import PIL.Image
6
  import torch
7
+ from transformers import BitsAndBytesConfig, pipeline,LlavaNextProcessor, LlavaNextForConditionalGeneration
8
+ import torch
9
  import re
10
  import time
11
 
12
  DESCRIPTION = "# LLaVA πŸ’ͺ - THE IRON PUMPING MACHINE VISION BEAST"
13
 
14
  model_id = "llava-hf/llava-v1.6-vicuna-7b-hf"
15
+
16
+
17
+ pipe = LlavaNextForConditionalGeneration.from_pretrained(model_id , torch_dtype=torch.float16, low_cpu_mem_usage=True)
18
+
19
 
 
20
 
21
  def extract_response_pairs(text):
22
  turns = re.split(r'(USER:|ASSISTANT:)', text)[1:]
 
28
 
29
  return conv_list
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
+
33
+ def add_text(history, text):
34
+ history = history.append([text, None])
35
+ return history, text
36
+
37
+ def infer(image, prompt,
38
+ temperature,
39
+ length_penalty,
40
+ repetition_penalty,
41
+ max_length,
42
+ min_length,
43
+ top_p):
44
+
45
+ outputs = pipe(images=image, prompt=prompt,
46
+ generate_kwargs={"temperature":temperature,
47
+ "length_penalty":length_penalty,
48
+ "repetition_penalty":repetition_penalty,
49
+ "max_length":max_length,
50
+ "min_length":min_length,
51
+ "top_p":top_p})
52
+ inference_output = outputs[0]["generated_text"]
53
+ return inference_output
54
+
55
+
56
+
57
+ def bot(history_chat, text_input, image,
58
+ temperature,
59
+ length_penalty,
60
+ repetition_penalty,
61
+ max_length,
62
+ min_length,
63
+ top_p):
64
+
65
  if text_input == "":
66
  gr.Warning("Please input text")
67
 
68
+ if image==None:
69
  gr.Warning("Please input image or wait for image to be uploaded before clicking submit.")
70
+ chat_history = " ".join(history_chat) # history as a str to be passed to model
71
+ chat_history = "you are a bodybuilding coach,and you sounds like arnold schwarzenegger, give advice on my gains, training and inspire me at the end"+chat_history + f"USER: <image>\n{text_input}\nASSISTANT:" # add text input for prompting
72
+ inference_result = infer(image, chat_history,
73
+ temperature,
74
+ length_penalty,
75
+ repetition_penalty,
76
+ max_length,
77
+ min_length,
78
+ top_p)
79
+ # return inference and parse for new history
 
80
  chat_val = extract_response_pairs(inference_result)
81
 
82
+ # create history list for yielding the last inference response
83
  chat_state_list = copy.deepcopy(chat_val)
84
+ chat_state_list[-1][1] = "" # empty last response
 
 
 
 
85
 
86
+ # add characters iteratively
87
+ for character in chat_val[-1][1]:
88
  chat_state_list[-1][1] += character
89
  time.sleep(0.05)
90
+ # yield history but with last response being streamed
91
  yield chat_state_list
92
 
93
+
94
  css = """
95
  #mkd {
96
  height: 500px;
 
98
  border: 1px solid #ccc;
99
  }
100
  """
101
+ with gr.Blocks(css="style.css") as demo:
 
102
  gr.Markdown(DESCRIPTION)
103
  gr.Markdown("""## LLaVA, one of the greatest multimodal chat models is now available in Transformers with 4-bit quantization! ⚑️
104
  See the docs here: https://huggingface.co/docs/transformers/main/en/model_doc/llava.""")
105
  chatbot = gr.Chatbot(label="Chat", show_label=False)
106
  gr.Markdown("Input image and text and start chatting πŸ‘‡")
107
  with gr.Row():
108
+
109
+ image = gr.Image(type="pil")
110
+ text_input = gr.Text(label="Chat Input", show_label=False, max_lines=3, container=False)
111
 
112
  history_chat = gr.State(value=[])
 
113
 
114
  with gr.Accordion(label="Advanced settings", open=False):
115
+ temperature = gr.Slider(
116
+ label="Temperature",
117
+ info="Used with nucleus sampling.",
118
+ minimum=0.5,
119
+ maximum=1.0,
120
+ step=0.1,
121
+ value=1.0,
122
+ )
123
+ length_penalty = gr.Slider(
124
+ label="Length Penalty",
125
+ info="Set to larger for longer sequence, used with beam search.",
126
+ minimum=-1.0,
127
+ maximum=2.0,
128
+ step=0.2,
129
+ value=1.0,
130
+ )
131
+ repetition_penalty = gr.Slider(
132
+ label="Repetition Penalty",
133
+ info="Larger value prevents repetition.",
134
+ minimum=1.0,
135
+ maximum=5.0,
136
+ step=0.5,
137
+ value=1.5,
138
+ )
139
+ max_length = gr.Slider(
140
+ label="Max Length",
141
+ minimum=1,
142
+ maximum=500,
143
+ step=1,
144
+ value=200,
145
+ )
146
+ min_length = gr.Slider(
147
+ label="Minimum Length",
148
+ minimum=1,
149
+ maximum=100,
150
+ step=1,
151
+ value=1,
152
+ )
153
+ top_p = gr.Slider(
154
+ label="Top P",
155
+ info="Used with nucleus sampling.",
156
+ minimum=0.5,
157
+ maximum=1.0,
158
+ step=0.1,
159
+ value=0.9,
160
+ )
161
+ chat_output = [
162
+ chatbot,
163
+ history_chat
164
+ ]
165
 
 
166
 
167
+ chat_inputs = [
168
+ image,
169
+ text_input,
170
+ temperature,
171
+ length_penalty,
172
+ repetition_penalty,
173
+ max_length,
174
+ min_length,
175
+ top_p,
176
+ history_chat
177
+ ]
178
  with gr.Row():
179
+ clear_chat_button = gr.Button("Clear")
180
+ cancel_btn = gr.Button("Stop Generation")
181
+ chat_button = gr.Button("Submit", variant="primary")
182
 
183
+ chat_event1 = chat_button.click(add_text, [chatbot, text_input], [chatbot, text_input]).then(bot, [chatbot, text_input,
184
+ image, temperature,
185
+ length_penalty,
186
+ repetition_penalty,
187
+ max_length,
188
+ min_length,
189
+ top_p], chatbot)
190
 
191
+ chat_event2 = text_input.submit(
192
+ add_text,
193
+ [chatbot, text_input],
194
+ [chatbot, text_input]
195
+ ).then(
196
+ fn=bot,
197
+ inputs=[chatbot, text_input, image, temperature,
198
+ length_penalty,
199
+ repetition_penalty,
200
+ max_length,
201
+ min_length,
202
+ top_p],
203
+ outputs=chatbot
204
+ )
205
+ clear_chat_button.click(
206
+ fn=lambda: ([], []),
207
+ inputs=None,
208
+ outputs=[
209
+ chatbot,
210
+ history_chat
211
+ ],
212
+ queue=False,
213
+ api_name="clear",
214
  )
215
+ image.change(
216
+ fn=lambda: ([], []),
217
+ inputs=None,
218
+ outputs=[
219
+ chatbot,
220
+ history_chat
221
+ ],
222
+ queue=False)
223
+ cancel_btn.click(
224
+ None, [], [],
225
+ cancels=[chat_event1, chat_event2]
226
+ )
227
+ examples = [["./examples/baklava.png", "How to make this pastry?"],["./examples/bee.png","Describe this image."]]
228
+ gr.Examples(examples=examples, inputs=[image, text_input, chat_inputs])
229
 
 
 
 
230
 
231
+
 
 
 
 
232
 
233
  if __name__ == "__main__":
234
  demo.queue(max_size=10).launch(debug=True)