whan12 commited on
Commit
56451ce
Β·
verified Β·
1 Parent(s): 872e236

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -169
app.py CHANGED
@@ -4,19 +4,15 @@ import copy
4
  import gradio as gr
5
  import PIL.Image
6
  import torch
7
- from transformers import BitsAndBytesConfig, pipeline,LlavaNextProcessor, LlavaNextForConditionalGeneration
8
- import torch
9
  import re
10
  import time
11
 
12
  DESCRIPTION = "# LLaVA πŸ’ͺ - THE IRON PUMPING MACHINE VISION BEAST"
13
 
14
  model_id = "llava-hf/llava-v1.6-vicuna-7b-hf"
15
-
16
-
17
- pipe = LlavaNextForConditionalGeneration.from_pretrained(model_id , torch_dtype=torch.float16, low_cpu_mem_usage=True)
18
-
19
 
 
20
 
21
  def extract_response_pairs(text):
22
  turns = re.split(r'(USER:|ASSISTANT:)', text)[1:]
@@ -28,69 +24,72 @@ def extract_response_pairs(text):
28
 
29
  return conv_list
30
 
31
-
32
-
33
  def add_text(history, text):
34
- history = history.append([text, None])
35
- return history, text
36
-
37
- def infer(image, prompt,
38
- temperature,
39
- length_penalty,
40
- repetition_penalty,
41
- max_length,
42
- min_length,
43
- top_p):
44
-
45
- outputs = pipe(images=image, prompt=prompt,
46
- generate_kwargs={"temperature":temperature,
47
- "length_penalty":length_penalty,
48
- "repetition_penalty":repetition_penalty,
49
- "max_length":max_length,
50
- "min_length":min_length,
51
- "top_p":top_p})
52
- inference_output = outputs[0]["generated_text"]
53
- return inference_output
54
-
55
-
56
-
57
- def bot(history_chat, text_input, image,
58
- temperature,
59
- length_penalty,
60
- repetition_penalty,
61
- max_length,
62
- min_length,
63
- top_p):
64
-
 
 
 
 
65
  if text_input == "":
66
  gr.Warning("Please input text")
67
 
68
- if image==None:
69
  gr.Warning("Please input image or wait for image to be uploaded before clicking submit.")
70
- chat_history = " ".join(history_chat) # history as a str to be passed to model
71
- chat_history = "you are a bodybuilding coach,and you sounds like arnold schwarzenegger, give advice on my gains, training and inspire me at the end"+chat_history + f"USER: <image>\n{text_input}\nASSISTANT:" # add text input for prompting
72
- inference_result = infer(image, chat_history,
73
- temperature,
74
- length_penalty,
75
- repetition_penalty,
76
- max_length,
77
- min_length,
78
- top_p)
79
- # return inference and parse for new history
 
80
  chat_val = extract_response_pairs(inference_result)
81
 
82
- # create history list for yielding the last inference response
83
  chat_state_list = copy.deepcopy(chat_val)
84
- chat_state_list[-1][1] = "" # empty last response
 
 
 
 
85
 
86
- # add characters iteratively
87
- for character in chat_val[-1][1]:
88
  chat_state_list[-1][1] += character
89
  time.sleep(0.05)
90
- # yield history but with last response being streamed
91
  yield chat_state_list
92
 
93
-
94
  css = """
95
  #mkd {
96
  height: 500px;
@@ -98,137 +97,52 @@ css = """
98
  border: 1px solid #ccc;
99
  }
100
  """
101
- with gr.Blocks(css="style.css") as demo:
 
102
  gr.Markdown(DESCRIPTION)
103
  gr.Markdown("""## LLaVA, one of the greatest multimodal chat models is now available in Transformers with 4-bit quantization! ⚑️
104
  See the docs here: https://huggingface.co/docs/transformers/main/en/model_doc/llava.""")
105
  chatbot = gr.Chatbot(label="Chat", show_label=False)
106
  gr.Markdown("Input image and text and start chatting πŸ‘‡")
107
  with gr.Row():
108
-
109
- image = gr.Image(type="pil")
110
- text_input = gr.Text(label="Chat Input", show_label=False, max_lines=3, container=False)
111
 
112
  history_chat = gr.State(value=[])
 
113
 
114
  with gr.Accordion(label="Advanced settings", open=False):
115
- temperature = gr.Slider(
116
- label="Temperature",
117
- info="Used with nucleus sampling.",
118
- minimum=0.5,
119
- maximum=1.0,
120
- step=0.1,
121
- value=1.0,
122
- )
123
- length_penalty = gr.Slider(
124
- label="Length Penalty",
125
- info="Set to larger for longer sequence, used with beam search.",
126
- minimum=-1.0,
127
- maximum=2.0,
128
- step=0.2,
129
- value=1.0,
130
- )
131
- repetition_penalty = gr.Slider(
132
- label="Repetition Penalty",
133
- info="Larger value prevents repetition.",
134
- minimum=1.0,
135
- maximum=5.0,
136
- step=0.5,
137
- value=1.5,
138
- )
139
- max_length = gr.Slider(
140
- label="Max Length",
141
- minimum=1,
142
- maximum=500,
143
- step=1,
144
- value=200,
145
- )
146
- min_length = gr.Slider(
147
- label="Minimum Length",
148
- minimum=1,
149
- maximum=100,
150
- step=1,
151
- value=1,
152
- )
153
- top_p = gr.Slider(
154
- label="Top P",
155
- info="Used with nucleus sampling.",
156
- minimum=0.5,
157
- maximum=1.0,
158
- step=0.1,
159
- value=0.9,
160
- )
161
- chat_output = [
162
- chatbot,
163
- history_chat
164
- ]
165
 
 
166
 
167
- chat_inputs = [
168
- image,
169
- text_input,
170
- temperature,
171
- length_penalty,
172
- repetition_penalty,
173
- max_length,
174
- min_length,
175
- top_p,
176
- history_chat
177
- ]
178
  with gr.Row():
179
- clear_chat_button = gr.Button("Clear")
180
- cancel_btn = gr.Button("Stop Generation")
181
- chat_button = gr.Button("Submit", variant="primary")
182
 
183
- chat_event1 = chat_button.click(add_text, [chatbot, text_input], [chatbot, text_input]).then(bot, [chatbot, text_input,
184
- image, temperature,
185
- length_penalty,
186
- repetition_penalty,
187
- max_length,
188
- min_length,
189
- top_p], chatbot)
190
-
191
- chat_event2 = text_input.submit(
192
- add_text,
193
- [chatbot, text_input],
194
- [chatbot, text_input]
195
- ).then(
196
- fn=bot,
197
- inputs=[chatbot, text_input, image, temperature,
198
- length_penalty,
199
- repetition_penalty,
200
- max_length,
201
- min_length,
202
- top_p],
203
- outputs=chatbot
204
  )
205
- clear_chat_button.click(
206
- fn=lambda: ([], []),
207
- inputs=None,
208
- outputs=[
209
- chatbot,
210
- history_chat
211
- ],
212
- queue=False,
213
- api_name="clear",
214
  )
215
- image.change(
216
- fn=lambda: ([], []),
217
- inputs=None,
218
- outputs=[
219
- chatbot,
220
- history_chat
221
- ],
222
- queue=False)
223
- cancel_btn.click(
224
- None, [], [],
225
- cancels=[chat_event1, chat_event2]
226
- )
227
- examples = [["./examples/baklava.png", "How to make this pastry?"],["./examples/bee.png","Describe this image."]]
228
- gr.Examples(examples=examples, inputs=[image, text_input, chat_inputs])
229
 
 
 
 
230
 
231
-
 
 
 
 
232
 
233
  if __name__ == "__main__":
234
  demo.queue(max_size=10).launch(debug=True)
 
4
  import gradio as gr
5
  import PIL.Image
6
  import torch
7
+ from transformers import BitsAndBytesConfig, pipeline, LlavaNextProcessor, LlavaNextForConditionalGeneration
 
8
  import re
9
  import time
10
 
11
  DESCRIPTION = "# LLaVA πŸ’ͺ - THE IRON PUMPING MACHINE VISION BEAST"
12
 
13
  model_id = "llava-hf/llava-v1.6-vicuna-7b-hf"
 
 
 
 
14
 
15
+ pipe = LlavaNextForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.float16, low_cpu_mem_usage=True)
16
 
17
  def extract_response_pairs(text):
18
  turns = re.split(r'(USER:|ASSISTANT:)', text)[1:]
 
24
 
25
  return conv_list
26
 
 
 
27
  def add_text(history, text):
28
+ history = history + [[text, None]]
29
+ return history, text
30
+
31
+ def infer(image, prompt, temperature, length_penalty, repetition_penalty, max_length, min_length, top_p):
32
+ outputs = pipe(images=image, prompt=prompt,
33
+ generate_kwargs={"temperature": temperature,
34
+ "length_penalty": length_penalty,
35
+ "repetition_penalty": repetition_penalty,
36
+ "max_length": max_length,
37
+ "min_length": min_length,
38
+ "top_p": top_p})
39
+ inference_output = outputs[0]["generated_text"]
40
+ return inference_output
41
+
42
+ def arnold_speak(text):
43
+ # Add Arnold Schwarzenegger-style phrases and modify speech
44
+ arnold_phrases = [
45
+ "Come with me if you want to lift!",
46
+ "I'll be back... after my protein shake.",
47
+ "Hasta la vista, baby weight!",
48
+ "Get to da choppa... I mean, da squat rack!",
49
+ "You lack discipline! But don't worry, I'm here to pump you up!"
50
+ ]
51
+
52
+ text = text.replace(".", "!") # More enthusiastic punctuation
53
+ text = text.replace("gym", "iron paradise")
54
+ text = text.replace("exercise", "pump iron")
55
+ text = text.replace("workout", "sculpt your physique")
56
+
57
+ # Add random Arnold phrase to the end
58
+ text += " " + arnold_phrases[torch.randint(0, len(arnold_phrases), (1,)).item()]
59
+
60
+ return text
61
+
62
+ def bot(history_chat, text_input, image, temperature, length_penalty, repetition_penalty, max_length, min_length, top_p, arnold_mode):
63
  if text_input == "":
64
  gr.Warning("Please input text")
65
 
66
+ if image is None:
67
  gr.Warning("Please input image or wait for image to be uploaded before clicking submit.")
68
+
69
+ chat_history = " ".join([item for sublist in history_chat for item in sublist]) # Flatten history
70
+
71
+ if arnold_mode:
72
+ system_prompt = "you are a bodybuilding coach, and you sound like Arnold Schwarzenegger. Give advice on gains, training, and inspire me at the end. Use Arnold's catchphrases and speaking style."
73
+ else:
74
+ system_prompt = "You are a helpful AI assistant. Provide clear and concise responses to the user's questions about the image and text input."
75
+
76
+ chat_history = f"{system_prompt}\n{chat_history}\nUSER: <image>\n{text_input}\nASSISTANT:"
77
+
78
+ inference_result = infer(image, chat_history, temperature, length_penalty, repetition_penalty, max_length, min_length, top_p)
79
  chat_val = extract_response_pairs(inference_result)
80
 
 
81
  chat_state_list = copy.deepcopy(chat_val)
82
+ chat_state_list[-1][1] = "" # empty last response
83
+
84
+ response = chat_val[-1][1]
85
+ if arnold_mode:
86
+ response = arnold_speak(response)
87
 
88
+ for character in response:
 
89
  chat_state_list[-1][1] += character
90
  time.sleep(0.05)
 
91
  yield chat_state_list
92
 
 
93
  css = """
94
  #mkd {
95
  height: 500px;
 
97
  border: 1px solid #ccc;
98
  }
99
  """
100
+
101
+ with gr.Blocks(css=css) as demo:
102
  gr.Markdown(DESCRIPTION)
103
  gr.Markdown("""## LLaVA, one of the greatest multimodal chat models is now available in Transformers with 4-bit quantization! ⚑️
104
  See the docs here: https://huggingface.co/docs/transformers/main/en/model_doc/llava.""")
105
  chatbot = gr.Chatbot(label="Chat", show_label=False)
106
  gr.Markdown("Input image and text and start chatting πŸ‘‡")
107
  with gr.Row():
108
+ image = gr.Image(type="pil")
109
+ text_input = gr.Text(label="Chat Input", show_label=False, max_lines=3, container=False)
 
110
 
111
  history_chat = gr.State(value=[])
112
+ arnold_mode = gr.Checkbox(label="Arnold Schwarzenegger Mode", value=False)
113
 
114
  with gr.Accordion(label="Advanced settings", open=False):
115
+ temperature = gr.Slider(label="Temperature", info="Used with nucleus sampling.", minimum=0.5, maximum=1.0, step=0.1, value=1.0)
116
+ length_penalty = gr.Slider(label="Length Penalty", info="Set to larger for longer sequence, used with beam search.", minimum=-1.0, maximum=2.0, step=0.2, value=1.0)
117
+ repetition_penalty = gr.Slider(label="Repetition Penalty", info="Larger value prevents repetition.", minimum=1.0, maximum=5.0, step=0.5, value=1.5)
118
+ max_length = gr.Slider(label="Max Length", minimum=1, maximum=500, step=1, value=200)
119
+ min_length = gr.Slider(label="Minimum Length", minimum=1, maximum=100, step=1, value=1)
120
+ top_p = gr.Slider(label="Top P", info="Used with nucleus sampling.", minimum=0.5, maximum=1.0, step=0.1, value=0.9)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
+ chat_inputs = [image, text_input, temperature, length_penalty, repetition_penalty, max_length, min_length, top_p, history_chat, arnold_mode]
123
 
 
 
 
 
 
 
 
 
 
 
 
124
  with gr.Row():
125
+ clear_chat_button = gr.Button("Clear")
126
+ cancel_btn = gr.Button("Stop Generation")
127
+ chat_button = gr.Button("Submit", variant="primary")
128
 
129
+ chat_event1 = chat_button.click(add_text, [chatbot, text_input], [chatbot, text_input]).then(
130
+ bot, chat_inputs, chatbot
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
131
  )
132
+
133
+ chat_event2 = text_input.submit(add_text, [chatbot, text_input], [chatbot, text_input]).then(
134
+ bot, chat_inputs, chatbot
 
 
 
 
 
 
135
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
 
137
+ clear_chat_button.click(lambda: ([], []), inputs=None, outputs=[chatbot, history_chat], queue=False, api_name="clear")
138
+ image.change(lambda: ([], []), inputs=None, outputs=[chatbot, history_chat], queue=False)
139
+ cancel_btn.click(None, [], [], cancels=[chat_event1, chat_event2])
140
 
141
+ examples = [
142
+ ["./examples/baklava.png", "How to make this pastry?"],
143
+ ["./examples/bee.png", "Describe this image."]
144
+ ]
145
+ gr.Examples(examples=examples, inputs=[image, text_input])
146
 
147
  if __name__ == "__main__":
148
  demo.queue(max_size=10).launch(debug=True)