File size: 3,603 Bytes
355d287
 
a1fde91
dcdb448
 
355d287
1e1d66d
 
 
84f3f84
355d287
a1fde91
 
 
2d2629a
355d287
b07e7dc
355d287
 
 
9a282a8
a1fde91
 
 
 
 
 
 
917196e
 
 
 
d1ca20c
 
176961e
a1fde91
176961e
 
 
 
 
 
 
 
 
 
 
9a282a8
176961e
 
 
 
 
9a282a8
da13932
87e930a
7b99df8
 
6939269
 
2d2629a
176961e
6939269
176961e
 
9a282a8
 
179fa33
54f8b7a
9a282a8
179fa33
54f8b7a
179fa33
65c6075
 
 
 
 
 
 
01ce9b4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import torch 
import re 
import gradio as gr
import streamlit as st
# st.title("Image Caption Generator")
from transformers import AutoTokenizer, ViTFeatureExtractor, VisionEncoderDecoderModel 
import os
import tensorflow as tf
os.environ['TF_ENABLE_ONEDNN_OPTS'] = '0'

device='cpu'
encoder_checkpoint = "nlpconnect/vit-gpt2-image-captioning"
decoder_checkpoint = "nlpconnect/vit-gpt2-image-captioning"
model_checkpoint = "nlpconnect/vit-gpt2-image-captioning"

feature_extractor = ViTFeatureExtractor.from_pretrained(encoder_checkpoint)
tokenizer = AutoTokenizer.from_pretrained(decoder_checkpoint)
model = VisionEncoderDecoderModel.from_pretrained(model_checkpoint).to(device)


def predict(image, max_length=64, num_beams=4):
  image = image.convert('RGB')
  image = feature_extractor(image, return_tensors="pt").pixel_values.to(device)
  clean_text = lambda x: x.replace('<|endoftext|>','').split('\n')[0]
  caption_ids = model.generate(image, max_length = max_length)[0]
  caption_text = clean_text(tokenizer.decode(caption_ids))
  return caption_text 

input = gr.inputs.Image(label="Upload any Image", type = 'pil', optional=True)
output = gr.outputs.Textbox(type="text",label="Captions")
examples = ["example1.jpg"]
print("------------------------- 6 -------------------------\n")
title = "Image to Text ViT with LORA"

# interface = gr.Interface(
            
#         fn=predict,
#         description=description,
#         inputs = input,
#         theme="grass",
#         outputs=output,
#         examples=examples,
#         title=title,
#     )
# interface.launch(debug=True)

with gr.Blocks() as demo:
    

    gr.HTML(
        """
        <div style="text-align: center; max-width: 1200px; margin: 20px auto;">
        <h1 style="font-weight: 900; font-size: 3rem; margin: 0rem">
            ViT Image-to-Text with LORA
        </h1>   
        <h2 style="text-align: left; font-weight: 450; font-size: 1rem; margin-top: 2rem; margin-bottom: 1.5rem">
        In the field of large language models, the challenge of fine-tuning has long perplexed researchers. Microsoft, however, has unveiled an innovative solution called <b>Low-Rank Adaptation (LoRA)</b>. With the emergence of behemoth models like GPT-3 boasting billions of parameters, the cost of fine-tuning them for specific tasks or domains has become exorbitant.
        LoRA offers a groundbreaking approach by freezing the weights of pre-trained models and introducing trainable layers known as <b>rank-decomposition matrices in each transformer block</b>. This ingenious technique significantly reduces the number of trainable parameters and minimizes GPU memory requirements, as gradients no longer need to be computed for the majority of model weights.        
        <br>
        <br>
        You can find more info here: <a href="https://www.linkedin.com/pulse/fine-tuning-image-to-text-algorithms-with-lora-daniel-puente-viejo" target="_blank";>Linkedin article</a>
        </h2>
        
        </div>
        """)
    with gr.Row():
            with gr.Column(scale=1):
                img = gr.inputs.Image(label="Upload any Image", type = 'pil', optional=True)
                button = gr.Button(value="Describe")
            with gr.Column(scale=1):
                out = gr.outputs.Textbox(type="text",label="Captions")   
                
    button.click(predict, inputs=[img], outputs=[out])
    gr.Examples(
        examples=[os.path.join(os.path.dirname(__file__), "example1.jpg")],
        inputs=img,
        outputs=out,
        fn=predict,
        cache_examples=True,
    )
demo.launch(debug=True)