Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from datasets import load_dataset | |
| import random | |
| import numpy as np | |
| from transformers import CLIPProcessor, CLIPModel | |
| from os import environ | |
| import clip | |
| import pickle | |
| import requests | |
| import torch | |
| import os | |
| from huggingface_hub import hf_hub_download | |
| from torch import nn | |
| import torch.nn.functional as nnf | |
| import sys | |
| from typing import Tuple, List, Union, Optional | |
| from transformers import GPT2Tokenizer, GPT2LMHeadModel, AdamW, get_linear_schedule_with_warmup | |
| N = type(None) | |
| V = np.array | |
| ARRAY = np.ndarray | |
| ARRAYS = Union[Tuple[ARRAY, ...], List[ARRAY]] | |
| VS = Union[Tuple[V, ...], List[V]] | |
| VN = Union[V, N] | |
| VNS = Union[VS, N] | |
| T = torch.Tensor | |
| TS = Union[Tuple[T, ...], List[T]] | |
| TN = Optional[T] | |
| TNS = Union[Tuple[TN, ...], List[TN]] | |
| TSN = Optional[TS] | |
| TA = Union[T, ARRAY] | |
| D = torch.device | |
| CPU = torch.device('cpu') | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| # # Load the pre-trained model and processor | |
| clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") | |
| clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") | |
| #orig_clip_model, orig_clip_processor = clip.load("ViT-B/32", device=device, jit=False) | |
| # Load the Unsplash dataset | |
| dataset = load_dataset("jamescalam/unsplash-25k-photos", split="train") # all 25K images are in train split | |
| dataset_size = len(dataset) | |
| # Load gpt and modifed weights for captions | |
| gpt = GPT2LMHeadModel.from_pretrained('gpt2') | |
| tokenizer = GPT2Tokenizer.from_pretrained("gpt2") | |
| conceptual_weight = hf_hub_download(repo_id="akhaliq/CLIP-prefix-captioning-conceptual-weights", filename="conceptual_weights.pt") | |
| coco_weight = hf_hub_download(repo_id="akhaliq/CLIP-prefix-captioning-COCO-weights", filename="coco_weights.pt") | |
| height = 256 # height for resizing images | |
| def predict(image, labels): | |
| with torch.no_grad(): | |
| inputs = clip_processor(text=[f"a photo of {c}" for c in labels], images=image, return_tensors="pt", padding=True) | |
| outputs = clip_model(**inputs) | |
| logits_per_image = outputs.logits_per_image # this is the image-text similarity score | |
| probs = logits_per_image.softmax(dim=1).cpu().numpy() # we can take the softmax to get the label probabilities | |
| return {k: float(v) for k, v in zip(labels, probs[0])} | |
| # def predict2(image, labels): | |
| # image = orig_clip_processor(image).unsqueeze(0).to(device) | |
| # text = clip.tokenize(labels).to(device) | |
| # with torch.no_grad(): | |
| # image_features = orig_clip_model.encode_image(image) | |
| # text_features = orig_clip_model.encode_text(text) | |
| # logits_per_image, logits_per_text = orig_clip_model(image, text) | |
| # probs = logits_per_image.softmax(dim=-1).cpu().numpy() | |
| # return {k: float(v) for k, v in zip(labels, probs[0])} | |
| def rand_image(): | |
| n = dataset.num_rows | |
| r = random.randrange(0,n) | |
| return dataset[r]["photo_image_url"] + f"?h={height}" # Unsplash allows dynamic requests, including size of image | |
| def set_labels(text): | |
| return text.split(",") | |
| # get_caption = gr.load("ryaalbr/caption", src="spaces", hf_token=environ["api_key"]) | |
| # def generate_text(image, model_name): | |
| # return get_caption(image, model_name) | |
| class MLP(nn.Module): | |
| def forward(self, x: T) -> T: | |
| return self.model(x) | |
| def __init__(self, sizes: Tuple[int, ...], bias=True, act=nn.Tanh): | |
| super(MLP, self).__init__() | |
| layers = [] | |
| for i in range(len(sizes) -1): | |
| layers.append(nn.Linear(sizes[i], sizes[i + 1], bias=bias)) | |
| if i < len(sizes) - 2: | |
| layers.append(act()) | |
| self.model = nn.Sequential(*layers) | |
| class ClipCaptionModel(nn.Module): | |
| def get_dummy_token(self, batch_size: int, device: D) -> T: | |
| return torch.zeros(batch_size, self.prefix_length, dtype=torch.int64, device=device) | |
| def forward(self, tokens: T, prefix: T, mask: Optional[T] = None, labels: Optional[T] = None): | |
| embedding_text = self.gpt.transformer.wte(tokens) | |
| prefix_projections = self.clip_project(prefix).view(-1, self.prefix_length, self.gpt_embedding_size) | |
| embedding_cat = torch.cat((prefix_projections, embedding_text), dim=1) | |
| if labels is not None: | |
| dummy_token = self.get_dummy_token(tokens.shape[0], tokens.device) | |
| labels = torch.cat((dummy_token, tokens), dim=1) | |
| out = self.gpt(inputs_embeds=embedding_cat, labels=labels, attention_mask=mask) | |
| return out | |
| def __init__(self, prefix_length: int, prefix_size: int = 512): | |
| super(ClipCaptionModel, self).__init__() | |
| self.prefix_length = prefix_length | |
| self.gpt = gpt | |
| self.gpt_embedding_size = self.gpt.transformer.wte.weight.shape[1] | |
| if prefix_length > 10: # not enough memory | |
| self.clip_project = nn.Linear(prefix_size, self.gpt_embedding_size * prefix_length) | |
| else: | |
| self.clip_project = MLP((prefix_size, (self.gpt_embedding_size * prefix_length) // 2, self.gpt_embedding_size * prefix_length)) | |
| #clip_model, preprocess = clip.load("ViT-B/32", device=device, jit=False) | |
| def get_caption(img,model_name): | |
| prefix_length = 10 | |
| model = ClipCaptionModel(prefix_length) | |
| if model_name == "COCO": | |
| model_path = coco_weight | |
| else: | |
| model_path = conceptual_weight | |
| model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) | |
| model = model.eval() | |
| model = model.to(device) | |
| input = clip_processor(images=img, return_tensors="pt").to(device) | |
| with torch.no_grad(): | |
| prefix = clip_model.get_image_features(**input) | |
| # image = preprocess(img).unsqueeze(0).to(device) | |
| # with torch.no_grad(): | |
| # prefix = clip_model.encode_image(image).to(device, dtype=torch.float32) | |
| prefix_embed = model.clip_project(prefix).reshape(1, prefix_length, -1) | |
| output = model.gpt.generate(inputs_embeds=prefix_embed, | |
| num_beams=1, | |
| do_sample=False, | |
| num_return_sequences=1, | |
| no_repeat_ngram_size=1, | |
| max_new_tokens = 67, | |
| pad_token_id = tokenizer.eos_token_id, | |
| eos_token_id = tokenizer.encode('.')[0], | |
| renormalize_logits = True) | |
| generated_text_prefix = tokenizer.decode(output[0], skip_special_tokens=True) | |
| return generated_text_prefix[:-1] if generated_text_prefix[-1] == "." else generated_text_prefix #remove period at end if present | |
| # get_images = gr.load("ryaalbr/ImageSearch", src="spaces", hf_token=environ["api_key"]) | |
| # def search_images(text): | |
| # return get_images(text, api_name="images") | |
| emb_filename = 'unsplash-25k-photos-embeddings-indexes.pkl' | |
| with open(emb_filename, 'rb') as emb: | |
| id2url, img_names, img_emb = pickle.load(emb) | |
| def search(search_query): | |
| with torch.no_grad(): | |
| # Encode and normalize the description using CLIP (HF CLIP) | |
| inputs = clip_processor(text=search_query, images=None, return_tensors="pt", padding=True) | |
| text_encoded = clip_model.get_text_features(**inputs) | |
| # # Encode and normalize the description using CLIP (original CLIP) | |
| # text_encoded = orig_clip_model.encode_text(clip.tokenize(search_query)) | |
| # text_encoded /= text_encoded.norm(dim=-1, keepdim=True) | |
| # Retrieve the description vector | |
| text_features = text_encoded.cpu().numpy() | |
| # Compute the similarity between the descrption and each photo using the Cosine similarity | |
| similarities = (text_features @ img_emb.T).squeeze(0) | |
| # Sort the photos by their similarity score | |
| best_photos = similarities.argsort()[::-1] | |
| best_photos = best_photos[:15] | |
| #best_photos = sorted(zip(similarities, range(img_emb.shape[0])), key=lambda x: x[0], reverse=True) | |
| best_photo_ids = img_names[best_photos] | |
| imgs = [] | |
| # Iterate over the top 5 results | |
| for id in best_photo_ids: | |
| id, _ = id.split('.') | |
| url = id2url.get(id, "") | |
| if url == "": continue | |
| img = url + "?h=512" | |
| # r = requests.get(url + "?w=512", stream=True) | |
| # img = Image.open(r.raw) | |
| #credits = f'Photo by <a href="https://unsplash.com/@{photo["photographer_username"]}?utm_source=NaturalLanguageImageSearch&utm_medium=referral">{photo["photographer_first_name"]} {photo["photographer_last_name"]}</a> on <a href="https://unsplash.com/?utm_source=NaturalLanguageImageSearch&utm_medium=referral">Unsplash</a>' | |
| imgs.append(img) | |
| #display(HTML(f'Photo by <a href="https://unsplash.com/@{photo["photographer_username"]}?utm_source=NaturalLanguageImageSearch&utm_medium=referral">{photo["photographer_first_name"]} {photo["photographer_last_name"]}</a> on <a href="https://unsplash.com/?utm_source=NaturalLanguageImageSearch&utm_medium=referral">Unsplash</a>')) | |
| if len(imgs) == 5: break | |
| return imgs | |
| with gr.Blocks() as demo: | |
| with gr.Tab("Classification"): | |
| labels = gr.State([]) # creates hidden component that can store a value and can be used as input/output; here, initial value is an empty list | |
| instructions = """## Instructions: | |
| 1. Enter list of labels separated by commas (or select one of the examples below) | |
| 2. Click **Get Random Image** to grab a random image from dataset | |
| 3. Click **Classify Image** to analyze current image against the labels (including after changing labels) | |
| """ | |
| gr.Markdown(instructions) | |
| with gr.Row(variant="compact"): | |
| label_text = gr.Textbox(show_label=False, placeholder="Enter classification labels").style(container=False) | |
| #submit_btn = gr.Button("Submit").style(full_width=False) | |
| gr.Examples(["spring, summer, fall, winter", | |
| "mountain, city, beach, ocean, desert, forest, valley", | |
| "red, blue, green, white, black, purple, brown", | |
| "person, animal, landscape, something else", | |
| "day, night, dawn, dusk"], inputs=label_text) | |
| with gr.Row(): | |
| with gr.Column(variant="panel"): | |
| im = gr.Image(interactive=False).style(height=height) | |
| with gr.Row(): | |
| get_btn = gr.Button("Get Random Image").style(full_width=False) | |
| class_btn = gr.Button("Classify Image").style(full_width=False) | |
| cf = gr.Label() | |
| #submit_btn.click(fn=set_labels, inputs=label_text) | |
| label_text.change(fn=set_labels, inputs=label_text, outputs=labels) # parse list if changed | |
| label_text.blur(fn=set_labels, inputs=label_text, outputs=labels) # parse list if focus is moved elsewhere; ensures that list is fully parsed before classification | |
| label_text.submit(fn=set_labels, inputs=label_text, outputs=labels) # parse list if user hits enter; ensures that list is fully parsed before classification | |
| get_btn.click(fn=rand_image, outputs=im) | |
| #im.change(predict, inputs=[im, labels], outputs=cf) | |
| class_btn.click(predict, inputs=[im, labels], outputs=cf) | |
| gr.HTML(f"Dataset: <a href='https://github.com/unsplash/datasets' target='_blank'>Unsplash Lite</a>; Number of Images: {dataset_size}") | |
| with gr.Tab("Captioning"): | |
| instructions = """## Instructions: | |
| 1. Click **Get Random Image** to grab a random image from dataset | |
| 1. Click **Create Caption** to generate a caption for the image | |
| 1. Different models can be selected: | |
| * **COCO** generally produces more straight-forward captions, but it is a smaller dataset and therefore struggles to recognize certain objects | |
| * **Conceptual Captions** is a much larger dataset but sometimes produces results that resemble social media posts | |
| """ | |
| gr.Markdown(instructions) | |
| with gr.Row(): | |
| with gr.Column(variant="panel"): | |
| im_cap = gr.Image(interactive=False).style(height=height) | |
| model_name = gr.Radio(choices=["COCO","Conceptual Captions"], type="value", value="COCO", label="Model").style(container=True, item_container = False) | |
| with gr.Row(): | |
| get_btn_cap = gr.Button("Get Random Image").style(full_width=False) | |
| caption_btn = gr.Button("Create Caption").style(full_width=False) | |
| caption = gr.Textbox(label='Caption', elem_classes="caption-text") | |
| get_btn_cap.click(fn=rand_image, outputs=im_cap) | |
| #im_cap.change(generate_text, inputs=im_cap, outputs=caption) | |
| caption_btn.click(get_caption, inputs=[im_cap, model_name], outputs=caption) | |
| gr.HTML(f"Dataset: <a href='https://github.com/unsplash/datasets' target='_blank'>Unsplash Lite</a>; Number of Images: {dataset_size}") | |
| with gr.Tab("Search"): | |
| instructions = """## Instructions: | |
| 1. Enter a search query (or select one of the examples below) | |
| 2. Click **Find Images** to find images that match the query (top 5 are shown in order from left to right) | |
| 3. Keep in mind that the dataset contains mostly nature-focused images""" | |
| gr.Markdown(instructions) | |
| with gr.Column(variant="panel"): | |
| desc = gr.Textbox(show_label=False, placeholder="Enter description").style(container=False) | |
| gr.Examples(["someone holding flowers", | |
| "someone holding pink flowers", | |
| "red fruit in a person's hands", | |
| "an aerial view of forest", | |
| "a waterfall with a rainbow" | |
| ], inputs=desc) | |
| search_btn = gr.Button("Find Images").style(full_width=False) | |
| gallery = gr.Gallery(show_label=False).style(grid=(2,2,3,5)) | |
| search_btn.click(search,inputs=desc, outputs=gallery, postprocess=False) | |
| gr.HTML(f"Dataset: <a href='https://github.com/unsplash/datasets' target='_blank'>Unsplash Lite</a>; Number of Images: {dataset_size}") | |
| demo.launch() |