Spaces:
Runtime error
Runtime error
File size: 2,423 Bytes
3de953a d68b3d4 3de953a d68b3d4 3de953a d68b3d4 3de953a d68b3d4 3de953a d68b3d4 3de953a d68b3d4 3de953a d68b3d4 3de953a d68b3d4 3de953a d68b3d4 3de953a d68b3d4 3de953a d68b3d4 3de953a d68b3d4 3de953a d68b3d4 3de953a d68b3d4 3de953a d68b3d4 3de953a d68b3d4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
import os
import sys
import json
import torch
import kelip
import gradio as gr
def load_model():
model, preprocess_img, tokenizer = kelip.build_model("ViT-B/32")
device = "cuda" if torch.cuda.is_available() else "cpu"
model = model.to(device)
model.eval()
model_dict = {
"model": model,
"preprocess_img": preprocess_img,
"tokenizer": tokenizer,
}
return model_dict
def classify(img, user_text):
preprocess_img = model_dict["preprocess_img"]
input_img = preprocess_img(img).unsqueeze(0)
device = "cuda" if torch.cuda.is_available() else "cpu"
input_img = input_img.to(device)
# extract image features
with torch.no_grad():
image_features = model_dict["model"].encode_image(input_img)
# extract text features
user_texts = user_text.split(",")
if user_text == "" or user_text.isspace():
user_texts = []
input_texts = model_dict["tokenizer"].encode(user_texts)
if torch.cuda.is_available():
input_texts = input_texts.cuda()
text_features = model_dict["model"].encode_text(input_texts)
# l2 normalize
image_features /= image_features.norm(dim=-1, keepdim=True)
text_features /= text_features.norm(dim=-1, keepdim=True)
similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
values, indices = similarity[0].topk(len(user_texts))
result = {}
for value, index in zip(values, indices):
result[user_texts[index]] = value.item()
return result
if __name__ == "__main__":
global model_dict
model_dict = load_model()
inputs = [
gr.inputs.Image(type="pil", label="Image"),
gr.inputs.Textbox(lines=5, label="Caption"),
]
outputs = ["label"]
title = "KELIP"
description = "Zero-shot classification with KELIP -- Korean and English bilingual contrastive Language-Image Pre-training model that is trained with collected 1.1 billion image-text pairs (708 million Korean and 476 million English).<br> <br><a href='https://arxiv.org/abs/2203.14463' target='_blank'>Arxiv</a> | <a href='https://github.com/navervision/KELIP' target='_blank'>Github</a>"
article = ""
iface = gr.Interface(
fn=classify,
inputs=inputs,
outputs=outputs,
title=title,
description=description,
article=article,
)
iface.launch()
|