renatotn7 aps commited on
Commit
52ef2c1
·
0 Parent(s):

Duplicate from flava/flava-multimodal-zero-shot

Browse files

Co-authored-by: Amanpreet Singh <aps@users.noreply.huggingface.co>

Files changed (10) hide show
  1. .gitattributes +30 -0
  2. .gitignore +2 -0
  3. README.md +13 -0
  4. app.py +131 -0
  5. cows.jpg +0 -0
  6. dog.jpg +3 -0
  7. germany.jpg +3 -0
  8. requirements.txt +4 -0
  9. rocket.jpg +3 -0
  10. sofa.jpg +0 -0
.gitattributes ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bin.* filter=lfs diff=lfs merge=lfs -text
5
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.model filter=lfs diff=lfs merge=lfs -text
12
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
13
+ *.onnx filter=lfs diff=lfs merge=lfs -text
14
+ *.ot filter=lfs diff=lfs merge=lfs -text
15
+ *.parquet filter=lfs diff=lfs merge=lfs -text
16
+ *.pb filter=lfs diff=lfs merge=lfs -text
17
+ *.pt filter=lfs diff=lfs merge=lfs -text
18
+ *.pth filter=lfs diff=lfs merge=lfs -text
19
+ *.rar filter=lfs diff=lfs merge=lfs -text
20
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
21
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
22
+ *.tflite filter=lfs diff=lfs merge=lfs -text
23
+ *.tgz filter=lfs diff=lfs merge=lfs -text
24
+ *.xz filter=lfs diff=lfs merge=lfs -text
25
+ *.zip filter=lfs diff=lfs merge=lfs -text
26
+ *.zstandard filter=lfs diff=lfs merge=lfs -text
27
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
28
+ dog.jpg filter=lfs diff=lfs merge=lfs -text
29
+ germany.jpg filter=lfs diff=lfs merge=lfs -text
30
+ rocket.jpg filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ venv/
2
+ transformers/
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: FLAVA MultiModal Zero Shot
3
+ emoji: 🤖
4
+ colorFrom: red
5
+ colorTo: gray
6
+ sdk: gradio
7
+ sdk_version: 3.0.5
8
+ app_file: app.py
9
+ pinned: false
10
+ duplicated_from: flava/flava-multimodal-zero-shot
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces#reference
app.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import gradio as gr
3
+ import torch
4
+
5
+ from transformers import BertTokenizer, FlavaForPreTraining, FlavaModel, FlavaFeatureExtractor, FlavaProcessor
6
+ from PIL import Image
7
+
8
+
9
+ demo = gr.Blocks()
10
+
11
+ tokenizer = BertTokenizer.from_pretrained("facebook/flava-full")
12
+ flava_pt = FlavaForPreTraining.from_pretrained("facebook/flava-full")
13
+ flava = FlavaModel.from_pretrained("facebook/flava-full")
14
+ processor = FlavaProcessor.from_pretrained("facebook/flava-full")
15
+ fe = FlavaFeatureExtractor.from_pretrained("facebook/flava-full")
16
+
17
+
18
+ PREDICTION_ATTR = "mlm_logits"
19
+
20
+ def zero_shot_text(text, options):
21
+ options = [option.strip() for option in options.split(";")]
22
+ option_indices = tokenizer.convert_tokens_to_ids(options)
23
+ tokens = tokenizer([text], return_tensors="pt")
24
+ mask_ids = tokens["input_ids"][0] == 103
25
+ with torch.no_grad():
26
+ output = flava_pt(**tokens)
27
+
28
+ text_logits = getattr(output, PREDICTION_ATTR)
29
+ probs = text_logits[0, mask_ids, option_indices].view(-1, len(option_indices)).mean(dim=0)
30
+ probs = torch.nn.functional.softmax(probs, dim=-1)
31
+ return {label: probs[idx].item() for idx, label in enumerate(options)}
32
+
33
+
34
+ def zero_shot_image(image, options):
35
+ PIL_image = Image.fromarray(np.uint8(image)).convert("RGB")
36
+ labels = [label.strip() for label in options.split(";")]
37
+ image_input = fe([PIL_image], return_tensors="pt")
38
+ text_inputs = tokenizer(
39
+ labels, padding="max_length", return_tensors="pt"
40
+ )
41
+
42
+ image_embeddings = flava.get_image_features(**image_input)[:, 0, :]
43
+ text_embeddings = flava.get_text_features(**text_inputs)[:, 0, :]
44
+ similarities = list(
45
+ torch.nn.functional.softmax(
46
+ (text_embeddings @ image_embeddings.T).squeeze(0), dim=0
47
+ )
48
+ )
49
+ return {label: similarities[idx].item() for idx, label in enumerate(labels)}
50
+
51
+ def zero_shot_multimodal(image, text, options):
52
+ options = [option.strip() for option in options.split(";")]
53
+ option_indices = tokenizer.convert_tokens_to_ids(options)
54
+ tokens = processor([image], [text], return_tensors="pt", return_codebook_pixels=True, return_image_mask=True)
55
+
56
+ mask_ids = tokens["input_ids"][0] == 103
57
+ tokens["bool_masked_pos"] = torch.ones_like(tokens["bool_masked_pos"])
58
+
59
+ with torch.no_grad():
60
+ output = flava_pt(**tokens)
61
+
62
+ text_logits = getattr(output, "mmm_text_logits")
63
+ probs = text_logits[0, mask_ids, option_indices].view(-1, len(option_indices)).mean(dim=0)
64
+ probs = torch.nn.functional.softmax(probs, dim=-1)
65
+ return {label: probs[idx].item() for idx, label in enumerate(options)}
66
+
67
+ with demo:
68
+ gr.Markdown(
69
+ """
70
+ # Zero-Shot image, text or multimodal classification using the same FLAVA model
71
+
72
+ Click on one the examples provided to load them into the UI and "Classify".
73
+
74
+ - For image classification, provide class options to be ranked separated by `;`.
75
+ - For text and multimodal classification, provide your 1) prompt with the word you want to be filled in as `[MASK]`, and 2) possible options to be ranked separated by `;`.
76
+ """
77
+ )
78
+ with gr.Tabs():
79
+ with gr.TabItem("Zero-Shot Image Classification"):
80
+ with gr.Row():
81
+ with gr.Column():
82
+ image_input = gr.Image()
83
+ text_options_i = gr.Textbox(label="Classes (seperated by ;)")
84
+ image_button = gr.Button("Classify")
85
+ image_dataset = gr.Dataset(
86
+ components=[image_input, text_options_i],
87
+ samples=[
88
+ ["cows.jpg", "a cow; two cows in a green field; a cow in a green field"],
89
+ ["sofa.jpg", "a room with red sofa; a red room with sofa; ladder in a room"]
90
+ ]
91
+ )
92
+
93
+ labels_image = gr.Label(label="Probabilities")
94
+ with gr.TabItem("Zero-Shot Text Classification"):
95
+ with gr.Row():
96
+ with gr.Column():
97
+ text_input = gr.Textbox(label="Prompt")
98
+ text_options = gr.Textbox(label="Label options (separate by ;)")
99
+ text_button = gr.Button("Classify")
100
+ text_dataset = gr.Dataset(
101
+ components=[text_input, text_options],
102
+ samples=[
103
+ ["by far the worst movie of the year. This was [MASK]", "negative; positive"],
104
+ ["Lord Voldemort -- in the films; born Tom Marvolo Riddle) is a fictional character and the main antagonist in J.K. Rowling's series of Harry Potter novels. Voldemort first appeared in Harry Potter and the Philosopher's Stone, which was released in 1997. Voldemort appears either in person or in flashbacks in each book and its film adaptation in the series, except the third, Harry Potter and the Prisoner of Azkaban, where he is only mentioned. Question: are tom riddle and lord voldemort the same person? Answer: [MASK]", "no; yes"],
105
+ ]
106
+ )
107
+ labels_text = gr.Label(label="Probabilities")
108
+ with gr.TabItem("Zero-Shot MultiModal Classification"):
109
+ with gr.Row():
110
+ with gr.Column():
111
+ image_input_mm = gr.Image()
112
+ text_input_mm = gr.Textbox(label="Prompt")
113
+ text_options_mm = gr.Textbox(label="Options (separate by ;)")
114
+ multimodal_button = gr.Button("Classify")
115
+ multimodal_dataset = gr.Dataset(
116
+ components=[image_input_mm, text_input_mm],
117
+ samples=[
118
+ ["cows.jpg", "What animals are in the field? They are [MASK].", "cows; lions; sheep; monkeys"],
119
+ ["sofa.jpg", "What furniture is in the room? It is [MASK].", "sofa; ladder; bucket"]
120
+ ]
121
+ )
122
+ labels_multimodal = gr.Label(label="Probabilities")
123
+
124
+ text_button.click(zero_shot_text, inputs=[text_input, text_options], outputs=labels_text)
125
+ image_button.click(zero_shot_image, inputs=[image_input, text_options_i], outputs=labels_image)
126
+ multimodal_button.click(zero_shot_multimodal, inputs=[image_input_mm, text_input_mm, text_options_mm], outputs=labels_multimodal)
127
+ text_dataset.click(lambda a: a, inputs=[text_dataset], outputs=[text_input, text_options])
128
+ image_dataset.click(lambda a: a, inputs=[image_dataset], outputs=[image_input, text_options_i])
129
+ multimodal_dataset.click(lambda a: a, inputs=[multimodal_dataset], outputs=[image_input_mm, text_input_mm, text_options_mm])
130
+
131
+ demo.launch()
cows.jpg ADDED
dog.jpg ADDED

Git LFS Details

  • SHA256: fd74da609925c348d16a1d5cfdd289f736ed20382ecccdc71195fb78d0c53d94
  • Pointer size: 130 Bytes
  • Size of remote file: 54.5 kB
germany.jpg ADDED

Git LFS Details

  • SHA256: 61ddcea0c3eb97c4511d4b0162e1b80ac7f05f5c1fba3ceb5a6fde7073066869
  • Pointer size: 130 Bytes
  • Size of remote file: 10.6 kB
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ transformers
2
+ torch
3
+ numpy
4
+ pillow
rocket.jpg ADDED

Git LFS Details

  • SHA256: 160c61f4820aff99965bf024632136ee872f493ba33b642f29b6ca0ebdd186ca
  • Pointer size: 131 Bytes
  • Size of remote file: 941 kB
sofa.jpg ADDED