Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,350 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import os
|
4 |
+
from peft import PeftModel
|
5 |
+
from PIL import Image
|
6 |
+
import gradio as gr
|
7 |
+
import librosa
|
8 |
+
import nltk
|
9 |
+
|
10 |
+
from transformers import PreTrainedModel
|
11 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
12 |
+
from transformers import CLIPProcessor, CLIPModel
|
13 |
+
from transformers import WhisperProcessor, WhisperForConditionalGeneration
|
14 |
+
|
15 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
16 |
+
|
17 |
+
model_name = "microsoft/Phi-3.5-mini-instruct"
|
18 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
19 |
+
|
20 |
+
# Load the model and processor
|
21 |
+
clipmodel = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
22 |
+
clipprocessor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
|
23 |
+
|
24 |
+
nltk.download('punkt')
|
25 |
+
nltk.download('punkt_tab')
|
26 |
+
|
27 |
+
def remove_punctuation(text):
|
28 |
+
newtext = ''.join([char for char in text if char.isalnum() or char.isspace()])
|
29 |
+
newtext = ' '.join(newtext.split())
|
30 |
+
return newtext
|
31 |
+
|
32 |
+
def preprocess_text(text):
|
33 |
+
text_no_punct = remove_punctuation(text)
|
34 |
+
return text_no_punct
|
35 |
+
|
36 |
+
# Load Whisper model and processor
|
37 |
+
whisper_model_name = "openai/whisper-small"
|
38 |
+
whisper_processor = WhisperProcessor.from_pretrained(whisper_model_name)
|
39 |
+
whisper_model = WhisperForConditionalGeneration.from_pretrained(whisper_model_name)
|
40 |
+
|
41 |
+
def transcribe_speech(audiopath):
|
42 |
+
# Load and preprocess the audio
|
43 |
+
speech, rate = librosa.load(audiopath, sr=16000)
|
44 |
+
audio_input = whisper_processor(speech, return_tensors="pt", sampling_rate=16000)
|
45 |
+
# print("audio_input:", audio_input)
|
46 |
+
|
47 |
+
# Generate transcription
|
48 |
+
with torch.no_grad():
|
49 |
+
generated_ids = whisper_model.generate(audio_input["input_features"])
|
50 |
+
|
51 |
+
# Decode the transcription
|
52 |
+
transcription = whisper_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
53 |
+
|
54 |
+
return transcription
|
55 |
+
|
56 |
+
class ProjectionBlock(nn.Module):
|
57 |
+
def __init__(self, input_dim_CLIP, input_dim_phi2):
|
58 |
+
super().__init__()
|
59 |
+
self.pre_norm = nn.LayerNorm(input_dim_CLIP)
|
60 |
+
self.proj = nn.Sequential(
|
61 |
+
nn.Linear(input_dim_CLIP, input_dim_phi2),
|
62 |
+
nn.GELU(),
|
63 |
+
nn.Linear(input_dim_phi2, input_dim_phi2)
|
64 |
+
)
|
65 |
+
def forward(self, x):
|
66 |
+
x = self.pre_norm(x)
|
67 |
+
return self.proj(x)
|
68 |
+
|
69 |
+
# Modify the MultimodalPhiModel class to work with HuggingFace Trainer
|
70 |
+
class MultimodalPhiModel(PreTrainedModel):
|
71 |
+
|
72 |
+
def gradient_checkpointing_enable(self, **kwargs):
|
73 |
+
self.phi_model.gradient_checkpointing_enable(**kwargs)
|
74 |
+
|
75 |
+
def gradient_checkpointing_disable(self):
|
76 |
+
self.phi_model.gradient_checkpointing_disable()
|
77 |
+
|
78 |
+
def __init__(self, phi_model, tokenizer, projection):
|
79 |
+
super().__init__(phi_model.config)
|
80 |
+
self.phi_model = phi_model
|
81 |
+
self.image_projection = projection
|
82 |
+
self.tokenizer = tokenizer
|
83 |
+
# self.device = device
|
84 |
+
self.base_phi_model = None
|
85 |
+
|
86 |
+
@classmethod
|
87 |
+
def from_pretrained(self, pretrained_model_name_or_path, *model_args, debug=False, **kwargs):
|
88 |
+
|
89 |
+
model_name = "microsoft/Phi-3.5-mini-instruct"
|
90 |
+
base_phi_model = AutoModelForCausalLM.from_pretrained(
|
91 |
+
model_name,
|
92 |
+
torch_dtype=torch.bfloat16,
|
93 |
+
trust_remote_code=True,
|
94 |
+
)
|
95 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
96 |
+
|
97 |
+
# phi_path = os.path.join(pretrained_model_name_or_path, "phi_model")
|
98 |
+
phi_path = pretrained_model_name_or_path
|
99 |
+
|
100 |
+
# Save the base model
|
101 |
+
model = PeftModel.from_pretrained(base_phi_model, phi_path)
|
102 |
+
phi_model = model.merge_and_unload()
|
103 |
+
|
104 |
+
# # Load the base Phi-3 model
|
105 |
+
# phi_model = AutoModelForCausalLM.from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs)
|
106 |
+
input_dim = 512
|
107 |
+
output_dim = 3072
|
108 |
+
|
109 |
+
# Load the projector weights
|
110 |
+
# projector_path = os.path.join(pretrained_model_name_or_path, "projection_layer", "pytorch_model.bin")
|
111 |
+
projector_path = os.path.join(pretrained_model_name_or_path, "image_projector.pth")
|
112 |
+
if os.path.exists(projector_path):
|
113 |
+
projector_state_dict = torch.load(projector_path, map_location=phi_model.device)
|
114 |
+
|
115 |
+
projector = ProjectionBlock(input_dim, output_dim)
|
116 |
+
|
117 |
+
# Try to load the state dict, ignoring mismatched keys
|
118 |
+
projector.load_state_dict(projector_state_dict, strict=False)
|
119 |
+
print(f"Loaded projector with input_dim={input_dim}, output_dim={output_dim}")
|
120 |
+
else:
|
121 |
+
print(f"Projector weights not found at {projector_path}. Initializing with default dimensions.")
|
122 |
+
input_dim = 512 # Default CLIP embedding size
|
123 |
+
output_dim = phi_model.config.hidden_size
|
124 |
+
projector = ProjectionBlock(input_dim, output_dim)
|
125 |
+
|
126 |
+
# Create and return the Phi3WithProjector instance
|
127 |
+
model = self(phi_model, tokenizer, projector)
|
128 |
+
model.base_phi_model = base_phi_model
|
129 |
+
return model
|
130 |
+
|
131 |
+
def save_pretrained(self, save_directory):
|
132 |
+
# Load the Phi-3.5 model
|
133 |
+
self.phi_model.save_pretrained(save_directory)
|
134 |
+
# model_name = "microsoft/Phi-3.5-mini-instruct"
|
135 |
+
# base_phi_model = AutoModelForCausalLM.from_pretrained(
|
136 |
+
# model_name,
|
137 |
+
# torch_dtype=torch.bfloat16,
|
138 |
+
# trust_remote_code=True,
|
139 |
+
# )
|
140 |
+
# # Save the base model
|
141 |
+
# model = PeftModel.from_pretrained(base_phi_model, self.phi_model)
|
142 |
+
# model = model.merge_and_unload()
|
143 |
+
# model.save_pretrained(save_directory)
|
144 |
+
|
145 |
+
# Save the projector weights
|
146 |
+
projector_path = os.path.join(save_directory, "image_projector.pth")
|
147 |
+
torch.save(self.image_projection.state_dict(), projector_path)
|
148 |
+
|
149 |
+
# Save the config
|
150 |
+
self.config.save_pretrained(save_directory)
|
151 |
+
|
152 |
+
def encode(self, image_features):
|
153 |
+
image_projections = self.image_projection(image_features)
|
154 |
+
return image_projections
|
155 |
+
|
156 |
+
def forward(self, start_input_ids, end_input_ids, image_features, attention_mask, labels):
|
157 |
+
# print("tokenizer bos_token_id", self.tokenizer.bos_token_id, "tokenizer eos_token", self.tokenizer.eos_token,
|
158 |
+
# "tokenizer pad_token_id", self.tokenizer.pad_token_id, "tokenizer sep_token_id", self.tokenizer.sep_token_id,
|
159 |
+
# "tokenizer cls_token_id", self.tokenizer.cls_token_id, "tokenizer mask_token_id", self.tokenizer.mask_token_id,
|
160 |
+
# "tokenizer unk_token_id", self.tokenizer.unk_token_id)
|
161 |
+
device = next(self.parameters()).device
|
162 |
+
|
163 |
+
start_embeds = self.phi_model.get_input_embeddings()(start_input_ids.to(device))
|
164 |
+
end_embeds = self.phi_model.get_input_embeddings()(end_input_ids.to(device))
|
165 |
+
# print("start_embeds shape:", start_embeds.shape, "image_embeddings shape:", image_embeddings.shape, "end_embeds shape:", end_embeds.shape)
|
166 |
+
# print("start_embeds dtype:", start_embeds.dtype, "image_embeddings dtype:", image_embeddings.dtype, "end_embeds dtype:", end_embeds.dtype)
|
167 |
+
if image_features is not None:
|
168 |
+
# Encode image features
|
169 |
+
image_embeddings = self.encode(image_features.to(device)).bfloat16()
|
170 |
+
input_embeds = torch.cat([start_embeds, image_embeddings, end_embeds], dim=1)
|
171 |
+
else:
|
172 |
+
input_embeds = torch.cat([start_embeds, end_embeds], dim=1)
|
173 |
+
# print("Input Embeds shape:", input_embeds.shape, "attention_mask shape:", attention_mask.shape, "labels shape:", labels.shape)
|
174 |
+
|
175 |
+
# print("input_embeds dtype:", input_embeds.dtype, "attention_mask dtype:", attention_mask.dtype)
|
176 |
+
# Forward pass through the language model
|
177 |
+
outputs = self.phi_model(inputs_embeds=input_embeds.to(device),
|
178 |
+
attention_mask=attention_mask.to(device),
|
179 |
+
labels=labels,
|
180 |
+
return_dict=True)
|
181 |
+
|
182 |
+
return outputs
|
183 |
+
|
184 |
+
def getImageArray(image_path):
|
185 |
+
image = Image.open(image_path)
|
186 |
+
return image
|
187 |
+
|
188 |
+
def getAudioArray(audio_path):
|
189 |
+
speech, rate = librosa.load(audio_path, sr=16000)
|
190 |
+
return speech
|
191 |
+
|
192 |
+
def getInputs(image_path, question, answer=""):
|
193 |
+
|
194 |
+
image_features = None
|
195 |
+
speech_text = ""
|
196 |
+
num_image_tokens = 0
|
197 |
+
|
198 |
+
if image_path is not None:
|
199 |
+
# print("type of image:", type(image_path))
|
200 |
+
# print("image path:", image_path)
|
201 |
+
image = clipprocessor(images=Image.open(image_path), return_tensors="pt")
|
202 |
+
|
203 |
+
# Generate the embedding
|
204 |
+
image_features = clipmodel.get_image_features(**image)
|
205 |
+
|
206 |
+
# Generate the embedding
|
207 |
+
# image_features = get_clip_embeddings(image)
|
208 |
+
image_features = torch.stack([image_features])
|
209 |
+
num_image_tokens = image_features.shape[1]
|
210 |
+
|
211 |
+
# Start text before putting image embedding
|
212 |
+
start_text = f"<|system|>\nYou are an assistant good at understanding the objects and their relationship from the context.<|end|>\n<|user|>\n"
|
213 |
+
|
214 |
+
# Prepare text input for causal language modeling
|
215 |
+
end_text = f"\nPlease describe the objects and their relationship from the context.<|end|>\n<|assistant|>\n{answer}"
|
216 |
+
|
217 |
+
# Tokenize the full texts
|
218 |
+
start_tokens = tokenizer(start_text, padding=True, truncation=True, max_length=512, return_tensors="pt")
|
219 |
+
end_tokens = tokenizer(end_text, padding=True, truncation=True, max_length=512, return_tensors="pt")
|
220 |
+
# print(f"start_encodings shape: {start_encodings['input_ids'].shape}, end_encodings shape: {end_encodings['input_ids'].shape}")
|
221 |
+
|
222 |
+
start_input_ids = start_tokens['input_ids']
|
223 |
+
start_attention_mask = start_tokens['attention_mask']
|
224 |
+
end_input_ids = end_tokens['input_ids']
|
225 |
+
end_attention_mask = end_tokens['attention_mask']
|
226 |
+
|
227 |
+
# print("start_input_ids type:", type(start_input_ids), "image_tokens type:", type(image_tokens))
|
228 |
+
# print(f"start_input_ids shape: {start_input_ids.shape}, image_tokens shape: {image_tokens.shape}, end_input_ids shape: {end_input_ids.shape}")
|
229 |
+
# input_ids = torch.cat([start_input_ids,image_tokens,end_input_ids], dim=1)
|
230 |
+
if image_path is not None:
|
231 |
+
attention_mask = torch.cat([start_attention_mask, torch.ones((1, num_image_tokens), dtype=torch.long), end_attention_mask], dim=1)
|
232 |
+
else:
|
233 |
+
attention_mask = torch.cat([start_attention_mask, end_attention_mask], dim=1)
|
234 |
+
|
235 |
+
return start_input_ids, end_input_ids, image_features, attention_mask
|
236 |
+
|
237 |
+
model_location = "./MM_FT_C1_V2"
|
238 |
+
print("Model location:", model_location)
|
239 |
+
|
240 |
+
model = MultimodalPhiModel.from_pretrained(model_location).to(device)
|
241 |
+
|
242 |
+
import re
|
243 |
+
|
244 |
+
def getStringAfter(output, start_str):
|
245 |
+
if start_str in output:
|
246 |
+
answer = output.split(start_str)[1]
|
247 |
+
else:
|
248 |
+
answer = output
|
249 |
+
|
250 |
+
answer = preprocess_text(answer)
|
251 |
+
return answer
|
252 |
+
|
253 |
+
|
254 |
+
def getStringAfterAnswer(output):
|
255 |
+
if "<|assistant|>" in output:
|
256 |
+
answer = output.split("<|assistant|>")[1]
|
257 |
+
else:
|
258 |
+
answer = output
|
259 |
+
|
260 |
+
answer = preprocess_text(answer)
|
261 |
+
return answer
|
262 |
+
|
263 |
+
def generateOutput(image_path, audio_path, context_text, question, max_length=5):
|
264 |
+
answerPart = ""
|
265 |
+
speech_text = ""
|
266 |
+
if image_path is not None:
|
267 |
+
for i in range(max_length):
|
268 |
+
start_tokens, end_tokens, image_features, attention_mask = getInputs(image_path, question, answer=answerPart)
|
269 |
+
# print("image_features dtype:", image_features.dtype)
|
270 |
+
output = model(start_tokens, end_tokens, image_features, attention_mask, labels=None)
|
271 |
+
tokens = output.logits.argmax(dim=-1)
|
272 |
+
output = tokenizer.decode(
|
273 |
+
tokens[0],
|
274 |
+
skip_special_tokens=True
|
275 |
+
)
|
276 |
+
answerPart = getStringAfter(output, "<|assistant|>")
|
277 |
+
print("Answerpart:", answerPart)
|
278 |
+
|
279 |
+
if audio_path is not None:
|
280 |
+
speech_text = transcribe_speech(audio_path)
|
281 |
+
print("Speech Text:", speech_text)
|
282 |
+
|
283 |
+
if (question is None) or (question == ""):
|
284 |
+
question = "Provide only in 1 sentence to describe the objects and their relationships in it."
|
285 |
+
|
286 |
+
input_text = (
|
287 |
+
"<|system|>\nPlease understand the context "
|
288 |
+
"and answer the question based on the context in 1 or 2 summarized sentences.\n"
|
289 |
+
f"<|end|>\n<|user|>\n<|context|>{answerPart}\n{speech_text}\n{context_text}"
|
290 |
+
f"\n<|question|>: {question}\n<|end|>\n<|assistant|>\n"
|
291 |
+
)
|
292 |
+
print("input_text:", input_text)
|
293 |
+
start_tokens = tokenizer(input_text, padding=True, truncation=True, max_length=1024, return_tensors="pt")['input_ids'].to(device)
|
294 |
+
# base_phi_model.generate(start_tokens, max_length=2, do_sample=False, pad_token_id=tokenizer.pad_token_id)
|
295 |
+
|
296 |
+
output_text = tokenizer.decode(
|
297 |
+
model.base_phi_model.generate(start_tokens, max_length=1024, do_sample=False, pad_token_id=tokenizer.pad_token_id)[0],
|
298 |
+
skip_special_tokens=True
|
299 |
+
)
|
300 |
+
|
301 |
+
output_text = getStringAfter(output_text, question).strip()
|
302 |
+
return output_text
|
303 |
+
|
304 |
+
title = "Created Fine Tuned MultiModal model"
|
305 |
+
description = "Test the fine tuned multimodal model created using clip, phi3.5 mini instruct, whisper models"
|
306 |
+
examples = [
|
307 |
+
["./images/COCO_train2014_000000581181.jpg", None, None, None, None, "Describe what is happening in this image."],
|
308 |
+
[None, "Audio File", "./audio/03-01-01-01-01-01-01.wav", None, None, "Describe what is the person trying to tell in this audio."],
|
309 |
+
]
|
310 |
+
|
311 |
+
# [None, "Microphone", None, "example_audio_mic.wav", "Context without image.", "What is the result?"],
|
312 |
+
|
313 |
+
demo = gr.Blocks()
|
314 |
+
|
315 |
+
def process_inputs(image, audio_source, audio_file, audio_mic, context_text, question):
|
316 |
+
if audio_source == "Microphone":
|
317 |
+
speech = audio_mic
|
318 |
+
elif audio_source == "Audio File":
|
319 |
+
speech = audio_file
|
320 |
+
else:
|
321 |
+
speech = None
|
322 |
+
|
323 |
+
# image_features = get_clip_embeddings(image) if image else None
|
324 |
+
answer = generateOutput(image, speech, context_text, question)
|
325 |
+
|
326 |
+
return answer
|
327 |
+
|
328 |
+
with demo:
|
329 |
+
with gr.Row():
|
330 |
+
audio_source = gr.Radio(choices=["Microphone", "Audio File"], label="Select Audio Source")
|
331 |
+
audio_file = gr.Audio(sources="upload", type="filepath", visible=False)
|
332 |
+
audio_mic = gr.Audio(sources="microphone", type="filepath", visible=False)
|
333 |
+
image_input = gr.Image(type="filepath", label="Upload Image")
|
334 |
+
context_text = gr.Textbox(label="Context Text")
|
335 |
+
question = gr.Textbox(label="Question")
|
336 |
+
output_text = gr.Textbox(label="Output")
|
337 |
+
|
338 |
+
def update_audio_input(source):
|
339 |
+
if source == "Microphone":
|
340 |
+
return gr.update(visible=True), gr.update(visible=False)
|
341 |
+
elif source == "Audio File":
|
342 |
+
return gr.update(visible=False), gr.update(visible=True)
|
343 |
+
else:
|
344 |
+
return gr.update(visible=False), gr.update(visible=False)
|
345 |
+
|
346 |
+
audio_source.change(fn=update_audio_input, inputs=audio_source, outputs=[audio_mic, audio_file])
|
347 |
+
submit_button = gr.Button("Submit")
|
348 |
+
submit_button.click(fn=process_inputs, inputs=[image_input, audio_source, audio_file, audio_mic, context_text, question], outputs=output_text)
|
349 |
+
|
350 |
+
demo.launch(debug=True)
|