zqu2004 commited on
Commit
c9406a3
·
verified ·
1 Parent(s): de2bf9d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +127 -59
app.py CHANGED
@@ -1,63 +1,131 @@
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
-
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
-
9
-
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
-
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
-
26
- messages.append({"role": "user", "content": message})
27
-
28
- response = ""
29
-
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
-
39
- response += token
40
- yield response
41
-
42
- """
43
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
44
- """
45
- demo = gr.ChatInterface(
46
- respond,
47
- additional_inputs=[
48
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
49
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
50
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
51
- gr.Slider(
52
- minimum=0.1,
53
- maximum=1.0,
54
- value=0.95,
55
- step=0.05,
56
- label="Top-p (nucleus sampling)",
57
- ),
58
- ],
59
  )
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
 
62
- if __name__ == "__main__":
63
- demo.launch()
 
1
  import gradio as gr
2
+ from transformers import AutoModelForCausalLM, AutoProcessor, GenerationConfig
3
+ from PIL import Image
4
+ import torch
5
+ import spaces
6
+
7
+ # Flag to use GPU (set to False by default)
8
+ USE_GPU = False
9
+
10
+ # Load the processor and model
11
+ device = torch.device("cuda" if USE_GPU and torch.cuda.is_available() else "cpu")
12
+
13
+ processor = AutoProcessor.from_pretrained(
14
+ 'allenai/Molmo-7B-D-0924',
15
+ trust_remote_code=True,
16
+ torch_dtype='auto',
17
+ )
18
+
19
+ model = AutoModelForCausalLM.from_pretrained(
20
+ 'allenai/Molmo-7B-D-0924',
21
+ trust_remote_code=True,
22
+ torch_dtype='auto',
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  )
24
 
25
+ model.to(device)
26
+
27
+ # Predefined prompts
28
+ prompts = [
29
+ "Describe this image in detail",
30
+ "What objects can you see in this image?",
31
+ "What's the main subject of this image?",
32
+ "Describe the colors in this image",
33
+ "What emotions does this image evoke?"
34
+ ]
35
+
36
+ def process_image_and_text(image, text, max_new_tokens, temperature, top_p):
37
+ # Process the image and text
38
+ inputs = processor.process(
39
+ images=[Image.fromarray(image)],
40
+ text=text
41
+ )
42
+
43
+ # Move inputs to the correct device and make a batch of size 1
44
+ inputs = {k: v.to(device).unsqueeze(0) for k, v in inputs.items()}
45
+
46
+ # Generate output
47
+ output = model.generate_from_batch(
48
+ inputs,
49
+ GenerationConfig(
50
+ max_new_tokens=max_new_tokens,
51
+ temperature=temperature,
52
+ top_p=top_p,
53
+ stop_strings="<|endoftext|>"
54
+ ),
55
+ tokenizer=processor.tokenizer
56
+ )
57
+
58
+ # Only get generated tokens; decode them to text
59
+ generated_tokens = output[0, inputs['input_ids'].size(1):]
60
+ generated_text = processor.tokenizer.decode(generated_tokens, skip_special_tokens=True)
61
+
62
+ return generated_text
63
+
64
+ def chatbot(image, text, history, max_new_tokens, temperature, top_p):
65
+ if image is None:
66
+ return history + [("Please upload an image first.", None)]
67
+
68
+ response = process_image_and_text(image, text, max_new_tokens, temperature, top_p)
69
+ history.append((text, response))
70
+ return history
71
+
72
+ def update_textbox(prompt):
73
+ return gr.update(value=prompt)
74
+
75
+ # Define the Gradio interface
76
+ with gr.Blocks() as demo:
77
+ gr.Markdown("# Image Chatbot with Molmo-7B-D-0924")
78
+
79
+ with gr.Row():
80
+ image_input = gr.Image(type="numpy")
81
+ chatbot_output = gr.Chatbot()
82
+
83
+ with gr.Row():
84
+ text_input = gr.Textbox(placeholder="Ask a question about the image...")
85
+ prompt_dropdown = gr.Dropdown(choices=[""] + prompts, label="Select a premade prompt", value="")
86
+
87
+ submit_button = gr.Button("Submit")
88
+ clear_button = gr.ClearButton([text_input, chatbot_output])
89
+
90
+ with gr.Accordion("Advanced options", open=False):
91
+ max_new_tokens = gr.Slider(minimum=1, maximum=500, value=200, step=1, label="Max new tokens")
92
+ temperature = gr.Slider(minimum=0.1, maximum=2.0, value=1.0, step=0.1, label="Temperature")
93
+ top_p = gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)")
94
+
95
+ state = gr.State([])
96
+
97
+ # Add copy button for raw output
98
+ with gr.Row():
99
+ raw_output = gr.Textbox(label="Raw Output", interactive=False)
100
+ copy_button = gr.Button("Copy Raw Output")
101
+
102
+ def update_raw_output(history):
103
+ if history:
104
+ return history[-1][1]
105
+ return ""
106
+
107
+ submit_button.click(
108
+ chatbot,
109
+ inputs=[image_input, text_input, state, max_new_tokens, temperature, top_p],
110
+ outputs=[chatbot_output]
111
+ ).then(
112
+ update_raw_output,
113
+ inputs=[chatbot_output],
114
+ outputs=[raw_output]
115
+ )
116
+
117
+ text_input.submit(
118
+ chatbot,
119
+ inputs=[image_input, text_input, state, max_new_tokens, temperature, top_p],
120
+ outputs=[chatbot_output]
121
+ ).then(
122
+ update_raw_output,
123
+ inputs=[chatbot_output],
124
+ outputs=[raw_output]
125
+ )
126
+
127
+ prompt_dropdown.change(update_textbox, inputs=[prompt_dropdown], outputs=[text_input])
128
+
129
+ copy_button.click(lambda x: gr.update(value=x), inputs=[raw_output], outputs=[gr.Textbox(visible=False)])
130
 
131
+ demo.launch()