Fangrui Liu commited on
Commit
0b449a5
·
1 Parent(s): b73f599

add selective db / feat / lang

Browse files
Files changed (1) hide show
  1. app.py +90 -55
app.py CHANGED
@@ -3,7 +3,7 @@ import numpy as np
3
  import base64
4
  from io import BytesIO
5
  from multilingual_clip import pt_multilingual_clip
6
- from transformers import CLIPTokenizerFast, AutoTokenizer
7
  import torch
8
  import logging
9
  from os import environ
@@ -12,30 +12,22 @@ environ['TOKENIZERS_PARALLELISM'] = 'true'
12
 
13
 
14
  db_name_map = {
15
- "Unsplash Photos 25K": "mqdb_demo.unsplash_25k_clip_indexer",
16
- "RSICD: Remote Sensing Images 11K": "mqdb_demo.rsicd_clip_b_32",
17
  }
 
 
 
 
 
18
 
19
  DB_NAME = "mqdb_demo.unsplash_25k_clip_indexer"
20
- MODEL_ID = 'M-CLIP/XLM-Roberta-Large-Vit-B-32'
21
  DIMS = 512
22
  # Ignore some bad links (broken in the dataset already)
23
  BAD_IDS = {'9_9hzZVjV8s', 'RDs0THr4lGs', 'vigsqYux_-8',
24
  'rsJtMXn3p_c', 'AcG-unN00gw', 'r1R_0ZNUcx0'}
25
 
26
 
27
- @st.experimental_singleton(show_spinner=False)
28
- def init_clip():
29
- """ Initialize CLIP Model
30
-
31
- Returns:
32
- Tokenizer: CLIPTokenizerFast (which convert words into embeddings)
33
- """
34
- clip = pt_multilingual_clip.MultilingualCLIP.from_pretrained(MODEL_ID)
35
- tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
36
- return tokenizer, clip
37
-
38
-
39
  @st.experimental_singleton(show_spinner=False)
40
  def init_db():
41
  """ Initialize the Database Connection
@@ -82,15 +74,15 @@ def query(xq, top_k=10):
82
  # Using PREWHERE allows you to do column filter before vector search
83
  xc = st.session_state.index.fetch(f"SELECT id, url, vector,\
84
  distance('topK={top_k}')(vector, {xq_s}) AS dist\
85
- FROM {db_name_map[st.session_state.db_name_ref]} \
86
  PREWHERE id NOT IN ({exclude_list})")
87
  else:
88
  xc = st.session_state.index.fetch(f"SELECT id, url, vector,\
89
  distance('topK={top_k}')(vector, {xq_s}) AS dist\
90
- FROM {db_name_map[st.session_state.db_name_ref]}")
91
  real_xc = st.session_state.index.fetch(f"SELECT id, url, vector,\
92
  distance('topK={top_k}')(vector, {xq_s}) AS dist\
93
- FROM {db_name_map[st.session_state.db_name_ref]}")
94
  top_k = real_xc
95
  xc = [xi for xi in xc if xi['id'] not in st.session_state.meta or
96
  st.session_state.meta[xi['id']] < 1]
@@ -166,38 +158,6 @@ class NormalizingLayer(torch.nn.Module):
166
  return x / torch.norm(x, dim=-1, keepdim=True)
167
 
168
 
169
- def prompt2vec(prompt: str):
170
- """ Convert prompt into a computational vector
171
-
172
- Args:
173
- prompt (str): Text to be tokenized
174
-
175
- Returns:
176
- xq: vector from the tokenizer, representing the original prompt
177
- """
178
- # inputs = tokenizer(prompt, return_tensors='pt')
179
- # out = clip.get_text_features(**inputs)
180
- out = clip.forward(prompt, tokenizer)
181
- xq = out.squeeze(0).cpu().detach().numpy().tolist()
182
- return xq
183
-
184
-
185
- def pil_to_bytes(img):
186
- """ Convert a Pillow image into base64
187
-
188
- Args:
189
- img (PIL.Image): Pillow (PIL) Image
190
-
191
- Returns:
192
- img_bin: image in base64 format
193
- """
194
- with BytesIO() as buf:
195
- img.save(buf, format='jpeg')
196
- img_bin = buf.getvalue()
197
- img_bin = base64.b64encode(img_bin).decode('utf-8')
198
- return img_bin
199
-
200
-
201
  def card(i, url):
202
  return f'<img id="img{i}" src="{url}" width="200px;">'
203
 
@@ -286,6 +246,63 @@ def delete_element(element):
286
  del element
287
 
288
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289
  st.markdown("""
290
  <link
291
  rel="stylesheet"
@@ -323,13 +340,23 @@ messages = [
323
  """
324
  ]
325
 
 
 
 
 
 
 
 
 
326
  with st.spinner("Connecting DB..."):
327
  st.session_state.meta, st.session_state.index = init_db()
328
 
329
  with st.spinner("Loading Models..."):
330
  # Initialize CLIP model
331
  if 'xq' not in st.session_state:
332
- tokenizer, clip = init_clip()
 
 
333
  st.session_state.query_num = 0
334
 
335
  if 'xq' not in st.session_state:
@@ -347,8 +374,15 @@ if 'xq' not in st.session_state:
347
  start = [st.empty(), st.empty(), st.empty(), st.empty(),
348
  st.empty(), st.empty(), st.empty()]
349
  start[0].info(msg)
350
- st.session_state.db_name_ref = start[1].selectbox(
351
- "Select Database:", list(db_name_map.keys()))
 
 
 
 
 
 
 
352
  prompt = start[2].text_input(
353
  "Prompt:", value="", placeholder="Examples: playing corgi, 女人举着雨伞, mouette volant au-dessus de la mer, ガラスの花瓶の花 ...")
354
  if len(prompt) > 0:
@@ -388,7 +422,8 @@ if 'xq' not in st.session_state:
388
  else:
389
  print(f"Input prompt is {prompt}")
390
  # Tokenize the vectors
391
- xq = prompt2vec(prompt)
 
392
  st.session_state.xq = xq
393
  st.session_state.orig_xq = xq
394
  _ = [elem.empty() for elem in start]
 
3
  import base64
4
  from io import BytesIO
5
  from multilingual_clip import pt_multilingual_clip
6
+ from transformers import CLIPTokenizerFast, AutoTokenizer, CLIPModel
7
  import torch
8
  import logging
9
  from os import environ
 
12
 
13
 
14
  db_name_map = {
15
+ "Unsplash Photos 25K": lambda feat: f"mqdb_demo.unsplash_25k_{feat}_indexer",
16
+ "RSICD: Remote Sensing Images 11K": lambda feat: f"mqdb_demo.rsicd_{feat}_b_32",
17
  }
18
+ feat_name_map = {
19
+ 'Vanilla CLIP': "clip",
20
+ 'CLIP finetuned on RSICD': "cliprsicd"
21
+ }
22
+
23
 
24
  DB_NAME = "mqdb_demo.unsplash_25k_clip_indexer"
 
25
  DIMS = 512
26
  # Ignore some bad links (broken in the dataset already)
27
  BAD_IDS = {'9_9hzZVjV8s', 'RDs0THr4lGs', 'vigsqYux_-8',
28
  'rsJtMXn3p_c', 'AcG-unN00gw', 'r1R_0ZNUcx0'}
29
 
30
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  @st.experimental_singleton(show_spinner=False)
32
  def init_db():
33
  """ Initialize the Database Connection
 
74
  # Using PREWHERE allows you to do column filter before vector search
75
  xc = st.session_state.index.fetch(f"SELECT id, url, vector,\
76
  distance('topK={top_k}')(vector, {xq_s}) AS dist\
77
+ FROM {db_name_map[st.session_state.db_name_ref](feat_name_map[st.session_state.feat_name])} \
78
  PREWHERE id NOT IN ({exclude_list})")
79
  else:
80
  xc = st.session_state.index.fetch(f"SELECT id, url, vector,\
81
  distance('topK={top_k}')(vector, {xq_s}) AS dist\
82
+ FROM {db_name_map[st.session_state.db_name_ref](feat_name_map[st.session_state.feat_name])}")
83
  real_xc = st.session_state.index.fetch(f"SELECT id, url, vector,\
84
  distance('topK={top_k}')(vector, {xq_s}) AS dist\
85
+ FROM {db_name_map[st.session_state.db_name_ref](feat_name_map[st.session_state.feat_name])}")
86
  top_k = real_xc
87
  xc = [xi for xi in xc if xi['id'] not in st.session_state.meta or
88
  st.session_state.meta[xi['id']] < 1]
 
158
  return x / torch.norm(x, dim=-1, keepdim=True)
159
 
160
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  def card(i, url):
162
  return f'<img id="img{i}" src="{url}" width="200px;">'
163
 
 
246
  del element
247
 
248
 
249
+ @st.experimental_singleton(show_spinner=False)
250
+ def init_clip_mlang():
251
+ """ Initialize CLIP Model
252
+
253
+ Returns:
254
+ Tokenizer: CLIPTokenizerFast (which convert words into embeddings)
255
+ """
256
+ MODEL_ID = 'M-CLIP/XLM-Roberta-Large-Vit-B-32'
257
+ clip = pt_multilingual_clip.MultilingualCLIP.from_pretrained(MODEL_ID)
258
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
259
+ return tokenizer, clip
260
+
261
+ @st.experimental_singleton(show_spinner=False)
262
+ def init_clip_vanilla():
263
+ """ Initialize CLIP Model
264
+
265
+ Returns:
266
+ Tokenizer: CLIPTokenizerFast (which convert words into embeddings)
267
+ """
268
+ MODEL_ID = "openai/clip-vit-base-patch32"
269
+ tokenizer = CLIPTokenizerFast.from_pretrained(MODEL_ID)
270
+ clip = CLIPModel.from_pretrained(MODEL_ID)
271
+ return tokenizer, clip
272
+
273
+
274
+ @st.experimental_singleton(show_spinner=False)
275
+ def init_clip_rsicd():
276
+ """ Initialize CLIP Model
277
+
278
+ Returns:
279
+ Tokenizer: CLIPTokenizerFast (which convert words into embeddings)
280
+ """
281
+ MODEL_ID = "flax-community/clip-rsicd"
282
+ tokenizer = CLIPTokenizerFast.from_pretrained(MODEL_ID)
283
+ clip = CLIPModel.from_pretrained(MODEL_ID)
284
+ return tokenizer, clip
285
+
286
+
287
+ def prompt2vec_mlang(prompt: str, tokenizer, clip):
288
+ """ Convert prompt into a computational vector
289
+
290
+ Args:
291
+ prompt (str): Text to be tokenized
292
+
293
+ Returns:
294
+ xq: vector from the tokenizer, representing the original prompt
295
+ """
296
+ out = clip.forward(prompt, tokenizer)
297
+ xq = out.squeeze(0).cpu().detach().numpy().tolist()
298
+ return xq
299
+
300
+ def prompt2vec_vanilla(prompt: str, tokenizer, clip):
301
+ inputs = tokenizer(prompt, return_tensors='pt')
302
+ out = clip.get_text_features(**inputs)
303
+ xq = out.squeeze(0).cpu().detach().numpy().tolist()
304
+ return xq
305
+
306
  st.markdown("""
307
  <link
308
  rel="stylesheet"
 
340
  """
341
  ]
342
 
343
+ text_model_map = {
344
+ 'Multi Lingual': {'Vanilla CLIP': [prompt2vec_mlang, ]},
345
+ 'English': {'Vanilla CLIP': [prompt2vec_vanilla, ],
346
+ 'CLIP finetuned on RSICD': [prompt2vec_vanilla, ],
347
+ }
348
+ }
349
+
350
+
351
  with st.spinner("Connecting DB..."):
352
  st.session_state.meta, st.session_state.index = init_db()
353
 
354
  with st.spinner("Loading Models..."):
355
  # Initialize CLIP model
356
  if 'xq' not in st.session_state:
357
+ text_model_map['Multi Lingual']['Vanilla CLIP'].append(init_clip_mlang())
358
+ text_model_map['English']['Vanilla CLIP'].append(init_clip_vanilla())
359
+ text_model_map['English']['CLIP finetuned on RSICD'].append(init_clip_rsicd())
360
  st.session_state.query_num = 0
361
 
362
  if 'xq' not in st.session_state:
 
374
  start = [st.empty(), st.empty(), st.empty(), st.empty(),
375
  st.empty(), st.empty(), st.empty()]
376
  start[0].info(msg)
377
+ start_col = start[1].columns(3)
378
+ st.session_state.db_name_ref = start_col[0].selectbox("Select Database:", list(db_name_map.keys()))
379
+ st.session_state.lang = start_col[1].selectbox("Select Language:", list(text_model_map.keys()))
380
+ st.session_state.feat_name = start_col[2].selectbox("Select Image Feature:",
381
+ list(text_model_map[st.session_state.lang].keys()))
382
+ if st.session_state.db_name_ref == "RSICD: Remote Sensing Images 11K":
383
+ st.warning('If you are searching for Remote Sensing Images, \
384
+ try to use prompt "An aerial photograph of <your-real-query>" \
385
+ to obtain best search experience!')
386
  prompt = start[2].text_input(
387
  "Prompt:", value="", placeholder="Examples: playing corgi, 女人举着雨伞, mouette volant au-dessus de la mer, ガラスの花瓶の花 ...")
388
  if len(prompt) > 0:
 
422
  else:
423
  print(f"Input prompt is {prompt}")
424
  # Tokenize the vectors
425
+ p2v_func, args = text_model_map[st.session_state.lang][st.session_state.feat_name]
426
+ xq = p2v_func(prompt, *args)
427
  st.session_state.xq = xq
428
  st.session_state.orig_xq = xq
429
  _ = [elem.empty() for elem in start]