Chintan-Shah commited on
Commit
27d2338
·
verified ·
1 Parent(s): f99d2b8

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -0
app.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import clip
3
+ import torch
4
+ from torchvision.datasets import CIFAR100
5
+ from PIL import Image
6
+
7
+ # Load the model
8
+ device = "cuda" if torch.cuda.is_available() else "cpu"
9
+ model, preprocess = clip.load('ViT-B/32', device)
10
+
11
+ # Download the dataset
12
+ cifar100 = CIFAR100(root=os.path.expanduser("~/.cache"), download=True, train=False)
13
+ text_inputs = torch.cat([clip.tokenize(f"a photo of a {c}") for c in cifar100.classes]).to(device)
14
+
15
+ def generateOutput(source):
16
+ # Prepare the inputs
17
+ # image, class_id = cifar100[3637]
18
+ image_input = preprocess(source).unsqueeze(0).to(device)
19
+
20
+ with torch.no_grad():
21
+ image_features = model.encode_image(image_input)
22
+ text_features = model.encode_text(text_inputs)
23
+
24
+ # Pick the top 5 most similar labels for the image
25
+ image_features /= image_features.norm(dim=-1, keepdim=True)
26
+ text_features /= text_features.norm(dim=-1, keepdim=True)
27
+ similarity = (100.0 * image_features @ text_features.T).softmax(dim=-1)
28
+ values, indices = similarity[0].topk(5)
29
+
30
+ # Result in Text
31
+ outputText = "\nTop predictions:\n"
32
+ for value, index in zip(values, indices):
33
+ outputText = outputText + f"{cifar100.classes[index]:>16s}: {100 * value.item():.2f}% \n"
34
+
35
+ return(outputText)
36
+
37
+ title = "CLIP Classification Inference Trials"
38
+ description = "Shows the CLIP Classification based on CIFAR100 data with your own image"
39
+ examples = [["Elephants.jpg"],["941398-beautiful-farm-animals-wallpaper-2000x1402-for-meizu.jpg"], ["Puppies.jpg"], ["photo2.JPG"], ["MultipleItems.jpg"]]
40
+ demo = gr.Interface(
41
+ generateOutput,
42
+ inputs = [
43
+ gr.Image(width=256, height=256, label="Input Image"),
44
+ ],
45
+ outputs = [
46
+ gr.Text(),
47
+ ],
48
+ title = title,
49
+ description = description,
50
+ examples = examples,
51
+ cache_examples=False
52
+ )
53
+ demo.launch()