88hours commited on
Commit
7d9878f
·
1 Parent(s): 24ad9c0

Add extra file for storing multimodel data in rag

Browse files
mm_rag/MLM/client.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Base interface for client making requests/call to visual language model provider API"""
2
+
3
+ from abc import ABC, abstractmethod
4
+ from typing import List, Optional, Dict, Union, Iterator
5
+ import requests
6
+ import json
7
+ from utility import isBase64, encode_image, encode_image_from_path_or_url, lvlm_inference
8
+
9
+ class BaseClient(ABC):
10
+ def __init__(self,
11
+ hostname: str = "127.0.0.1",
12
+ port: int = 8090,
13
+ timeout: int = 60,
14
+ url: Optional[str] = None):
15
+ self.connection_url = f"http://{hostname}:{port}" if url is None else url
16
+ self.timeout = timeout
17
+ # self.headers = {'Content-Type': 'application/x-www-form-urlencoded'}
18
+ self.headers = {'Content-Type': 'application/json'}
19
+
20
+ def root(self):
21
+ """Request for showing welcome message"""
22
+ connection_route = f"{self.connection_url}/"
23
+ return requests.get(connection_route)
24
+
25
+ @abstractmethod
26
+ def generate(self,
27
+ prompt: str,
28
+ image: str,
29
+ **kwargs
30
+ ) -> str:
31
+ """Send request to visual language model API
32
+ and return generated text that was returned by the visual language model API
33
+
34
+ Use this method when you want to call visual language model API to generate text without streaming
35
+
36
+ Args:
37
+ prompt: A prompt.
38
+ image: A string that can be either path to image or base64 of an image.
39
+ **kwargs: Arbitrary additional keyword arguments.
40
+ These are usually passed to the model provider API call as hyperparameter for generation.
41
+
42
+ Returns:
43
+ Text returned from visual language model provider API call
44
+ """
45
+
46
+
47
+ def generate_stream(
48
+ self,
49
+ prompt: str,
50
+ image: str,
51
+ **kwargs
52
+ ) -> Iterator[str]:
53
+ """Send request to visual language model API
54
+ and return an iterator of streaming text that were returned from the visual language model API call
55
+
56
+ Use this method when you want to call visual language model API to stream generated text.
57
+
58
+ Args:
59
+ prompt: A prompt.
60
+ image: A string that can be either path to image or base64 of an image.
61
+ **kwargs: Arbitrary additional keyword arguments.
62
+ These are usually passed to the model provider API call as hyperparameter for generation.
63
+
64
+ Returns:
65
+ Iterator of text streamed from visual language model provider API call
66
+ """
67
+ raise NotImplementedError()
68
+
69
+ def generate_batch(
70
+ self,
71
+ prompt: List[str],
72
+ image: List[str],
73
+ **kwargs
74
+ ) -> List[str]:
75
+ """Send a request to visual language model API for multi-batch generation
76
+ and return a list of generated text that was returned by the visual language model API
77
+
78
+ Use this method when you want to call visual language model API to multi-batch generate text.
79
+ Multi-batch generation does not support streaming.
80
+
81
+ Args:
82
+ prompt: List of prompts.
83
+ image: List of strings; each of which can be either path to image or base64 of an image.
84
+ **kwargs: Arbitrary additional keyword arguments.
85
+ These are usually passed to the model provider API call as hyperparameter for generation.
86
+
87
+ Returns:
88
+ List of texts returned from visual language model provider API call
89
+ """
90
+ raise NotImplementedError()
91
+
92
+ class PredictionGuardClient(BaseClient):
93
+
94
+ generate_kwargs = ['max_tokens',
95
+ 'temperature',
96
+ 'top_p',
97
+ 'top_k']
98
+
99
+ def filter_accepted_genkwargs(self, kwargs):
100
+ gen_args = {}
101
+ if "generate_kwargs" in kwargs and isinstance(kwargs["generate_kwargs"], dict):
102
+ gen_args = {k:kwargs["generate_kwargs"][k]
103
+ for k in self.generate_kwargs
104
+ if k in kwargs["generate_kwargs"]}
105
+ return gen_args
106
+
107
+ def generate(self,
108
+ prompt: str,
109
+ image: str,
110
+ **kwargs
111
+ ) -> str:
112
+ """Send request to PredictionGuard's API
113
+ and return generated text that was returned by LLAVA model
114
+
115
+ Use this method when you want to call LLAVA model API to generate text without streaming
116
+
117
+ Args:
118
+ prompt: A prompt.
119
+ image: A string that can be either path/URL to image or base64 of an image.
120
+ **kwargs: Arbitrary additional keyword arguments.
121
+ These are usually passed to the model provider API call as hyperparameter for generation.
122
+
123
+ Returns:
124
+ Text returned from visual language model provider API call
125
+ """
126
+
127
+ assert image is not None and len(image) != "", "the input image cannot be None, it must be either base64-encoded image or path/URL to image"
128
+ if isBase64(image):
129
+ base64_image = image
130
+ else: # this is path to image or URL to image
131
+ base64_image = encode_image_from_path_or_url(image)
132
+
133
+ args = self.filter_accepted_genkwargs(kwargs)
134
+ return lvlm_inference(prompt=prompt, image=base64_image, **args)
135
+
mm_rag/MLM/lvlm.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .client import PredictionGuardClient
2
+ from langchain_core.language_models.llms import LLM
3
+ from langchain_core.pydantic_v1 import Extra, root_validator
4
+ from typing import Any, Optional, List, Dict, Iterator, AsyncIterator
5
+ from langchain_core.callbacks import CallbackManagerForLLMRun
6
+ from utility import get_from_dict_or_env, MultimodalModelInput
7
+
8
+ from langchain_core.runnables import RunnableConfig, ensure_config
9
+ from langchain_core.language_models.base import LanguageModelInput
10
+ from langchain_core.prompt_values import StringPromptValue
11
+ # from langchain_core.outputs import GenerationChunk, LLMResult
12
+ from langchain_core.language_models.llms import BaseLLM
13
+ from langchain_core.callbacks import (
14
+ # CallbackManager,
15
+ CallbackManagerForLLMRun,
16
+ )
17
+ # from langchain_core.load import dumpd
18
+ from langchain_core.runnables.config import run_in_executor
19
+
20
+ class LVLM(LLM):
21
+ """This class extends LLM class for implementing a custom request to LVLM provider API"""
22
+
23
+
24
+ client: Any = None #: :meta private:
25
+ hostname: Optional[str] = None
26
+ port: Optional[int] = None
27
+ url: Optional[str] = None
28
+ max_new_tokens: Optional[int] = 200
29
+ temperature: Optional[float] = 0.6
30
+ top_k: Optional[float] = 0
31
+ stop: Optional[List[str]] = None
32
+ ignore_eos: Optional[bool] = False
33
+ do_sample: Optional[bool] = True
34
+ lazy_mode: Optional[bool] = True
35
+ hpu_graphs: Optional[bool] = True
36
+
37
+ @root_validator()
38
+ def validate_environment(cls, values: Dict) -> Dict:
39
+ """Validate that the access token and python package exists in environment if needed"""
40
+ if values['client'] is None:
41
+ # check if url of API is provided
42
+ url = get_from_dict_or_env(values, 'url', "VLM_URL", None)
43
+ if url is None:
44
+ hostname = get_from_dict_or_env(values, 'hostname', 'VLM_HOSTNAME', None)
45
+ port = get_from_dict_or_env(values, 'port', 'VLM_PORT', None)
46
+ if hostname is not None and port is not None:
47
+ values['client'] = PredictionGuardClient(hostname=hostname, port=port)
48
+ else:
49
+ # using default hostname and port to create Client
50
+ values['client'] = PredictionGuardClient()
51
+ else:
52
+ values['client'] = PredictionGuardClient(url=url)
53
+ return values
54
+
55
+ @property
56
+ def _llm_type(self) -> str:
57
+ """Return type of llm"""
58
+ return "Large Vision Language Model"
59
+
60
+ @property
61
+ def _default_params(self) -> Dict[str, Any]:
62
+ """Get the default parameters for calling the Prediction Guard API."""
63
+ return {
64
+ "max_tokens": self.max_new_tokens,
65
+ "temperature": self.temperature,
66
+ "top_k": self.top_k,
67
+ "ignore_eos": self.ignore_eos,
68
+ "do_sample": self.do_sample,
69
+ "stop" : self.stop,
70
+ }
71
+
72
+ def get_params(self, **kwargs):
73
+ params = self._default_params
74
+ params.update(kwargs)
75
+ return params
76
+
77
+
78
+ def _call(
79
+ self,
80
+ prompt: str,
81
+ image: str,
82
+ stop: Optional[List[str]] = None,
83
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
84
+ **kwargs: Any,
85
+ ) -> str:
86
+ """Run the VLM on the given input.
87
+
88
+ Args:
89
+ prompt: The prompt to generate from.
90
+ image: This can be either path to image or base64 encode of the image.
91
+ stop: Stop words to use when generating. Model output is cut off at the
92
+ first occurrence of any of the stop substrings.
93
+ If stop tokens are not supported consider raising NotImplementedError.
94
+ Returns:
95
+ The model output as a string. Actual completions DOES NOT include the prompt
96
+ Example: TBD
97
+ """
98
+ params = {}
99
+ if stop is not None:
100
+ raise ValueError("stop kwargs are not permitted.")
101
+ params['generate_kwargs'] = self.get_params(**kwargs)
102
+ response = self.client.generate(prompt=prompt, image=image, **params)
103
+ return response
104
+
105
+ def _stream(
106
+ self,
107
+ prompt: str,
108
+ image: str,
109
+ stop: Optional[List[str]] = None,
110
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
111
+ **kwargs: Any,
112
+ ) -> Iterator[str]:
113
+ """Stream the VLM on the given prompt and image.
114
+
115
+ Args:
116
+ prompt: The prompt to generate from.
117
+ image: This can be either path to image or base64 encode of the image.
118
+ stop: Stop words to use when generating. Model output is cut off at the
119
+ first occurrence of any of the stop substrings.
120
+ If stop tokens are not supported consider raising NotImplementedError.
121
+ Returns:
122
+ The model outputs an iterator of string. Actual completions DOES NOT include the prompt
123
+ Example: TBD
124
+ """
125
+ params = {}
126
+ params['generate_kwargs'] = self.get_params(**kwargs)
127
+ for chunk in self.client.generate_stream(prompt=prompt, image=image, **params):
128
+ yield chunk
129
+
130
+ async def _astream(
131
+ self,
132
+ prompt: str,
133
+ image: str,
134
+ stop: Optional[List[str]] = None,
135
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
136
+ **kwargs: Any,
137
+ ) -> AsyncIterator[str]:
138
+ """An async version of _stream method that stream the VLM on the given prompt and image.
139
+
140
+ Args:
141
+ prompt: The prompt to generate from.
142
+ image: This can be either path to image or base64 encode of the image.
143
+ stop: Stop words to use when generating. Model output is cut off at the
144
+ first occurrence of any of the stop substrings.
145
+ If stop tokens are not supported consider raising NotImplementedError.
146
+ Returns:
147
+ The model outputs an async iterator of string. Actual completions DOES NOT include the prompt
148
+ Example: TBD
149
+ """
150
+ iterator = await run_in_executor(
151
+ None,
152
+ self._stream,
153
+ prompt,
154
+ image,
155
+ stop,
156
+ run_manager.get_sync() if run_manager else None,
157
+ **kwargs,
158
+ )
159
+ done = object()
160
+ while True:
161
+ item = await run_in_executor(
162
+ None,
163
+ next,
164
+ iterator,
165
+ done, # type: ignore[call-arg, arg-type]
166
+ )
167
+ if item is done:
168
+ break
169
+ yield item # type: ignore[misc]
170
+
171
+ def invoke(
172
+ self,
173
+ input: MultimodalModelInput,
174
+ config: Optional[RunnableConfig] = None,
175
+ *,
176
+ stop: Optional[List[str]] = None,
177
+ **kwargs: Any,
178
+ ) -> str:
179
+ config = ensure_config(config)
180
+ if isinstance(input, dict) and 'prompt' in input.keys() and 'image' in input.keys():
181
+ return (
182
+ self.generate_prompt(
183
+ [self._convert_input(StringPromptValue(text=input['prompt']))],
184
+ stop=stop,
185
+ callbacks=config.get("callbacks"),
186
+ tags=config.get("tags"),
187
+ metadata=config.get("metadata"),
188
+ run_name=config.get("run_name"),
189
+ run_id=config.pop("run_id", None),
190
+ image= input['image'],
191
+ **kwargs,
192
+ )
193
+ .generations[0][0]
194
+ .text
195
+ )
196
+ return (
197
+ self.generate_prompt(
198
+ [self._convert_input(input)],
199
+ stop=stop,
200
+ callbacks=config.get("callbacks"),
201
+ tags=config.get("tags"),
202
+ metadata=config.get("metadata"),
203
+ run_name=config.get("run_name"),
204
+ run_id=config.pop("run_id", None),
205
+ **kwargs,
206
+ )
207
+ .generations[0][0]
208
+ .text
209
+ )
210
+
211
+ async def ainvoke(
212
+ self,
213
+ input: MultimodalModelInput,
214
+ config: Optional[RunnableConfig] = None,
215
+ *,
216
+ stop: Optional[List[str]] = None,
217
+ **kwargs: Any,
218
+ ) -> str:
219
+ config = ensure_config(config)
220
+ if isinstance(input, dict) and 'prompt' in input.keys() and 'image' in input.keys():
221
+ llm_result = await self.agenerate_prompt(
222
+ [self._convert_input(StringPromptValue(text=input['prompt']))],
223
+ stop=stop,
224
+ callbacks=config.get("callbacks"),
225
+ tags=config.get("tags"),
226
+ metadata=config.get("metadata"),
227
+ run_name=config.get("run_name"),
228
+ run_id=config.pop("run_id", None),
229
+ image=input['image'],
230
+ **kwargs,
231
+ )
232
+ else:
233
+ llm_result = await self.agenerate_prompt(
234
+ [self._convert_input(input)],
235
+ stop=stop,
236
+ callbacks=config.get("callbacks"),
237
+ tags=config.get("tags"),
238
+ metadata=config.get("metadata"),
239
+ run_name=config.get("run_name"),
240
+ run_id=config.pop("run_id", None),
241
+ **kwargs,
242
+ )
243
+ return llm_result.generations[0][0].text
244
+
245
+ def stream(
246
+ self,
247
+ input: MultimodalModelInput,
248
+ config: Optional[RunnableConfig] = None,
249
+ *,
250
+ stop: Optional[List[str]] = None,
251
+ **kwargs: Any,
252
+ ) -> Iterator[str]:
253
+ if type(self)._stream == BaseLLM._stream:
254
+ # model doesn't implement streaming, so use default implementation
255
+ yield self.invoke(input, config=config, stop=stop, **kwargs)
256
+ else:
257
+ if stop is not None:
258
+ raise ValueError("stop kwargs are not permitted.")
259
+ image = None
260
+ prompt = None
261
+ if isinstance(input, dict) and 'prompt' in input.keys():
262
+ prompt = self._convert_input(input['prompt']).to_string()
263
+ else:
264
+ raise ValueError("prompt must be provided")
265
+ if isinstance(input, dict) and 'image' in input.keys():
266
+ image = input['image']
267
+
268
+ for chunk in self._stream(
269
+ prompt=prompt, image=image, **kwargs
270
+ ):
271
+ yield chunk
272
+
273
+ async def astream(
274
+ self,
275
+ input: LanguageModelInput,
276
+ config: Optional[RunnableConfig] = None,
277
+ *,
278
+ stop: Optional[List[str]] = None,
279
+ **kwargs: Any,
280
+ ) -> AsyncIterator[str]:
281
+ if (
282
+ type(self)._astream is BaseLLM._astream
283
+ and type(self)._stream is BaseLLM._stream
284
+ ):
285
+ yield await self.ainvoke(input, config=config, stop=stop, **kwargs)
286
+ return
287
+ else:
288
+ if stop is not None:
289
+ raise ValueError("stop kwargs are not permitted.")
290
+ image = None
291
+ if isinstance(input, dict) and 'prompt' in input.keys() and 'image' in input.keys():
292
+ prompt = self._convert_input(input['prompt']).to_string()
293
+ image = input['image']
294
+ else:
295
+ raise ValueError("missing image is not permitted")
296
+ prompt = self._convert_input(input).to_string()
297
+
298
+ async for chunk in self._astream(
299
+ prompt=prompt, image=image, **kwargs
300
+ ):
301
+ yield chunk
mm_rag/embeddings/__pycache__/bridgetower_embeddings.cpython-311.pyc ADDED
Binary file (3.23 kB). View file
 
mm_rag/embeddings/bridgetower_embeddings.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ from langchain_core.embeddings import Embeddings
3
+ from langchain_core.pydantic_v1 import (
4
+ BaseModel,
5
+ )
6
+ from utility import encode_image, bt_embedding_from_prediction_guard
7
+ from tqdm import tqdm
8
+
9
+ class BridgeTowerEmbeddings(BaseModel, Embeddings):
10
+ """ BridgeTower embedding model """
11
+
12
+ def embed_documents(self, texts: List[str]) -> List[List[float]]:
13
+ """Embed a list of documents using BridgeTower.
14
+
15
+ Args:
16
+ texts: The list of texts to embed.
17
+
18
+ Returns:
19
+ List of embeddings, one for each text.
20
+ """
21
+ embeddings = []
22
+ for text in texts:
23
+ embedding = bt_embedding_from_prediction_guard(text, "")
24
+ embeddings.append(embedding)
25
+ return embeddings
26
+
27
+ def embed_query(self, text: str) -> List[float]:
28
+ """Embed a query using BridgeTower.
29
+
30
+ Args:
31
+ text: The text to embed.
32
+
33
+ Returns:
34
+ Embeddings for the text.
35
+ """
36
+ return self.embed_documents([text])[0]
37
+
38
+ def embed_image_text_pairs(self, texts: List[str], images: List[str], batch_size=2) -> List[List[float]]:
39
+ """Embed a list of image-text pairs using BridgeTower.
40
+
41
+ Args:
42
+ texts: The list of texts to embed.
43
+ images: The list of path-to-images to embed
44
+ batch_size: the batch size to process, default to 2
45
+ Returns:
46
+ List of embeddings, one for each image-text pairs.
47
+ """
48
+
49
+ # the length of texts must be equal to the length of images
50
+ assert len(texts)==len(images), "the len of captions should be equal to the len of images"
51
+
52
+ embeddings = []
53
+ for path_to_img, text in tqdm(zip(images, texts), total=len(texts)):
54
+ embedding = bt_embedding_from_prediction_guard(text, encode_image(path_to_img))
55
+ embeddings.append(embedding)
56
+ return embeddings
mm_rag/vectorstores/multimodal_lancedb.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, Iterable, List, Optional
2
+ from langchain_core.embeddings import Embeddings
3
+ import uuid
4
+ from langchain_community.vectorstores.lancedb import LanceDB
5
+
6
+ class MultimodalLanceDB(LanceDB):
7
+ """`LanceDB` vector store to process multimodal data
8
+
9
+ To use, you should have ``lancedb`` python package installed.
10
+ You can install it with ``pip install lancedb``.
11
+
12
+ Args:
13
+ connection: LanceDB connection to use. If not provided, a new connection
14
+ will be created.
15
+ embedding: Embedding to use for the vectorstore.
16
+ vector_key: Key to use for the vector in the database. Defaults to ``vector``.
17
+ id_key: Key to use for the id in the database. Defaults to ``id``.
18
+ text_key: Key to use for the text in the database. Defaults to ``text``.
19
+ image_path_key: Key to use for the path to image in the database. Defaults to ``image_path``.
20
+ table_name: Name of the table to use. Defaults to ``vectorstore``.
21
+ api_key: API key to use for LanceDB cloud database.
22
+ region: Region to use for LanceDB cloud database.
23
+ mode: Mode to use for adding data to the table. Defaults to ``overwrite``.
24
+
25
+
26
+
27
+ Example:
28
+ .. code-block:: python
29
+ vectorstore = MultimodalLanceDB(uri='/lancedb', embedding_function)
30
+ vectorstore.add_texts(['text1', 'text2'])
31
+ result = vectorstore.similarity_search('text1')
32
+ """
33
+
34
+ def __init__(
35
+ self,
36
+ connection: Optional[Any] = None,
37
+ embedding: Optional[Embeddings] = None,
38
+ uri: Optional[str] = "/tmp/lancedb",
39
+ vector_key: Optional[str] = "vector",
40
+ id_key: Optional[str] = "id",
41
+ text_key: Optional[str] = "text",
42
+ image_path_key: Optional[str] = "image_path",
43
+ table_name: Optional[str] = "vectorstore",
44
+ api_key: Optional[str] = None,
45
+ region: Optional[str] = None,
46
+ mode: Optional[str] = "append",
47
+ ):
48
+ super(MultimodalLanceDB, self).__init__(connection, embedding, uri, vector_key, id_key, text_key, table_name, api_key, region, mode)
49
+ self._image_path_key = image_path_key
50
+
51
+ def add_text_image_pairs(
52
+ self,
53
+ texts: Iterable[str],
54
+ image_paths: Iterable[str],
55
+ metadatas: Optional[List[dict]] = None,
56
+ ids: Optional[List[str]] = None,
57
+ **kwargs: Any,
58
+ ) -> List[str]:
59
+ """Turn text-image pairs into embedding and add it to the database
60
+
61
+ Args:
62
+ texts: Iterable of strings to combine with corresponding images to add to the vectorstore.
63
+ images: Iterable of path-to-images as strings to combine with corresponding texts to add to the vectorstore.
64
+ metadatas: Optional list of metadatas associated with the texts.
65
+ ids: Optional list of ids to associate w ith the texts.
66
+
67
+ Returns:
68
+ List of ids of the added text-image pairs.
69
+ """
70
+ # the length of texts must be equal to the length of images
71
+ assert len(texts)==len(image_paths), "the len of transcripts should be equal to the len of images"
72
+
73
+ # Embed texts and create documents
74
+ docs = []
75
+ ids = ids or [str(uuid.uuid4()) for _ in texts]
76
+ embeddings = self._embedding.embed_image_text_pairs(texts=list(texts), images=list(image_paths)) # type: ignore
77
+ for idx, text in enumerate(texts):
78
+ embedding = embeddings[idx]
79
+ metadata = metadatas[idx] if metadatas else {"id": ids[idx]}
80
+ docs.append(
81
+ {
82
+ self._vector_key: embedding,
83
+ self._id_key: ids[idx],
84
+ self._text_key: text,
85
+ self._image_path_key : image_paths[idx],
86
+ "metadata": metadata,
87
+ }
88
+ )
89
+
90
+ if 'mode' in kwargs:
91
+ mode = kwargs['mode']
92
+ else:
93
+ mode = self.mode
94
+ if self._table_name in self._connection.table_names():
95
+ tbl = self._connection.open_table(self._table_name)
96
+ if self.api_key is None:
97
+ tbl.add(docs, mode=mode)
98
+ else:
99
+ tbl.add(docs)
100
+ else:
101
+ self._connection.create_table(self._table_name, data=docs)
102
+ return ids
103
+
104
+ @classmethod
105
+ def from_text_image_pairs(
106
+ cls,
107
+ texts: List[str],
108
+ image_paths: List[str],
109
+ embedding: Embeddings,
110
+ metadatas: Optional[List[dict]] = None,
111
+ connection: Any = None,
112
+ vector_key: Optional[str] = "vector",
113
+ id_key: Optional[str] = "id",
114
+ text_key: Optional[str] = "text",
115
+ image_path_key: Optional[str] = "image_path",
116
+ table_name: Optional[str] = "vectorstore",
117
+ **kwargs: Any,
118
+ ):
119
+
120
+ instance = MultimodalLanceDB(
121
+ connection=connection,
122
+ embedding=embedding,
123
+ vector_key=vector_key,
124
+ id_key=id_key,
125
+ text_key=text_key,
126
+ image_path_key=image_path_key,
127
+ table_name=table_name,
128
+ )
129
+ instance.add_text_image_pairs(texts, image_paths, metadatas=metadatas, **kwargs)
130
+
131
+ return instance