Chintan-Shah commited on
Commit
596d986
·
verified ·
1 Parent(s): ced3a58

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +350 -0
app.py ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import os
4
+ from peft import PeftModel
5
+ from PIL import Image
6
+ import gradio as gr
7
+ import librosa
8
+ import nltk
9
+
10
+ from transformers import PreTrainedModel
11
+ from transformers import AutoModelForCausalLM, AutoTokenizer
12
+ from transformers import CLIPProcessor, CLIPModel
13
+ from transformers import WhisperProcessor, WhisperForConditionalGeneration
14
+
15
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+
17
+ model_name = "microsoft/Phi-3.5-mini-instruct"
18
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
19
+
20
+ # Load the model and processor
21
+ clipmodel = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
22
+ clipprocessor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
23
+
24
+ nltk.download('punkt')
25
+ nltk.download('punkt_tab')
26
+
27
+ def remove_punctuation(text):
28
+ newtext = ''.join([char for char in text if char.isalnum() or char.isspace()])
29
+ newtext = ' '.join(newtext.split())
30
+ return newtext
31
+
32
+ def preprocess_text(text):
33
+ text_no_punct = remove_punctuation(text)
34
+ return text_no_punct
35
+
36
+ # Load Whisper model and processor
37
+ whisper_model_name = "openai/whisper-small"
38
+ whisper_processor = WhisperProcessor.from_pretrained(whisper_model_name)
39
+ whisper_model = WhisperForConditionalGeneration.from_pretrained(whisper_model_name)
40
+
41
+ def transcribe_speech(audiopath):
42
+ # Load and preprocess the audio
43
+ speech, rate = librosa.load(audiopath, sr=16000)
44
+ audio_input = whisper_processor(speech, return_tensors="pt", sampling_rate=16000)
45
+ # print("audio_input:", audio_input)
46
+
47
+ # Generate transcription
48
+ with torch.no_grad():
49
+ generated_ids = whisper_model.generate(audio_input["input_features"])
50
+
51
+ # Decode the transcription
52
+ transcription = whisper_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
53
+
54
+ return transcription
55
+
56
+ class ProjectionBlock(nn.Module):
57
+ def __init__(self, input_dim_CLIP, input_dim_phi2):
58
+ super().__init__()
59
+ self.pre_norm = nn.LayerNorm(input_dim_CLIP)
60
+ self.proj = nn.Sequential(
61
+ nn.Linear(input_dim_CLIP, input_dim_phi2),
62
+ nn.GELU(),
63
+ nn.Linear(input_dim_phi2, input_dim_phi2)
64
+ )
65
+ def forward(self, x):
66
+ x = self.pre_norm(x)
67
+ return self.proj(x)
68
+
69
+ # Modify the MultimodalPhiModel class to work with HuggingFace Trainer
70
+ class MultimodalPhiModel(PreTrainedModel):
71
+
72
+ def gradient_checkpointing_enable(self, **kwargs):
73
+ self.phi_model.gradient_checkpointing_enable(**kwargs)
74
+
75
+ def gradient_checkpointing_disable(self):
76
+ self.phi_model.gradient_checkpointing_disable()
77
+
78
+ def __init__(self, phi_model, tokenizer, projection):
79
+ super().__init__(phi_model.config)
80
+ self.phi_model = phi_model
81
+ self.image_projection = projection
82
+ self.tokenizer = tokenizer
83
+ # self.device = device
84
+ self.base_phi_model = None
85
+
86
+ @classmethod
87
+ def from_pretrained(self, pretrained_model_name_or_path, *model_args, debug=False, **kwargs):
88
+
89
+ model_name = "microsoft/Phi-3.5-mini-instruct"
90
+ base_phi_model = AutoModelForCausalLM.from_pretrained(
91
+ model_name,
92
+ torch_dtype=torch.bfloat16,
93
+ trust_remote_code=True,
94
+ )
95
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
96
+
97
+ # phi_path = os.path.join(pretrained_model_name_or_path, "phi_model")
98
+ phi_path = pretrained_model_name_or_path
99
+
100
+ # Save the base model
101
+ model = PeftModel.from_pretrained(base_phi_model, phi_path)
102
+ phi_model = model.merge_and_unload()
103
+
104
+ # # Load the base Phi-3 model
105
+ # phi_model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
106
+ input_dim = 512
107
+ output_dim = 3072
108
+
109
+ # Load the projector weights
110
+ # projector_path = os.path.join(pretrained_model_name_or_path, "projection_layer", "pytorch_model.bin")
111
+ projector_path = os.path.join(pretrained_model_name_or_path, "image_projector.pth")
112
+ if os.path.exists(projector_path):
113
+ projector_state_dict = torch.load(projector_path, map_location=phi_model.device)
114
+
115
+ projector = ProjectionBlock(input_dim, output_dim)
116
+
117
+ # Try to load the state dict, ignoring mismatched keys
118
+ projector.load_state_dict(projector_state_dict, strict=False)
119
+ print(f"Loaded projector with input_dim={input_dim}, output_dim={output_dim}")
120
+ else:
121
+ print(f"Projector weights not found at {projector_path}. Initializing with default dimensions.")
122
+ input_dim = 512 # Default CLIP embedding size
123
+ output_dim = phi_model.config.hidden_size
124
+ projector = ProjectionBlock(input_dim, output_dim)
125
+
126
+ # Create and return the Phi3WithProjector instance
127
+ model = self(phi_model, tokenizer, projector)
128
+ model.base_phi_model = base_phi_model
129
+ return model
130
+
131
+ def save_pretrained(self, save_directory):
132
+ # Load the Phi-3.5 model
133
+ self.phi_model.save_pretrained(save_directory)
134
+ # model_name = "microsoft/Phi-3.5-mini-instruct"
135
+ # base_phi_model = AutoModelForCausalLM.from_pretrained(
136
+ # model_name,
137
+ # torch_dtype=torch.bfloat16,
138
+ # trust_remote_code=True,
139
+ # )
140
+ # # Save the base model
141
+ # model = PeftModel.from_pretrained(base_phi_model, self.phi_model)
142
+ # model = model.merge_and_unload()
143
+ # model.save_pretrained(save_directory)
144
+
145
+ # Save the projector weights
146
+ projector_path = os.path.join(save_directory, "image_projector.pth")
147
+ torch.save(self.image_projection.state_dict(), projector_path)
148
+
149
+ # Save the config
150
+ self.config.save_pretrained(save_directory)
151
+
152
+ def encode(self, image_features):
153
+ image_projections = self.image_projection(image_features)
154
+ return image_projections
155
+
156
+ def forward(self, start_input_ids, end_input_ids, image_features, attention_mask, labels):
157
+ # print("tokenizer bos_token_id", self.tokenizer.bos_token_id, "tokenizer eos_token", self.tokenizer.eos_token,
158
+ # "tokenizer pad_token_id", self.tokenizer.pad_token_id, "tokenizer sep_token_id", self.tokenizer.sep_token_id,
159
+ # "tokenizer cls_token_id", self.tokenizer.cls_token_id, "tokenizer mask_token_id", self.tokenizer.mask_token_id,
160
+ # "tokenizer unk_token_id", self.tokenizer.unk_token_id)
161
+ device = next(self.parameters()).device
162
+
163
+ start_embeds = self.phi_model.get_input_embeddings()(start_input_ids.to(device))
164
+ end_embeds = self.phi_model.get_input_embeddings()(end_input_ids.to(device))
165
+ # print("start_embeds shape:", start_embeds.shape, "image_embeddings shape:", image_embeddings.shape, "end_embeds shape:", end_embeds.shape)
166
+ # print("start_embeds dtype:", start_embeds.dtype, "image_embeddings dtype:", image_embeddings.dtype, "end_embeds dtype:", end_embeds.dtype)
167
+ if image_features is not None:
168
+ # Encode image features
169
+ image_embeddings = self.encode(image_features.to(device)).bfloat16()
170
+ input_embeds = torch.cat([start_embeds, image_embeddings, end_embeds], dim=1)
171
+ else:
172
+ input_embeds = torch.cat([start_embeds, end_embeds], dim=1)
173
+ # print("Input Embeds shape:", input_embeds.shape, "attention_mask shape:", attention_mask.shape, "labels shape:", labels.shape)
174
+
175
+ # print("input_embeds dtype:", input_embeds.dtype, "attention_mask dtype:", attention_mask.dtype)
176
+ # Forward pass through the language model
177
+ outputs = self.phi_model(inputs_embeds=input_embeds.to(device),
178
+ attention_mask=attention_mask.to(device),
179
+ labels=labels,
180
+ return_dict=True)
181
+
182
+ return outputs
183
+
184
+ def getImageArray(image_path):
185
+ image = Image.open(image_path)
186
+ return image
187
+
188
+ def getAudioArray(audio_path):
189
+ speech, rate = librosa.load(audio_path, sr=16000)
190
+ return speech
191
+
192
+ def getInputs(image_path, question, answer=""):
193
+
194
+ image_features = None
195
+ speech_text = ""
196
+ num_image_tokens = 0
197
+
198
+ if image_path is not None:
199
+ # print("type of image:", type(image_path))
200
+ # print("image path:", image_path)
201
+ image = clipprocessor(images=Image.open(image_path), return_tensors="pt")
202
+
203
+ # Generate the embedding
204
+ image_features = clipmodel.get_image_features(**image)
205
+
206
+ # Generate the embedding
207
+ # image_features = get_clip_embeddings(image)
208
+ image_features = torch.stack([image_features])
209
+ num_image_tokens = image_features.shape[1]
210
+
211
+ # Start text before putting image embedding
212
+ start_text = f"<|system|>\nYou are an assistant good at understanding the objects and their relationship from the context.<|end|>\n<|user|>\n"
213
+
214
+ # Prepare text input for causal language modeling
215
+ end_text = f"\nPlease describe the objects and their relationship from the context.<|end|>\n<|assistant|>\n{answer}"
216
+
217
+ # Tokenize the full texts
218
+ start_tokens = tokenizer(start_text, padding=True, truncation=True, max_length=512, return_tensors="pt")
219
+ end_tokens = tokenizer(end_text, padding=True, truncation=True, max_length=512, return_tensors="pt")
220
+ # print(f"start_encodings shape: {start_encodings['input_ids'].shape}, end_encodings shape: {end_encodings['input_ids'].shape}")
221
+
222
+ start_input_ids = start_tokens['input_ids']
223
+ start_attention_mask = start_tokens['attention_mask']
224
+ end_input_ids = end_tokens['input_ids']
225
+ end_attention_mask = end_tokens['attention_mask']
226
+
227
+ # print("start_input_ids type:", type(start_input_ids), "image_tokens type:", type(image_tokens))
228
+ # print(f"start_input_ids shape: {start_input_ids.shape}, image_tokens shape: {image_tokens.shape}, end_input_ids shape: {end_input_ids.shape}")
229
+ # input_ids = torch.cat([start_input_ids,image_tokens,end_input_ids], dim=1)
230
+ if image_path is not None:
231
+ attention_mask = torch.cat([start_attention_mask, torch.ones((1, num_image_tokens), dtype=torch.long), end_attention_mask], dim=1)
232
+ else:
233
+ attention_mask = torch.cat([start_attention_mask, end_attention_mask], dim=1)
234
+
235
+ return start_input_ids, end_input_ids, image_features, attention_mask
236
+
237
+ model_location = "./MM_FT_C1_V2"
238
+ print("Model location:", model_location)
239
+
240
+ model = MultimodalPhiModel.from_pretrained(model_location).to(device)
241
+
242
+ import re
243
+
244
+ def getStringAfter(output, start_str):
245
+ if start_str in output:
246
+ answer = output.split(start_str)[1]
247
+ else:
248
+ answer = output
249
+
250
+ answer = preprocess_text(answer)
251
+ return answer
252
+
253
+
254
+ def getStringAfterAnswer(output):
255
+ if "<|assistant|>" in output:
256
+ answer = output.split("<|assistant|>")[1]
257
+ else:
258
+ answer = output
259
+
260
+ answer = preprocess_text(answer)
261
+ return answer
262
+
263
+ def generateOutput(image_path, audio_path, context_text, question, max_length=5):
264
+ answerPart = ""
265
+ speech_text = ""
266
+ if image_path is not None:
267
+ for i in range(max_length):
268
+ start_tokens, end_tokens, image_features, attention_mask = getInputs(image_path, question, answer=answerPart)
269
+ # print("image_features dtype:", image_features.dtype)
270
+ output = model(start_tokens, end_tokens, image_features, attention_mask, labels=None)
271
+ tokens = output.logits.argmax(dim=-1)
272
+ output = tokenizer.decode(
273
+ tokens[0],
274
+ skip_special_tokens=True
275
+ )
276
+ answerPart = getStringAfter(output, "<|assistant|>")
277
+ print("Answerpart:", answerPart)
278
+
279
+ if audio_path is not None:
280
+ speech_text = transcribe_speech(audio_path)
281
+ print("Speech Text:", speech_text)
282
+
283
+ if (question is None) or (question == ""):
284
+ question = "Provide only in 1 sentence to describe the objects and their relationships in it."
285
+
286
+ input_text = (
287
+ "<|system|>\nPlease understand the context "
288
+ "and answer the question based on the context in 1 or 2 summarized sentences.\n"
289
+ f"<|end|>\n<|user|>\n<|context|>{answerPart}\n{speech_text}\n{context_text}"
290
+ f"\n<|question|>: {question}\n<|end|>\n<|assistant|>\n"
291
+ )
292
+ print("input_text:", input_text)
293
+ start_tokens = tokenizer(input_text, padding=True, truncation=True, max_length=1024, return_tensors="pt")['input_ids'].to(device)
294
+ # base_phi_model.generate(start_tokens, max_length=2, do_sample=False, pad_token_id=tokenizer.pad_token_id)
295
+
296
+ output_text = tokenizer.decode(
297
+ model.base_phi_model.generate(start_tokens, max_length=1024, do_sample=False, pad_token_id=tokenizer.pad_token_id)[0],
298
+ skip_special_tokens=True
299
+ )
300
+
301
+ output_text = getStringAfter(output_text, question).strip()
302
+ return output_text
303
+
304
+ title = "Created Fine Tuned MultiModal model"
305
+ description = "Test the fine tuned multimodal model created using clip, phi3.5 mini instruct, whisper models"
306
+ examples = [
307
+ ["./images/COCO_train2014_000000581181.jpg", None, None, None, None, "Describe what is happening in this image."],
308
+ [None, "Audio File", "./audio/03-01-01-01-01-01-01.wav", None, None, "Describe what is the person trying to tell in this audio."],
309
+ ]
310
+
311
+ # [None, "Microphone", None, "example_audio_mic.wav", "Context without image.", "What is the result?"],
312
+
313
+ demo = gr.Blocks()
314
+
315
+ def process_inputs(image, audio_source, audio_file, audio_mic, context_text, question):
316
+ if audio_source == "Microphone":
317
+ speech = audio_mic
318
+ elif audio_source == "Audio File":
319
+ speech = audio_file
320
+ else:
321
+ speech = None
322
+
323
+ # image_features = get_clip_embeddings(image) if image else None
324
+ answer = generateOutput(image, speech, context_text, question)
325
+
326
+ return answer
327
+
328
+ with demo:
329
+ with gr.Row():
330
+ audio_source = gr.Radio(choices=["Microphone", "Audio File"], label="Select Audio Source")
331
+ audio_file = gr.Audio(sources="upload", type="filepath", visible=False)
332
+ audio_mic = gr.Audio(sources="microphone", type="filepath", visible=False)
333
+ image_input = gr.Image(type="filepath", label="Upload Image")
334
+ context_text = gr.Textbox(label="Context Text")
335
+ question = gr.Textbox(label="Question")
336
+ output_text = gr.Textbox(label="Output")
337
+
338
+ def update_audio_input(source):
339
+ if source == "Microphone":
340
+ return gr.update(visible=True), gr.update(visible=False)
341
+ elif source == "Audio File":
342
+ return gr.update(visible=False), gr.update(visible=True)
343
+ else:
344
+ return gr.update(visible=False), gr.update(visible=False)
345
+
346
+ audio_source.change(fn=update_audio_input, inputs=audio_source, outputs=[audio_mic, audio_file])
347
+ submit_button = gr.Button("Submit")
348
+ submit_button.click(fn=process_inputs, inputs=[image_input, audio_source, audio_file, audio_mic, context_text, question], outputs=output_text)
349
+
350
+ demo.launch(debug=True)