Add extra file for storing multimodel data in rag
Browse files
mm_rag/MLM/client.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Base interface for client making requests/call to visual language model provider API"""
|
2 |
+
|
3 |
+
from abc import ABC, abstractmethod
|
4 |
+
from typing import List, Optional, Dict, Union, Iterator
|
5 |
+
import requests
|
6 |
+
import json
|
7 |
+
from utility import isBase64, encode_image, encode_image_from_path_or_url, lvlm_inference
|
8 |
+
|
9 |
+
class BaseClient(ABC):
|
10 |
+
def __init__(self,
|
11 |
+
hostname: str = "127.0.0.1",
|
12 |
+
port: int = 8090,
|
13 |
+
timeout: int = 60,
|
14 |
+
url: Optional[str] = None):
|
15 |
+
self.connection_url = f"http://{hostname}:{port}" if url is None else url
|
16 |
+
self.timeout = timeout
|
17 |
+
# self.headers = {'Content-Type': 'application/x-www-form-urlencoded'}
|
18 |
+
self.headers = {'Content-Type': 'application/json'}
|
19 |
+
|
20 |
+
def root(self):
|
21 |
+
"""Request for showing welcome message"""
|
22 |
+
connection_route = f"{self.connection_url}/"
|
23 |
+
return requests.get(connection_route)
|
24 |
+
|
25 |
+
@abstractmethod
|
26 |
+
def generate(self,
|
27 |
+
prompt: str,
|
28 |
+
image: str,
|
29 |
+
**kwargs
|
30 |
+
) -> str:
|
31 |
+
"""Send request to visual language model API
|
32 |
+
and return generated text that was returned by the visual language model API
|
33 |
+
|
34 |
+
Use this method when you want to call visual language model API to generate text without streaming
|
35 |
+
|
36 |
+
Args:
|
37 |
+
prompt: A prompt.
|
38 |
+
image: A string that can be either path to image or base64 of an image.
|
39 |
+
**kwargs: Arbitrary additional keyword arguments.
|
40 |
+
These are usually passed to the model provider API call as hyperparameter for generation.
|
41 |
+
|
42 |
+
Returns:
|
43 |
+
Text returned from visual language model provider API call
|
44 |
+
"""
|
45 |
+
|
46 |
+
|
47 |
+
def generate_stream(
|
48 |
+
self,
|
49 |
+
prompt: str,
|
50 |
+
image: str,
|
51 |
+
**kwargs
|
52 |
+
) -> Iterator[str]:
|
53 |
+
"""Send request to visual language model API
|
54 |
+
and return an iterator of streaming text that were returned from the visual language model API call
|
55 |
+
|
56 |
+
Use this method when you want to call visual language model API to stream generated text.
|
57 |
+
|
58 |
+
Args:
|
59 |
+
prompt: A prompt.
|
60 |
+
image: A string that can be either path to image or base64 of an image.
|
61 |
+
**kwargs: Arbitrary additional keyword arguments.
|
62 |
+
These are usually passed to the model provider API call as hyperparameter for generation.
|
63 |
+
|
64 |
+
Returns:
|
65 |
+
Iterator of text streamed from visual language model provider API call
|
66 |
+
"""
|
67 |
+
raise NotImplementedError()
|
68 |
+
|
69 |
+
def generate_batch(
|
70 |
+
self,
|
71 |
+
prompt: List[str],
|
72 |
+
image: List[str],
|
73 |
+
**kwargs
|
74 |
+
) -> List[str]:
|
75 |
+
"""Send a request to visual language model API for multi-batch generation
|
76 |
+
and return a list of generated text that was returned by the visual language model API
|
77 |
+
|
78 |
+
Use this method when you want to call visual language model API to multi-batch generate text.
|
79 |
+
Multi-batch generation does not support streaming.
|
80 |
+
|
81 |
+
Args:
|
82 |
+
prompt: List of prompts.
|
83 |
+
image: List of strings; each of which can be either path to image or base64 of an image.
|
84 |
+
**kwargs: Arbitrary additional keyword arguments.
|
85 |
+
These are usually passed to the model provider API call as hyperparameter for generation.
|
86 |
+
|
87 |
+
Returns:
|
88 |
+
List of texts returned from visual language model provider API call
|
89 |
+
"""
|
90 |
+
raise NotImplementedError()
|
91 |
+
|
92 |
+
class PredictionGuardClient(BaseClient):
|
93 |
+
|
94 |
+
generate_kwargs = ['max_tokens',
|
95 |
+
'temperature',
|
96 |
+
'top_p',
|
97 |
+
'top_k']
|
98 |
+
|
99 |
+
def filter_accepted_genkwargs(self, kwargs):
|
100 |
+
gen_args = {}
|
101 |
+
if "generate_kwargs" in kwargs and isinstance(kwargs["generate_kwargs"], dict):
|
102 |
+
gen_args = {k:kwargs["generate_kwargs"][k]
|
103 |
+
for k in self.generate_kwargs
|
104 |
+
if k in kwargs["generate_kwargs"]}
|
105 |
+
return gen_args
|
106 |
+
|
107 |
+
def generate(self,
|
108 |
+
prompt: str,
|
109 |
+
image: str,
|
110 |
+
**kwargs
|
111 |
+
) -> str:
|
112 |
+
"""Send request to PredictionGuard's API
|
113 |
+
and return generated text that was returned by LLAVA model
|
114 |
+
|
115 |
+
Use this method when you want to call LLAVA model API to generate text without streaming
|
116 |
+
|
117 |
+
Args:
|
118 |
+
prompt: A prompt.
|
119 |
+
image: A string that can be either path/URL to image or base64 of an image.
|
120 |
+
**kwargs: Arbitrary additional keyword arguments.
|
121 |
+
These are usually passed to the model provider API call as hyperparameter for generation.
|
122 |
+
|
123 |
+
Returns:
|
124 |
+
Text returned from visual language model provider API call
|
125 |
+
"""
|
126 |
+
|
127 |
+
assert image is not None and len(image) != "", "the input image cannot be None, it must be either base64-encoded image or path/URL to image"
|
128 |
+
if isBase64(image):
|
129 |
+
base64_image = image
|
130 |
+
else: # this is path to image or URL to image
|
131 |
+
base64_image = encode_image_from_path_or_url(image)
|
132 |
+
|
133 |
+
args = self.filter_accepted_genkwargs(kwargs)
|
134 |
+
return lvlm_inference(prompt=prompt, image=base64_image, **args)
|
135 |
+
|
mm_rag/MLM/lvlm.py
ADDED
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .client import PredictionGuardClient
|
2 |
+
from langchain_core.language_models.llms import LLM
|
3 |
+
from langchain_core.pydantic_v1 import Extra, root_validator
|
4 |
+
from typing import Any, Optional, List, Dict, Iterator, AsyncIterator
|
5 |
+
from langchain_core.callbacks import CallbackManagerForLLMRun
|
6 |
+
from utility import get_from_dict_or_env, MultimodalModelInput
|
7 |
+
|
8 |
+
from langchain_core.runnables import RunnableConfig, ensure_config
|
9 |
+
from langchain_core.language_models.base import LanguageModelInput
|
10 |
+
from langchain_core.prompt_values import StringPromptValue
|
11 |
+
# from langchain_core.outputs import GenerationChunk, LLMResult
|
12 |
+
from langchain_core.language_models.llms import BaseLLM
|
13 |
+
from langchain_core.callbacks import (
|
14 |
+
# CallbackManager,
|
15 |
+
CallbackManagerForLLMRun,
|
16 |
+
)
|
17 |
+
# from langchain_core.load import dumpd
|
18 |
+
from langchain_core.runnables.config import run_in_executor
|
19 |
+
|
20 |
+
class LVLM(LLM):
|
21 |
+
"""This class extends LLM class for implementing a custom request to LVLM provider API"""
|
22 |
+
|
23 |
+
|
24 |
+
client: Any = None #: :meta private:
|
25 |
+
hostname: Optional[str] = None
|
26 |
+
port: Optional[int] = None
|
27 |
+
url: Optional[str] = None
|
28 |
+
max_new_tokens: Optional[int] = 200
|
29 |
+
temperature: Optional[float] = 0.6
|
30 |
+
top_k: Optional[float] = 0
|
31 |
+
stop: Optional[List[str]] = None
|
32 |
+
ignore_eos: Optional[bool] = False
|
33 |
+
do_sample: Optional[bool] = True
|
34 |
+
lazy_mode: Optional[bool] = True
|
35 |
+
hpu_graphs: Optional[bool] = True
|
36 |
+
|
37 |
+
@root_validator()
|
38 |
+
def validate_environment(cls, values: Dict) -> Dict:
|
39 |
+
"""Validate that the access token and python package exists in environment if needed"""
|
40 |
+
if values['client'] is None:
|
41 |
+
# check if url of API is provided
|
42 |
+
url = get_from_dict_or_env(values, 'url', "VLM_URL", None)
|
43 |
+
if url is None:
|
44 |
+
hostname = get_from_dict_or_env(values, 'hostname', 'VLM_HOSTNAME', None)
|
45 |
+
port = get_from_dict_or_env(values, 'port', 'VLM_PORT', None)
|
46 |
+
if hostname is not None and port is not None:
|
47 |
+
values['client'] = PredictionGuardClient(hostname=hostname, port=port)
|
48 |
+
else:
|
49 |
+
# using default hostname and port to create Client
|
50 |
+
values['client'] = PredictionGuardClient()
|
51 |
+
else:
|
52 |
+
values['client'] = PredictionGuardClient(url=url)
|
53 |
+
return values
|
54 |
+
|
55 |
+
@property
|
56 |
+
def _llm_type(self) -> str:
|
57 |
+
"""Return type of llm"""
|
58 |
+
return "Large Vision Language Model"
|
59 |
+
|
60 |
+
@property
|
61 |
+
def _default_params(self) -> Dict[str, Any]:
|
62 |
+
"""Get the default parameters for calling the Prediction Guard API."""
|
63 |
+
return {
|
64 |
+
"max_tokens": self.max_new_tokens,
|
65 |
+
"temperature": self.temperature,
|
66 |
+
"top_k": self.top_k,
|
67 |
+
"ignore_eos": self.ignore_eos,
|
68 |
+
"do_sample": self.do_sample,
|
69 |
+
"stop" : self.stop,
|
70 |
+
}
|
71 |
+
|
72 |
+
def get_params(self, **kwargs):
|
73 |
+
params = self._default_params
|
74 |
+
params.update(kwargs)
|
75 |
+
return params
|
76 |
+
|
77 |
+
|
78 |
+
def _call(
|
79 |
+
self,
|
80 |
+
prompt: str,
|
81 |
+
image: str,
|
82 |
+
stop: Optional[List[str]] = None,
|
83 |
+
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
84 |
+
**kwargs: Any,
|
85 |
+
) -> str:
|
86 |
+
"""Run the VLM on the given input.
|
87 |
+
|
88 |
+
Args:
|
89 |
+
prompt: The prompt to generate from.
|
90 |
+
image: This can be either path to image or base64 encode of the image.
|
91 |
+
stop: Stop words to use when generating. Model output is cut off at the
|
92 |
+
first occurrence of any of the stop substrings.
|
93 |
+
If stop tokens are not supported consider raising NotImplementedError.
|
94 |
+
Returns:
|
95 |
+
The model output as a string. Actual completions DOES NOT include the prompt
|
96 |
+
Example: TBD
|
97 |
+
"""
|
98 |
+
params = {}
|
99 |
+
if stop is not None:
|
100 |
+
raise ValueError("stop kwargs are not permitted.")
|
101 |
+
params['generate_kwargs'] = self.get_params(**kwargs)
|
102 |
+
response = self.client.generate(prompt=prompt, image=image, **params)
|
103 |
+
return response
|
104 |
+
|
105 |
+
def _stream(
|
106 |
+
self,
|
107 |
+
prompt: str,
|
108 |
+
image: str,
|
109 |
+
stop: Optional[List[str]] = None,
|
110 |
+
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
111 |
+
**kwargs: Any,
|
112 |
+
) -> Iterator[str]:
|
113 |
+
"""Stream the VLM on the given prompt and image.
|
114 |
+
|
115 |
+
Args:
|
116 |
+
prompt: The prompt to generate from.
|
117 |
+
image: This can be either path to image or base64 encode of the image.
|
118 |
+
stop: Stop words to use when generating. Model output is cut off at the
|
119 |
+
first occurrence of any of the stop substrings.
|
120 |
+
If stop tokens are not supported consider raising NotImplementedError.
|
121 |
+
Returns:
|
122 |
+
The model outputs an iterator of string. Actual completions DOES NOT include the prompt
|
123 |
+
Example: TBD
|
124 |
+
"""
|
125 |
+
params = {}
|
126 |
+
params['generate_kwargs'] = self.get_params(**kwargs)
|
127 |
+
for chunk in self.client.generate_stream(prompt=prompt, image=image, **params):
|
128 |
+
yield chunk
|
129 |
+
|
130 |
+
async def _astream(
|
131 |
+
self,
|
132 |
+
prompt: str,
|
133 |
+
image: str,
|
134 |
+
stop: Optional[List[str]] = None,
|
135 |
+
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
136 |
+
**kwargs: Any,
|
137 |
+
) -> AsyncIterator[str]:
|
138 |
+
"""An async version of _stream method that stream the VLM on the given prompt and image.
|
139 |
+
|
140 |
+
Args:
|
141 |
+
prompt: The prompt to generate from.
|
142 |
+
image: This can be either path to image or base64 encode of the image.
|
143 |
+
stop: Stop words to use when generating. Model output is cut off at the
|
144 |
+
first occurrence of any of the stop substrings.
|
145 |
+
If stop tokens are not supported consider raising NotImplementedError.
|
146 |
+
Returns:
|
147 |
+
The model outputs an async iterator of string. Actual completions DOES NOT include the prompt
|
148 |
+
Example: TBD
|
149 |
+
"""
|
150 |
+
iterator = await run_in_executor(
|
151 |
+
None,
|
152 |
+
self._stream,
|
153 |
+
prompt,
|
154 |
+
image,
|
155 |
+
stop,
|
156 |
+
run_manager.get_sync() if run_manager else None,
|
157 |
+
**kwargs,
|
158 |
+
)
|
159 |
+
done = object()
|
160 |
+
while True:
|
161 |
+
item = await run_in_executor(
|
162 |
+
None,
|
163 |
+
next,
|
164 |
+
iterator,
|
165 |
+
done, # type: ignore[call-arg, arg-type]
|
166 |
+
)
|
167 |
+
if item is done:
|
168 |
+
break
|
169 |
+
yield item # type: ignore[misc]
|
170 |
+
|
171 |
+
def invoke(
|
172 |
+
self,
|
173 |
+
input: MultimodalModelInput,
|
174 |
+
config: Optional[RunnableConfig] = None,
|
175 |
+
*,
|
176 |
+
stop: Optional[List[str]] = None,
|
177 |
+
**kwargs: Any,
|
178 |
+
) -> str:
|
179 |
+
config = ensure_config(config)
|
180 |
+
if isinstance(input, dict) and 'prompt' in input.keys() and 'image' in input.keys():
|
181 |
+
return (
|
182 |
+
self.generate_prompt(
|
183 |
+
[self._convert_input(StringPromptValue(text=input['prompt']))],
|
184 |
+
stop=stop,
|
185 |
+
callbacks=config.get("callbacks"),
|
186 |
+
tags=config.get("tags"),
|
187 |
+
metadata=config.get("metadata"),
|
188 |
+
run_name=config.get("run_name"),
|
189 |
+
run_id=config.pop("run_id", None),
|
190 |
+
image= input['image'],
|
191 |
+
**kwargs,
|
192 |
+
)
|
193 |
+
.generations[0][0]
|
194 |
+
.text
|
195 |
+
)
|
196 |
+
return (
|
197 |
+
self.generate_prompt(
|
198 |
+
[self._convert_input(input)],
|
199 |
+
stop=stop,
|
200 |
+
callbacks=config.get("callbacks"),
|
201 |
+
tags=config.get("tags"),
|
202 |
+
metadata=config.get("metadata"),
|
203 |
+
run_name=config.get("run_name"),
|
204 |
+
run_id=config.pop("run_id", None),
|
205 |
+
**kwargs,
|
206 |
+
)
|
207 |
+
.generations[0][0]
|
208 |
+
.text
|
209 |
+
)
|
210 |
+
|
211 |
+
async def ainvoke(
|
212 |
+
self,
|
213 |
+
input: MultimodalModelInput,
|
214 |
+
config: Optional[RunnableConfig] = None,
|
215 |
+
*,
|
216 |
+
stop: Optional[List[str]] = None,
|
217 |
+
**kwargs: Any,
|
218 |
+
) -> str:
|
219 |
+
config = ensure_config(config)
|
220 |
+
if isinstance(input, dict) and 'prompt' in input.keys() and 'image' in input.keys():
|
221 |
+
llm_result = await self.agenerate_prompt(
|
222 |
+
[self._convert_input(StringPromptValue(text=input['prompt']))],
|
223 |
+
stop=stop,
|
224 |
+
callbacks=config.get("callbacks"),
|
225 |
+
tags=config.get("tags"),
|
226 |
+
metadata=config.get("metadata"),
|
227 |
+
run_name=config.get("run_name"),
|
228 |
+
run_id=config.pop("run_id", None),
|
229 |
+
image=input['image'],
|
230 |
+
**kwargs,
|
231 |
+
)
|
232 |
+
else:
|
233 |
+
llm_result = await self.agenerate_prompt(
|
234 |
+
[self._convert_input(input)],
|
235 |
+
stop=stop,
|
236 |
+
callbacks=config.get("callbacks"),
|
237 |
+
tags=config.get("tags"),
|
238 |
+
metadata=config.get("metadata"),
|
239 |
+
run_name=config.get("run_name"),
|
240 |
+
run_id=config.pop("run_id", None),
|
241 |
+
**kwargs,
|
242 |
+
)
|
243 |
+
return llm_result.generations[0][0].text
|
244 |
+
|
245 |
+
def stream(
|
246 |
+
self,
|
247 |
+
input: MultimodalModelInput,
|
248 |
+
config: Optional[RunnableConfig] = None,
|
249 |
+
*,
|
250 |
+
stop: Optional[List[str]] = None,
|
251 |
+
**kwargs: Any,
|
252 |
+
) -> Iterator[str]:
|
253 |
+
if type(self)._stream == BaseLLM._stream:
|
254 |
+
# model doesn't implement streaming, so use default implementation
|
255 |
+
yield self.invoke(input, config=config, stop=stop, **kwargs)
|
256 |
+
else:
|
257 |
+
if stop is not None:
|
258 |
+
raise ValueError("stop kwargs are not permitted.")
|
259 |
+
image = None
|
260 |
+
prompt = None
|
261 |
+
if isinstance(input, dict) and 'prompt' in input.keys():
|
262 |
+
prompt = self._convert_input(input['prompt']).to_string()
|
263 |
+
else:
|
264 |
+
raise ValueError("prompt must be provided")
|
265 |
+
if isinstance(input, dict) and 'image' in input.keys():
|
266 |
+
image = input['image']
|
267 |
+
|
268 |
+
for chunk in self._stream(
|
269 |
+
prompt=prompt, image=image, **kwargs
|
270 |
+
):
|
271 |
+
yield chunk
|
272 |
+
|
273 |
+
async def astream(
|
274 |
+
self,
|
275 |
+
input: LanguageModelInput,
|
276 |
+
config: Optional[RunnableConfig] = None,
|
277 |
+
*,
|
278 |
+
stop: Optional[List[str]] = None,
|
279 |
+
**kwargs: Any,
|
280 |
+
) -> AsyncIterator[str]:
|
281 |
+
if (
|
282 |
+
type(self)._astream is BaseLLM._astream
|
283 |
+
and type(self)._stream is BaseLLM._stream
|
284 |
+
):
|
285 |
+
yield await self.ainvoke(input, config=config, stop=stop, **kwargs)
|
286 |
+
return
|
287 |
+
else:
|
288 |
+
if stop is not None:
|
289 |
+
raise ValueError("stop kwargs are not permitted.")
|
290 |
+
image = None
|
291 |
+
if isinstance(input, dict) and 'prompt' in input.keys() and 'image' in input.keys():
|
292 |
+
prompt = self._convert_input(input['prompt']).to_string()
|
293 |
+
image = input['image']
|
294 |
+
else:
|
295 |
+
raise ValueError("missing image is not permitted")
|
296 |
+
prompt = self._convert_input(input).to_string()
|
297 |
+
|
298 |
+
async for chunk in self._astream(
|
299 |
+
prompt=prompt, image=image, **kwargs
|
300 |
+
):
|
301 |
+
yield chunk
|
mm_rag/embeddings/__pycache__/bridgetower_embeddings.cpython-311.pyc
ADDED
Binary file (3.23 kB). View file
|
|
mm_rag/embeddings/bridgetower_embeddings.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
from langchain_core.embeddings import Embeddings
|
3 |
+
from langchain_core.pydantic_v1 import (
|
4 |
+
BaseModel,
|
5 |
+
)
|
6 |
+
from utility import encode_image, bt_embedding_from_prediction_guard
|
7 |
+
from tqdm import tqdm
|
8 |
+
|
9 |
+
class BridgeTowerEmbeddings(BaseModel, Embeddings):
|
10 |
+
""" BridgeTower embedding model """
|
11 |
+
|
12 |
+
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
13 |
+
"""Embed a list of documents using BridgeTower.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
texts: The list of texts to embed.
|
17 |
+
|
18 |
+
Returns:
|
19 |
+
List of embeddings, one for each text.
|
20 |
+
"""
|
21 |
+
embeddings = []
|
22 |
+
for text in texts:
|
23 |
+
embedding = bt_embedding_from_prediction_guard(text, "")
|
24 |
+
embeddings.append(embedding)
|
25 |
+
return embeddings
|
26 |
+
|
27 |
+
def embed_query(self, text: str) -> List[float]:
|
28 |
+
"""Embed a query using BridgeTower.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
text: The text to embed.
|
32 |
+
|
33 |
+
Returns:
|
34 |
+
Embeddings for the text.
|
35 |
+
"""
|
36 |
+
return self.embed_documents([text])[0]
|
37 |
+
|
38 |
+
def embed_image_text_pairs(self, texts: List[str], images: List[str], batch_size=2) -> List[List[float]]:
|
39 |
+
"""Embed a list of image-text pairs using BridgeTower.
|
40 |
+
|
41 |
+
Args:
|
42 |
+
texts: The list of texts to embed.
|
43 |
+
images: The list of path-to-images to embed
|
44 |
+
batch_size: the batch size to process, default to 2
|
45 |
+
Returns:
|
46 |
+
List of embeddings, one for each image-text pairs.
|
47 |
+
"""
|
48 |
+
|
49 |
+
# the length of texts must be equal to the length of images
|
50 |
+
assert len(texts)==len(images), "the len of captions should be equal to the len of images"
|
51 |
+
|
52 |
+
embeddings = []
|
53 |
+
for path_to_img, text in tqdm(zip(images, texts), total=len(texts)):
|
54 |
+
embedding = bt_embedding_from_prediction_guard(text, encode_image(path_to_img))
|
55 |
+
embeddings.append(embedding)
|
56 |
+
return embeddings
|
mm_rag/vectorstores/multimodal_lancedb.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Iterable, List, Optional
|
2 |
+
from langchain_core.embeddings import Embeddings
|
3 |
+
import uuid
|
4 |
+
from langchain_community.vectorstores.lancedb import LanceDB
|
5 |
+
|
6 |
+
class MultimodalLanceDB(LanceDB):
|
7 |
+
"""`LanceDB` vector store to process multimodal data
|
8 |
+
|
9 |
+
To use, you should have ``lancedb`` python package installed.
|
10 |
+
You can install it with ``pip install lancedb``.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
connection: LanceDB connection to use. If not provided, a new connection
|
14 |
+
will be created.
|
15 |
+
embedding: Embedding to use for the vectorstore.
|
16 |
+
vector_key: Key to use for the vector in the database. Defaults to ``vector``.
|
17 |
+
id_key: Key to use for the id in the database. Defaults to ``id``.
|
18 |
+
text_key: Key to use for the text in the database. Defaults to ``text``.
|
19 |
+
image_path_key: Key to use for the path to image in the database. Defaults to ``image_path``.
|
20 |
+
table_name: Name of the table to use. Defaults to ``vectorstore``.
|
21 |
+
api_key: API key to use for LanceDB cloud database.
|
22 |
+
region: Region to use for LanceDB cloud database.
|
23 |
+
mode: Mode to use for adding data to the table. Defaults to ``overwrite``.
|
24 |
+
|
25 |
+
|
26 |
+
|
27 |
+
Example:
|
28 |
+
.. code-block:: python
|
29 |
+
vectorstore = MultimodalLanceDB(uri='/lancedb', embedding_function)
|
30 |
+
vectorstore.add_texts(['text1', 'text2'])
|
31 |
+
result = vectorstore.similarity_search('text1')
|
32 |
+
"""
|
33 |
+
|
34 |
+
def __init__(
|
35 |
+
self,
|
36 |
+
connection: Optional[Any] = None,
|
37 |
+
embedding: Optional[Embeddings] = None,
|
38 |
+
uri: Optional[str] = "/tmp/lancedb",
|
39 |
+
vector_key: Optional[str] = "vector",
|
40 |
+
id_key: Optional[str] = "id",
|
41 |
+
text_key: Optional[str] = "text",
|
42 |
+
image_path_key: Optional[str] = "image_path",
|
43 |
+
table_name: Optional[str] = "vectorstore",
|
44 |
+
api_key: Optional[str] = None,
|
45 |
+
region: Optional[str] = None,
|
46 |
+
mode: Optional[str] = "append",
|
47 |
+
):
|
48 |
+
super(MultimodalLanceDB, self).__init__(connection, embedding, uri, vector_key, id_key, text_key, table_name, api_key, region, mode)
|
49 |
+
self._image_path_key = image_path_key
|
50 |
+
|
51 |
+
def add_text_image_pairs(
|
52 |
+
self,
|
53 |
+
texts: Iterable[str],
|
54 |
+
image_paths: Iterable[str],
|
55 |
+
metadatas: Optional[List[dict]] = None,
|
56 |
+
ids: Optional[List[str]] = None,
|
57 |
+
**kwargs: Any,
|
58 |
+
) -> List[str]:
|
59 |
+
"""Turn text-image pairs into embedding and add it to the database
|
60 |
+
|
61 |
+
Args:
|
62 |
+
texts: Iterable of strings to combine with corresponding images to add to the vectorstore.
|
63 |
+
images: Iterable of path-to-images as strings to combine with corresponding texts to add to the vectorstore.
|
64 |
+
metadatas: Optional list of metadatas associated with the texts.
|
65 |
+
ids: Optional list of ids to associate w ith the texts.
|
66 |
+
|
67 |
+
Returns:
|
68 |
+
List of ids of the added text-image pairs.
|
69 |
+
"""
|
70 |
+
# the length of texts must be equal to the length of images
|
71 |
+
assert len(texts)==len(image_paths), "the len of transcripts should be equal to the len of images"
|
72 |
+
|
73 |
+
# Embed texts and create documents
|
74 |
+
docs = []
|
75 |
+
ids = ids or [str(uuid.uuid4()) for _ in texts]
|
76 |
+
embeddings = self._embedding.embed_image_text_pairs(texts=list(texts), images=list(image_paths)) # type: ignore
|
77 |
+
for idx, text in enumerate(texts):
|
78 |
+
embedding = embeddings[idx]
|
79 |
+
metadata = metadatas[idx] if metadatas else {"id": ids[idx]}
|
80 |
+
docs.append(
|
81 |
+
{
|
82 |
+
self._vector_key: embedding,
|
83 |
+
self._id_key: ids[idx],
|
84 |
+
self._text_key: text,
|
85 |
+
self._image_path_key : image_paths[idx],
|
86 |
+
"metadata": metadata,
|
87 |
+
}
|
88 |
+
)
|
89 |
+
|
90 |
+
if 'mode' in kwargs:
|
91 |
+
mode = kwargs['mode']
|
92 |
+
else:
|
93 |
+
mode = self.mode
|
94 |
+
if self._table_name in self._connection.table_names():
|
95 |
+
tbl = self._connection.open_table(self._table_name)
|
96 |
+
if self.api_key is None:
|
97 |
+
tbl.add(docs, mode=mode)
|
98 |
+
else:
|
99 |
+
tbl.add(docs)
|
100 |
+
else:
|
101 |
+
self._connection.create_table(self._table_name, data=docs)
|
102 |
+
return ids
|
103 |
+
|
104 |
+
@classmethod
|
105 |
+
def from_text_image_pairs(
|
106 |
+
cls,
|
107 |
+
texts: List[str],
|
108 |
+
image_paths: List[str],
|
109 |
+
embedding: Embeddings,
|
110 |
+
metadatas: Optional[List[dict]] = None,
|
111 |
+
connection: Any = None,
|
112 |
+
vector_key: Optional[str] = "vector",
|
113 |
+
id_key: Optional[str] = "id",
|
114 |
+
text_key: Optional[str] = "text",
|
115 |
+
image_path_key: Optional[str] = "image_path",
|
116 |
+
table_name: Optional[str] = "vectorstore",
|
117 |
+
**kwargs: Any,
|
118 |
+
):
|
119 |
+
|
120 |
+
instance = MultimodalLanceDB(
|
121 |
+
connection=connection,
|
122 |
+
embedding=embedding,
|
123 |
+
vector_key=vector_key,
|
124 |
+
id_key=id_key,
|
125 |
+
text_key=text_key,
|
126 |
+
image_path_key=image_path_key,
|
127 |
+
table_name=table_name,
|
128 |
+
)
|
129 |
+
instance.add_text_image_pairs(texts, image_paths, metadatas=metadatas, **kwargs)
|
130 |
+
|
131 |
+
return instance
|