In [2]:
import pandas as pd
import json
from PIL import Image
import numpy as np

In [3]:
import os
import sys
from pathlib import Path

import torch
import torch.nn.functional as F

from src.data.embs import ImageDataset
from src.model.blip_embs import blip_embs

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [5]:
def get_blip_config(model="base"):
    config = dict()
    if model == "base":
        config[
            "pretrained"
        ] = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base_capfilt_large.pth "
        config["vit"] = "base"
        config["batch_size_train"] = 32
        config["batch_size_test"] = 16
        config["vit_grad_ckpt"] = True
        config["vit_ckpt_layer"] = 4
        config["init_lr"] = 1e-5
    elif model == "large":
        config[
            "pretrained"
        ] = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_retrieval_coco.pth"
        config["vit"] = "large"
        config["batch_size_train"] = 16
        config["batch_size_test"] = 32
        config["vit_grad_ckpt"] = True
        config["vit_ckpt_layer"] = 12
        config["init_lr"] = 5e-6

    config["image_size"] = 384
    config["queue_size"] = 57600
    config["alpha"] = 0.4
    config["k_test"] = 256
    config["negative_all_rank"] = True

    return config

In [6]:
print("Creating model")
config = get_blip_config("large")

model = blip_embs(
        pretrained=config["pretrained"],
        image_size=config["image_size"],
        vit=config["vit"],
        vit_grad_ckpt=config["vit_grad_ckpt"],
        vit_ckpt_layer=config["vit_ckpt_layer"],
        queue_size=config["queue_size"],
        negative_all_rank=config["negative_all_rank"],
    )

model = model.to(device)
model.eval()

Creating model
load checkpoint from https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_large_retrieval_coco.pth
missing keys:
[]


BLIPEmbs(
  (visual_encoder): VisionTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 1024, kernel_size=(16, 16), stride=(16, 16))
      (norm): Identity()
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (blocks): ModuleList(
      (0): Block(
        (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=1024, out_features=3072, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=1024, out_features=1024, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (drop_path): Identity()
        (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=1024, out_features=4096, bias=True)
          (act): GELU(approximate='none')
          (fc2): Linear(in_features=4096, out_features=1024, bias=True)
          (drop): Dropout(p=0.0, inplace=False)
        )
  

In [7]:
model 

BLIPEmbs(
  (visual_encoder): VisionTransformer(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(3, 1024, kernel_size=(16, 16), stride=(16, 16))
      (norm): Identity()
    )
    (pos_drop): Dropout(p=0.0, inplace=False)
    (blocks): ModuleList(
      (0): Block(
        (norm1): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
        (attn): Attention(
          (qkv): Linear(in_features=1024, out_features=3072, bias=True)
          (attn_drop): Dropout(p=0.0, inplace=False)
          (proj): Linear(in_features=1024, out_features=1024, bias=True)
          (proj_drop): Dropout(p=0.0, inplace=False)
        )
        (drop_path): Identity()
        (norm2): LayerNorm((1024,), eps=1e-06, elementwise_affine=True)
        (mlp): Mlp(
          (fc1): Linear(in_features=1024, out_features=4096, bias=True)
          (act): GELU(approximate='none')
          (fc2): Linear(in_features=4096, out_features=1024, bias=True)
          (drop): Dropout(p=0.0, inplace=False)
        )
  

In [8]:
# Read all database image features and create a list

In [9]:
df = pd.read_json("datasets/sidechef/my_recipes.json")

In [10]:
df.columns

Index(['recipe_name', 'recipe_time', 'recipe_yields', 'recipe_ingredients',
       'recipe_instructions', 'recipe_image', 'blogger', 'recipe_nutrients',
       'tags', 'id_'],
      dtype='object')

In [24]:
print("Loading Target Embedding")
tar_img_feats = []
for _id in df["id_"].tolist():     
    tar_img_feats.append(torch.load("datasets/sidechef/blip-embs-large/{:07d}.pth".format(_id)).unsqueeze(0))

tar_img_feats = torch.cat(tar_img_feats, dim=0)

Loading Target Embedding


In [12]:
from src.data.transforms import transform_test

transform = transform_test(384)

In [13]:
image = Image.open("datasets/sidechef/images/{:07d}.png".format(3)).convert("RGB")

In [14]:
img = transform(image).unsqueeze(0)

In [15]:
img = img.to(device)

In [16]:
img_embs = model.visual_encoder(img)

In [17]:
img_embs.shape

torch.Size([1, 577, 1024])

In [18]:
img_feats = F.normalize(model.vision_proj(img_embs[:, 0, :]), dim=-1).cpu()

In [19]:
img_feats.shape

torch.Size([1, 256])

In [20]:
tar_img_feats[0].shape

torch.Size([1, 256])

In [21]:
tar_img_feats = torch.cat(tar_img_feats, dim=0)

In [159]:
score = (img_feats @ tar_img_feats.t()).squeeze(0).cpu().detach().numpy()

In [165]:
np.argsort(score)[::-1][0]

2

In [168]:
df.iloc[2+1]

recipe_name                               Farmers Market Breakfast Pizza
recipe_time                                                            0
recipe_yields                                                 2 servings
recipe_ingredients     [1/2 Pizza Dough, 1/2 cup Kale, 1/2 cup Onion,...
recipe_instructions    For homemade pizza sauce, finely chop the Swee...
recipe_image           https://www.sidechef.com/recipe/1cd15944-9411-...
blogger                                                     sidechef.com
recipe_nutrients       {'calories': '315 calories', 'proteinContent':...
tags                   [Breakfast, Brunch, Main Dish, Budget-Friendly...
id_                                                                    4
Name: 3, dtype: object

In [22]:
from transformers import StoppingCriteria, StoppingCriteriaList, TextIteratorStreamer

In [23]:
class StoppingCriteriaSub(StoppingCriteria):

    def __init__(self, stops=[], encounters=1):
        super().__init__()
        self.stops = stops

    def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor):
        for stop in self.stops:
            if torch.all(input_ids[:, -len(stop):] == stop).item():
                return True

        return False

In [None]:
image_path = "datasets/sidechef/images/{:07d}.png".format(3)

In [71]:
class Chat:

    def __init__(self, model, transform, dataframe, tar_img_feats, device='cuda:0', stopping_criteria=None):
        self.device = device
        self.model = model
        self.transform = transform
        self.df = dataframe
        self.tar_img_feats = tar_img_feats
        self.img_feats = None
        self.target_recipe = None
        self.messages = []

        if stopping_criteria is not None:
            self.stopping_criteria = stopping_criteria
        else:
            stop_words_ids = [torch.tensor([2]).to(self.device)]
            self.stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids)])

    def encode_image(self, image_path):
        img = Image.fromarray(image_path).convert("RGB")
        img = self.transform(img).unsqueeze(0)
        img = img.to(self.device)
        img_embs = model.visual_encoder(img)
        img_feats = F.normalize(model.vision_proj(img_embs[:, 0, :]), dim=-1).cpu()

        self.img_feats = img_feats 

        self.get_target(self.img_feats, self.tar_img_feats)

    def get_target(self, img_feats, tar_img_feats) : 
        score = (img_feats @ tar_img_feats.t()).squeeze(0).cpu().detach().numpy()
        index = np.argsort(score)[::-1][0] + 1
        self.target_recipe = df.iloc[index]

    def ask(self, msg):
        if "nutrition" in msg or "nutrients" in msg : 
            return json.dumps(self.target_recipe["recipe_nutrients"], indent=4)
        elif "instruction" in msg :
            return json.dumps(self.target_recipe["recipe_instructions"], indent=4)
        elif "ingredients" in msg :
            return json.dumps(self.target_recipe["recipe_ingredients"], indent=4)
        elif "tag" in msg or "class" in msg :
            return json.dumps(self.target_recipe["tags"], indent=4)
        else:
            return "Conversational capabilities will be included later."


In [72]:
chat = Chat(model,transform,df,tar_img_feats)

In [73]:
import gradio as gr 

In [80]:
example_images = gr.Dataset(components=[image], label="Food Examples",
                    samples=[
                        [os.path.join(os.path.dirname("./"), "./datasets/sidechef/sample_images/0000018.png")],
                        [os.path.join(os.path.dirname("./"), "./datasets/sidechef/sample_images/0000021.png")],
                        [os.path.join(os.path.dirname("./"), "./datasets/sidechef/sample_images/0000035.png")],
                        [os.path.join(os.path.dirname("./"), "./datasets/sidechef/sample_images/0000038.png")],
                        [os.path.join(os.path.dirname("./"), "./datasets/sidechef/sample_images/0000090.png")],
                        [os.path.join(os.path.dirname("./"), "./datasets/sidechef/sample_images/0000122.png")],
                    ])

example_texts = gr.Dataset(components=[gr.Textbox(visible=False)],
                    label="Prompt Examples",
                    samples=[
                        ["Describe the given chest x-ray image in detail."],
                        ["Take a look at this chest x-ray and describe the findings and impression."],
                        ["Could you provide a detailed description of the given x-ray image?"],
                        ["Describe the given chest x-ray image as detailed as possible."],
                        ["What are the key findings in this chest x-ray image?"],
                        ["Could you highlight any abnormalities or concerns in this chest x-ray image?"],
                        ["What specific features of the lungs and heart are visible in this chest x-ray image?"],
                        ["What is the most prominent feature visible in this chest x-ray image, and how is it indicative of the patient's health?"],
                        ["Based on the findings in this chest x-ray image, what is the overall impression?"],
                    ],)



def respond_to_user(image, message):
    # Process the image and message here
    # For demonstration, I'll just return a simple text response
    chat = Chat(model,transform,df,tar_img_feats)
    chat.encode_image(image)
    response = chat.ask(message)
    return response

iface = gr.Interface(
    fn=respond_to_user,
    inputs=[gr.Image(height="70%"), gr.Textbox(label="Ask Query"),],
    outputs=[gr.Textbox(label="Nutrition-GPT")],
    title="Nutrition-GPT Demo",
    description="Upload an food image and ask queries!",
    css=".component-12 {background-color: red}",
        
)

iface.launch()

Running on local URL:  http://127.0.0.1:7874

To create a public link, set `share=True` in `launch()`.




In [82]:

def respond_to_user(image, message):
    # Process the image and message here
    # For demonstration, I'll just return a simple text response
    chat = Chat(model,transform,df,tar_img_feats)
    chat.encode_image(image)
    response = chat.ask(message)
    return response


with gr.Blocks() as demo:
    gr.Markdown("Nutrition-GPT Demo")

    with gr.Row():
        with gr.Column():
            image = gr.Image()
            text_input = gr.Textbox(label='Ask Query')
            submit_button = gr.Button(value="Upload Food Image and Submit Query", interactive=True, variant="primary")
            clear = gr.Button("Reset")
            

        with gr.Column():
            text_output = gr.Textbox(label="Nutrition-GPT")

    
    submit_button.click(respond_to_user, inputs=[image, text_input], outputs=text_output)
            

demo.launch()

Running on local URL:  http://127.0.0.1:7876

To create a public link, set `share=True` in `launch()`.




In [None]:
def set_example_image(example: list) -> dict:
    return gr.Image.update(value=example[0])

def upload_img(gr_img, text_input, chat_state):
    if gr_img is None:
        return None, None, gr.update(interactive=True), chat_state, None
    chat_state = CONV_VISION.copy()
    img_list = []
    llm_message = chat.upload_img(gr_img, chat_state, img_list)
    return gr.update(interactive=False), gr.update(interactive=True, placeholder='Type and press Enter'), gr.update(value="Start Chatting", interactive=False), chat_state, img_list




with gr.Blocks() as demo:
    gr.Markdown("Nutrition-GPT Demo")

    with gr.Row():
        with gr.Column(scale=0.5):
            image = gr.Image(type="pil")
            upload_button = gr.Button(value="Upload Food Image and Ask Queries", interactive=True, variant="primary")
            clear = gr.Button("Reset")
            

        with gr.Column():
            chat_state = gr.State()
            img_list = gr.State()
            chatbot = gr.Chatbot(label='Nutrition-GPT')
            text_input = gr.Textbox(label='User', placeholder='Please upload food image.', interactive=False)


    with gr.Row():
        example_images = gr.Dataset(components=[image], label="X-Ray Examples",
                                    samples=[
                                        [os.path.join(os.path.dirname(__file__), "./datasets/sidechef/sample_images/0000018.png")],
                                        [os.path.join(os.path.dirname(__file__), "./datasets/sidechef/sample_images/0000021.png")],
                                        [os.path.join(os.path.dirname(__file__), "./datasets/sidechef/sample_images/0000035.png")],
                                        [os.path.join(os.path.dirname(__file__), "./datasets/sidechef/sample_images/0000038.png")],
                                        [os.path.join(os.path.dirname(__file__), "./datasets/sidechef/sample_images/0000090.png")],
                                        [os.path.join(os.path.dirname(__file__), "./datasets/sidechef/sample_images/0000122.png")],
                                    ])
        

    with gr.Row():
        example_texts = gr.Dataset(components=[gr.Textbox(visible=False)],
                                    label="Prompt Examples",
                                    samples=[
                                        ["Describe the given chest x-ray image in detail."],
                                        ["Take a look at this chest x-ray and describe the findings and impression."],
                                        ["Could you provide a detailed description of the given x-ray image?"],
                                        ["Describe the given chest x-ray image as detailed as possible."],
                                        ["What are the key findings in this chest x-ray image?"],
                                        ["Could you highlight any abnormalities or concerns in this chest x-ray image?"],
                                        ["What specific features of the lungs and heart are visible in this chest x-ray image?"],
                                        ["What is the most prominent feature visible in this chest x-ray image, and how is it indicative of the patient's health?"],
                                        ["Based on the findings in this chest x-ray image, what is the overall impression?"],
                                    ],)
    
    example_images.click(fn=set_example_image, inputs=example_images, outputs=example_images.components)

    upload_button.click(upload_img, [image, text_input, chat_state], [image, text_input, upload_button, chat_state, img_list])
    
    example_texts.click(set_example_text_input, inputs=example_texts, outputs=text_input).then(
        gradio_ask, [text_input, chatbot, chat_state], [text_input, chatbot, chat_state]).then(
        gradio_answer, [chatbot, chat_state, img_list], [chatbot, chat_state, img_list]
    )
    
    text_input.submit(gradio_ask, [text_input, chatbot, chat_state], [text_input, chatbot, chat_state]).then(
        gradio_answer, [chatbot, chat_state, img_list, num_beams, temperature], [chatbot, chat_state, img_list]
    )
    clear.click(gradio_reset, [chat_state, img_list], [chatbot, image, text_input, upload_button, chat_state, img_list], queue=False)
    
    gr.Markdown(disclaimer)

demo.launch(share=True, enable_queue=True)