Update README.md
Browse files
README.md
CHANGED
@@ -35,5 +35,101 @@ tags:
|
|
35 |
|
36 |
## How to Test tinyvvision:
|
37 |
Check out the provided inference script to easily test your own shapes and captions. Feel free to challenge tinyvvision with new, unseen combinations to explore its generalization capability!
|
|
|
|
|
|
|
38 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
39 |
✨ **Enjoy experimenting!** ✨
|
|
|
35 |
|
36 |
## How to Test tinyvvision:
|
37 |
Check out the provided inference script to easily test your own shapes and captions. Feel free to challenge tinyvvision with new, unseen combinations to explore its generalization capability!
|
38 |
+
'from huggingface_hub import hf_hub_download
|
39 |
+
import torch, re, numpy as np, math
|
40 |
+
from PIL import Image, ImageDraw, ImageFont
|
41 |
|
42 |
+
repo = "ProCreations/tinyvvision"
|
43 |
+
pth = hf_hub_download(repo, "cortexclip-mini.pth")
|
44 |
+
device = torch.device("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
|
45 |
+
state = torch.load(pth, map_location=device)
|
46 |
+
idx2tok = state["vocab"]
|
47 |
+
tok2idx = {t:i for i,t in enumerate(idx2tok)}
|
48 |
+
def encode_txt(s, maxlen=16):
|
49 |
+
toks = re.findall(r"\w+|[^\w\s]", s.lower())
|
50 |
+
ids = [tok2idx.get(t,0) for t in toks][:maxlen]
|
51 |
+
return ids + [0]*(maxlen-len(ids))
|
52 |
+
class TE(torch.nn.Module):
|
53 |
+
def __init__(self):
|
54 |
+
super().__init__()
|
55 |
+
self.emb = torch.nn.Embedding(len(idx2tok), 64)
|
56 |
+
self.gru = torch.nn.GRU(64, 128, num_layers=2, bidirectional=True, batch_first=True)
|
57 |
+
self.out_proj = torch.nn.Linear(256, 128)
|
58 |
+
def forward(self, x):
|
59 |
+
e, _ = self.gru(self.emb(x))
|
60 |
+
return self.out_proj(e[:, -1])
|
61 |
+
class IE(torch.nn.Module):
|
62 |
+
def __init__(self):
|
63 |
+
super().__init__()
|
64 |
+
self.conv = torch.nn.Sequential(
|
65 |
+
torch.nn.Conv2d(3,32,5,1,2), torch.nn.ReLU(),
|
66 |
+
torch.nn.Conv2d(32,64,3,1,1), torch.nn.ReLU(),
|
67 |
+
torch.nn.Conv2d(64,128,3,1,1), torch.nn.ReLU(),
|
68 |
+
torch.nn.AdaptiveAvgPool2d((4,4)), torch.nn.Flatten(),
|
69 |
+
torch.nn.Linear(128*4*4,128), torch.nn.ReLU()
|
70 |
+
)
|
71 |
+
def forward(self, x): return self.conv(x)
|
72 |
+
te, ie = TE().to(device), IE().to(device)
|
73 |
+
te.load_state_dict(state["text_encoder"])
|
74 |
+
ie.load_state_dict(state["image_encoder"])
|
75 |
+
te.eval(); ie.eval()
|
76 |
+
|
77 |
+
# ----- CUSTOMIZE YOUR EXAMPLES HERE -----
|
78 |
+
# To try your own image:
|
79 |
+
# 1. Replace the 'custom_image()' function with your image drawing/loading code.
|
80 |
+
# 2. Replace 'custom_caption' with your own caption for the image.
|
81 |
+
def custom_image():
|
82 |
+
# Example: Draw your own "blue hexagon" shape below!
|
83 |
+
img = Image.new("RGB",(64,64),"white")
|
84 |
+
dr = ImageDraw.Draw(img)
|
85 |
+
dr.regular_polygon((32,32,22), n_sides=6, fill="blue")
|
86 |
+
arr = np.array(img).astype(np.float32)/255.0
|
87 |
+
return torch.from_numpy(arr).permute(2,0,1).unsqueeze(0).to(device)
|
88 |
+
custom_caption = "a blue hexagon"
|
89 |
+
|
90 |
+
# ----- FUN DEMO EXAMPLES -----
|
91 |
+
def draw_red_heart():
|
92 |
+
img = Image.new("RGB",(64,64),"white")
|
93 |
+
dr = ImageDraw.Draw(img)
|
94 |
+
dr.polygon([(32,18),(50,34),(32,56),(14,34)], fill="red") # simple heart
|
95 |
+
dr.ellipse((18,12,32,32), fill="red")
|
96 |
+
dr.ellipse((32,12,46,32), fill="red")
|
97 |
+
arr = np.array(img).astype(np.float32)/255.0
|
98 |
+
return torch.from_numpy(arr).permute(2,0,1).unsqueeze(0).to(device)
|
99 |
+
def draw_purple_star():
|
100 |
+
img = Image.new("RGB",(64,64),"white")
|
101 |
+
dr = ImageDraw.Draw(img)
|
102 |
+
points = [ (32+20*math.cos(math.radians(a)),32+20*math.sin(math.radians(a))) for a in range(-90, 270, 72) ]
|
103 |
+
for i in range(5):
|
104 |
+
dr.line([points[i], points[(i+2)%5]], fill="purple", width=7)
|
105 |
+
arr = np.array(img).astype(np.float32)/255.0
|
106 |
+
return torch.from_numpy(arr).permute(2,0,1).unsqueeze(0).to(device)
|
107 |
+
def draw_orange_pentagon():
|
108 |
+
img = Image.new("RGB",(64,64),"white")
|
109 |
+
dr = ImageDraw.Draw(img)
|
110 |
+
dr.regular_polygon((32,32,22), n_sides=5, fill="orange")
|
111 |
+
arr = np.array(img).astype(np.float32)/255.0
|
112 |
+
return torch.from_numpy(arr).permute(2,0,1).unsqueeze(0).to(device)
|
113 |
+
|
114 |
+
demo_imgs = [
|
115 |
+
(custom_image(), custom_caption),
|
116 |
+
(draw_red_heart(), "a red heart"),
|
117 |
+
(draw_purple_star(), "a purple star"),
|
118 |
+
(draw_orange_pentagon(), "an orange pentagon"),
|
119 |
+
]
|
120 |
+
captions = [c for (_,c) in demo_imgs]
|
121 |
+
img_tensors = [im for (im,_) in demo_imgs]
|
122 |
+
cap_ids = torch.tensor([encode_txt(c) for c in captions], device=device)
|
123 |
+
|
124 |
+
with torch.no_grad():
|
125 |
+
txt_emb = te(cap_ids)
|
126 |
+
for i, (img, caption) in enumerate(zip(img_tensors, captions)):
|
127 |
+
im_emb = ie(img)
|
128 |
+
sim = torch.nn.functional.cosine_similarity(im_emb, txt_emb).cpu().numpy()
|
129 |
+
rank = int(np.argmax(sim))
|
130 |
+
print(f"Input image {i+1}: '{caption}'")
|
131 |
+
print(" Similarity scores:")
|
132 |
+
for j, c in enumerate(captions):
|
133 |
+
print(f" {c}: {sim[j]:.4f}")
|
134 |
+
print(" Best match:", captions[rank], "\n")'
|
135 |
✨ **Enjoy experimenting!** ✨
|