|
import gradio as gr
|
|
import io
|
|
import sys
|
|
import time
|
|
import dataclasses
|
|
from pathlib import Path
|
|
import os
|
|
from enum import auto, Enum
|
|
from typing import List, Tuple, Any
|
|
from utility import prediction_guard_llava_conv
|
|
import lancedb
|
|
from utility import load_json_file
|
|
from mm_rag.embeddings.bridgetower_embeddings import BridgeTowerEmbeddings
|
|
from mm_rag.vectorstores.multimodal_lancedb import MultimodalLanceDB
|
|
from mm_rag.MLM.client import PredictionGuardClient
|
|
from mm_rag.MLM.lvlm import LVLM
|
|
from PIL import Image
|
|
from langchain_core.runnables import RunnableParallel, RunnablePassthrough, RunnableLambda
|
|
from moviepy.video.io.VideoFileClip import VideoFileClip
|
|
from utility import prediction_guard_llava_conv, encode_image, Conversation, lvlm_inference_with_conversation
|
|
|
|
server_error_msg="**NETWORK ERROR DUE TO HIGH TRAFFIC. PLEASE REGENERATE OR REFRESH THIS PAGE.**"
|
|
|
|
|
|
def split_video(video_path, timestamp_in_ms, output_video_path: str = "./shared_data/splitted_videos", output_video_name: str="video_tmp.mp4", play_before_sec: int=3, play_after_sec: int=3):
|
|
timestamp_in_sec = int(timestamp_in_ms / 1000)
|
|
|
|
Path(output_video_path).mkdir(parents=True, exist_ok=True)
|
|
output_video = os.path.join(output_video_path, output_video_name)
|
|
with VideoFileClip(video_path) as video:
|
|
duration = video.duration
|
|
start_time = max(timestamp_in_sec - play_before_sec, 0)
|
|
end_time = min(timestamp_in_sec + play_after_sec, duration)
|
|
new = video.subclip(start_time, end_time)
|
|
new.write_videofile(output_video, audio_codec='aac')
|
|
return output_video
|
|
|
|
|
|
prompt_template = """The transcript associated with the image is '{transcript}'. {user_query}"""
|
|
|
|
|
|
def get_default_rag_chain():
|
|
|
|
LANCEDB_HOST_FILE = "./shared_data/.lancedb"
|
|
|
|
TBL_NAME = "demo_tbl"
|
|
|
|
|
|
db = lancedb.connect(LANCEDB_HOST_FILE)
|
|
|
|
|
|
embedder = BridgeTowerEmbeddings()
|
|
|
|
|
|
vectorstore = MultimodalLanceDB(uri=LANCEDB_HOST_FILE, embedding=embedder, table_name=TBL_NAME)
|
|
|
|
retriever_module = vectorstore.as_retriever(search_type='similarity', search_kwargs={"k": 1})
|
|
|
|
|
|
client = PredictionGuardClient()
|
|
|
|
lvlm_inference_module = LVLM(client=client)
|
|
|
|
def prompt_processing(input):
|
|
|
|
retrieved_results, user_query = input['retrieved_results'], input['user_query']
|
|
|
|
retrieved_result = retrieved_results[0]
|
|
|
|
|
|
|
|
metadata_retrieved_video_segment = retrieved_result.metadata['metadata']
|
|
|
|
|
|
transcript = metadata_retrieved_video_segment['transcript']
|
|
frame_path = metadata_retrieved_video_segment['extracted_frame_path']
|
|
return {
|
|
'prompt': prompt_template.format(transcript=transcript, user_query=user_query),
|
|
'image' : frame_path,
|
|
'metadata' : metadata_retrieved_video_segment,
|
|
}
|
|
|
|
prompt_processing_module = RunnableLambda(prompt_processing)
|
|
|
|
|
|
mm_rag_chain_with_retrieved_image = (
|
|
RunnableParallel({"retrieved_results": retriever_module ,
|
|
"user_query": RunnablePassthrough()})
|
|
| prompt_processing_module
|
|
| RunnableParallel({'final_text_output': lvlm_inference_module,
|
|
'input_to_lvlm' : RunnablePassthrough()})
|
|
)
|
|
return mm_rag_chain_with_retrieved_image
|
|
|
|
class SeparatorStyle(Enum):
|
|
"""Different separator style."""
|
|
SINGLE = auto()
|
|
|
|
@dataclasses.dataclass
|
|
class GradioInstance:
|
|
"""A class that keeps all conversation history."""
|
|
system: str
|
|
roles: List[str]
|
|
messages: List[List[str]]
|
|
offset: int
|
|
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
|
|
sep: str = "\n"
|
|
sep2: str = None
|
|
version: str = "Unknown"
|
|
path_to_img: str = None
|
|
video_title: str = None
|
|
path_to_video: str = None
|
|
caption: str = None
|
|
mm_rag_chain: Any = None
|
|
|
|
skip_next: bool = False
|
|
|
|
def _template_caption(self):
|
|
out = ""
|
|
if self.caption is not None:
|
|
out = f"The caption associated with the image is '{self.caption}'. "
|
|
return out
|
|
|
|
def get_prompt_for_rag(self):
|
|
messages = self.messages
|
|
assert len(messages) == 2, "length of current conversation should be 2"
|
|
assert messages[1][1] is None, "the first response message of current conversation should be None"
|
|
ret = messages[0][1]
|
|
return ret
|
|
|
|
def get_conversation_for_lvlm(self):
|
|
pg_conv = prediction_guard_llava_conv.copy()
|
|
image_path = self.path_to_img
|
|
b64_img = encode_image(image_path)
|
|
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
|
if msg is None:
|
|
break
|
|
if i == 0:
|
|
pg_conv.append_message(prediction_guard_llava_conv.roles[0], [msg, b64_img])
|
|
elif i == len(self.messages[self.offset:]) - 2:
|
|
pg_conv.append_message(role, [prompt_template.format(transcript=self.caption, user_query=msg)])
|
|
else:
|
|
pg_conv.append_message(role, [msg])
|
|
return pg_conv
|
|
|
|
def append_message(self, role, message):
|
|
self.messages.append([role, message])
|
|
|
|
def get_images(self, return_pil=False):
|
|
images = []
|
|
if self.path_to_img is not None:
|
|
path_to_image = self.path_to_img
|
|
images.append(path_to_image)
|
|
return images
|
|
|
|
def to_gradio_chatbot(self):
|
|
ret = []
|
|
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
|
if i % 2 == 0:
|
|
if type(msg) is tuple:
|
|
import base64
|
|
from io import BytesIO
|
|
msg, image, image_process_mode = msg
|
|
max_hw, min_hw = max(image.size), min(image.size)
|
|
aspect_ratio = max_hw / min_hw
|
|
max_len, min_len = 800, 400
|
|
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
|
|
longest_edge = int(shortest_edge * aspect_ratio)
|
|
W, H = image.size
|
|
if H > W:
|
|
H, W = longest_edge, shortest_edge
|
|
else:
|
|
H, W = shortest_edge, longest_edge
|
|
image = image.resize((W, H))
|
|
buffered = BytesIO()
|
|
image.save(buffered, format="JPEG")
|
|
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
|
img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
|
|
msg = img_str + msg.replace('<image>', '').strip()
|
|
ret.append([msg, None])
|
|
else:
|
|
ret.append([msg, None])
|
|
else:
|
|
ret[-1][-1] = msg
|
|
return ret
|
|
|
|
def copy(self):
|
|
return GradioInstance(
|
|
system=self.system,
|
|
roles=self.roles,
|
|
messages=[[x, y] for x, y in self.messages],
|
|
offset=self.offset,
|
|
sep_style=self.sep_style,
|
|
sep=self.sep,
|
|
sep2=self.sep2,
|
|
version=self.version,
|
|
mm_rag_chain=self.mm_rag_chain,
|
|
)
|
|
|
|
def dict(self):
|
|
return {
|
|
"system": self.system,
|
|
"roles": self.roles,
|
|
"messages": self.messages,
|
|
"offset": self.offset,
|
|
"sep": self.sep,
|
|
"sep2": self.sep2,
|
|
"path_to_img": self.path_to_img,
|
|
"video_title" : self.video_title,
|
|
"path_to_video": self.path_to_video,
|
|
"caption" : self.caption,
|
|
}
|
|
def get_path_to_subvideos(self):
|
|
if self.video_title is not None and self.path_to_img is not None:
|
|
info = video_helper_map[self.video_title]
|
|
path = info['path']
|
|
prefix = info['prefix']
|
|
vid_index = self.path_to_img.split('/')[-1]
|
|
vid_index = vid_index.split('_')[-1]
|
|
vid_index = vid_index.replace('.jpg', '')
|
|
ret = f"{prefix}{vid_index}.mp4"
|
|
ret = os.path.join(path, ret)
|
|
return ret
|
|
elif self.path_to_video is not None:
|
|
return self.path_to_video
|
|
return None
|
|
|
|
def get_gradio_instance(mm_rag_chain=None):
|
|
if mm_rag_chain is None:
|
|
mm_rag_chain = get_default_rag_chain()
|
|
|
|
instance = GradioInstance(
|
|
system="",
|
|
roles=prediction_guard_llava_conv.roles,
|
|
messages=[],
|
|
offset=0,
|
|
sep_style=SeparatorStyle.SINGLE,
|
|
sep="\n",
|
|
path_to_img=None,
|
|
video_title=None,
|
|
caption=None,
|
|
mm_rag_chain=mm_rag_chain,
|
|
)
|
|
return instance
|
|
|
|
gr.set_static_paths(paths=["./assets/"])
|
|
theme = gr.themes.Base(
|
|
primary_hue=gr.themes.Color(
|
|
c100="#dbeafe", c200="#bfdbfe", c300="#93c5fd", c400="#60a5fa", c50="#eff6ff", c500="#0054ae", c600="#00377c", c700="#00377c", c800="#1e40af", c900="#1e3a8a", c950="#0a0c2b"),
|
|
secondary_hue=gr.themes.Color(
|
|
c100="#dbeafe", c200="#bfdbfe", c300="#93c5fd", c400="#60a5fa", c50="#eff6ff", c500="#0054ae", c600="#0054ae", c700="#0054ae", c800="#1e40af", c900="#1e3a8a", c950="#1d3660"),
|
|
).set(
|
|
body_background_fill_dark='*primary_950',
|
|
body_text_color_dark='*neutral_300',
|
|
border_color_accent='*primary_700',
|
|
border_color_accent_dark='*neutral_800',
|
|
block_background_fill_dark='*primary_950',
|
|
block_border_width='2px',
|
|
block_border_width_dark='2px',
|
|
button_primary_background_fill_dark='*primary_500',
|
|
button_primary_border_color_dark='*primary_500'
|
|
)
|
|
|
|
css='''
|
|
@font-face {
|
|
font-family: IntelOne;
|
|
src: url("/file=./assets/intelone-bodytext-font-family-regular.ttf");
|
|
}
|
|
.gradio-container {background-color: #0a0c2b}
|
|
table {
|
|
border-collapse: collapse;
|
|
border: none;
|
|
}
|
|
'''
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
html_title = '''
|
|
<table style="bordercolor=#0a0c2b; border=0">
|
|
<tr style="height:150px; border:0">
|
|
<td style="border:0"><img src="/file=./assets/header.png"></td>
|
|
</tr>
|
|
</table>
|
|
|
|
'''
|
|
|
|
|
|
dropdown_list = [
|
|
"What is the name of one of the astronauts?",
|
|
"An astronaut's spacewalk",
|
|
"What does the astronaut say?",
|
|
|
|
]
|
|
|
|
no_change_btn = gr.Button()
|
|
enable_btn = gr.Button(interactive=True)
|
|
disable_btn = gr.Button(interactive=False)
|
|
|
|
def clear_history(state, request: gr.Request):
|
|
state = get_gradio_instance(state.mm_rag_chain)
|
|
return (state, state.to_gradio_chatbot(), "", None) + (disable_btn,) * 1
|
|
|
|
def add_text(state, text, request: gr.Request):
|
|
if len(text) <= 0 :
|
|
state.skip_next = True
|
|
return (state, state.to_gradio_chatbot(), "", None) + (no_change_btn,) * 1
|
|
|
|
text = text[:1536]
|
|
|
|
state.append_message(state.roles[0], text)
|
|
state.append_message(state.roles[1], None)
|
|
state.skip_next = False
|
|
return (state, state.to_gradio_chatbot(), "") + (disable_btn,) * 1
|
|
|
|
def http_bot(
|
|
state, request: gr.Request
|
|
):
|
|
start_tstamp = time.time()
|
|
|
|
if state.skip_next:
|
|
|
|
path_to_sub_videos = state.get_path_to_subvideos()
|
|
yield (state, state.to_gradio_chatbot(), path_to_sub_videos) + (no_change_btn,) * 1
|
|
return
|
|
|
|
if len(state.messages) == state.offset + 2:
|
|
|
|
new_state = get_gradio_instance(state.mm_rag_chain)
|
|
new_state.append_message(new_state.roles[0], state.messages[-2][1])
|
|
new_state.append_message(new_state.roles[1], None)
|
|
state = new_state
|
|
|
|
all_images = state.get_images(return_pil=False)
|
|
|
|
|
|
is_very_first_query = True
|
|
if len(all_images) == 0:
|
|
|
|
|
|
prompt_or_conversation = state.get_prompt_for_rag()
|
|
else:
|
|
|
|
is_very_first_query = False
|
|
prompt_or_conversation = state.get_conversation_for_lvlm()
|
|
|
|
if is_very_first_query:
|
|
executor = state.mm_rag_chain
|
|
else:
|
|
executor = lvlm_inference_with_conversation
|
|
|
|
state.messages[-1][-1] = "▌"
|
|
path_to_sub_videos = state.get_path_to_subvideos()
|
|
yield (state, state.to_gradio_chatbot(), path_to_sub_videos) + (disable_btn,) * 1
|
|
|
|
try:
|
|
if is_very_first_query:
|
|
|
|
response = executor.invoke(prompt_or_conversation)
|
|
message = response['final_text_output']
|
|
if 'metadata' in response['input_to_lvlm']:
|
|
metadata = response['input_to_lvlm']['metadata']
|
|
if (state.path_to_img is None
|
|
and 'input_to_lvlm' in response
|
|
and 'image' in response['input_to_lvlm']
|
|
):
|
|
state.path_to_img = response['input_to_lvlm']['image']
|
|
|
|
if state.path_to_video is None and 'video_path' in metadata:
|
|
video_path = metadata['video_path']
|
|
mid_time_ms = metadata['mid_time_ms']
|
|
splited_video_path = split_video(video_path, mid_time_ms)
|
|
state.path_to_video = splited_video_path
|
|
|
|
if state.caption is None and 'transcript' in metadata:
|
|
state.caption = metadata['transcript']
|
|
else:
|
|
raise ValueError("Response's format is changed")
|
|
else:
|
|
|
|
message = executor(prompt_or_conversation)
|
|
|
|
except Exception as e:
|
|
print(e)
|
|
state.messages[-1][-1] = server_error_msg
|
|
yield (state, state.to_gradio_chatbot(), None) + (
|
|
enable_btn,
|
|
)
|
|
return
|
|
|
|
state.messages[-1][-1] = message
|
|
path_to_sub_videos = state.get_path_to_subvideos()
|
|
|
|
|
|
|
|
|
|
|
|
yield (state, state.to_gradio_chatbot(), path_to_sub_videos) + (enable_btn,) * 1
|
|
|
|
finish_tstamp = time.time()
|
|
return
|
|
|
|
def get_demo(rag_chain=None):
|
|
if rag_chain is None:
|
|
rag_chain = get_default_rag_chain()
|
|
|
|
with gr.Blocks(theme=theme, css=css) as demo:
|
|
|
|
instance = get_gradio_instance(rag_chain)
|
|
state = gr.State(instance)
|
|
demo.load(
|
|
None,
|
|
None,
|
|
js="""
|
|
() => {
|
|
const params = new URLSearchParams(window.location.search);
|
|
if (!params.has('__theme')) {
|
|
params.set('__theme', 'dark');
|
|
window.location.search = params.toString();
|
|
}
|
|
}""",
|
|
)
|
|
gr.HTML(value=html_title)
|
|
with gr.Row():
|
|
with gr.Column(scale=4):
|
|
video = gr.Video(height=512, width=512, elem_id="video", interactive=False )
|
|
with gr.Column(scale=7):
|
|
chatbot = gr.Chatbot(
|
|
elem_id="chatbot", label="Multimodal RAG Chatbot", height=512,
|
|
)
|
|
with gr.Row():
|
|
with gr.Column(scale=8):
|
|
|
|
textbox = gr.Dropdown(
|
|
dropdown_list,
|
|
allow_custom_value=True,
|
|
|
|
|
|
label="Query",
|
|
info="Enter your query here or choose a sample from the dropdown list!"
|
|
)
|
|
with gr.Column(scale=1, min_width=50):
|
|
submit_btn = gr.Button(
|
|
value="Send", variant="primary", interactive=True
|
|
)
|
|
with gr.Row(elem_id="buttons") as button_row:
|
|
clear_btn = gr.Button(value="🗑️ Clear history", interactive=False)
|
|
|
|
btn_list = [clear_btn]
|
|
|
|
clear_btn.click(
|
|
clear_history, [state], [state, chatbot, textbox, video] + btn_list
|
|
)
|
|
submit_btn.click(
|
|
add_text,
|
|
[state, textbox],
|
|
[state, chatbot, textbox,] + btn_list,
|
|
).then(
|
|
http_bot,
|
|
[state],
|
|
[state, chatbot, video] + btn_list,
|
|
)
|
|
return demo
|
|
|
|
|