File size: 3,673 Bytes
a1fde91
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
355d287
 
a1fde91
355d287
84f3f84
355d287
a1fde91
 
 
94c8468
355d287
94c8468
 
 
355d287
94c8468
355d287
 
 
a1fde91
 
 
 
 
 
 
57d4ed7
94c8468
a1fde91
 
03038d2
94c8468
a1fde91
94c8468
e59dcf6
a1fde91
355d287
e59dcf6
a1fde91
e59dcf6
 
a1fde91
e59dcf6
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
# import gradio as gr
# import streamlit as st
# import torch 
# import re 
# from transformers import AutoTokenizer, ViTFeatureExtractor, VisionEncoderDecoderModel 

# device='cpu'
# encoder_checkpoint = "ydshieh/vit-gpt2-coco-en"
# decoder_checkpoint = "ydshieh/vit-gpt2-coco-en"
# model_checkpoint = "ydshieh/vit-gpt2-coco-eng"
# feature_extractor = ViTFeatureExtractor.from_pretrained(encoder_checkpoint)
# tokenizer = AutoTokenizer.from_pretrained(decoder_checkpoint)
# model = VisionEncoderDecoderModel.from_pretrained(model_checkpoint).to(device)

# def predict(image,max_length=64, num_beams=4):
#     input_image = Image.open(image)
#     model.eval()
#     pixel_values = feature_extractor(images=[input_image], return_tensors="pt").pixel_values
#     with torch.no_grad():
#         output_ids = model.generate(pixel_values, max_length=16, num_beams=4, return_dict_in_generate=True).sequences
#     preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
#     preds = [pred.strip() for pred in preds]  
#     return preds[0]
    
#   # image = image.convert('RGB')
#   # image = feature_extractor(image, return_tensors="pt").pixel_values.to(device)
#   # clean_text = lambda x: x.replace('<|endoftext|>','').split('\n')[0]
#   # caption_ids = model.generate(image, max_length = max_length)[0]
#   # caption_text = clean_text(tokenizer.decode(caption_ids))
#   # return caption_text 

# # st.title("Image to Text using Lora")

# inputs = gr.inputs.Image(label="Upload any Image", type = 'pil', optional=True)
# output = gr.outputs.Textbox(type="text",label="Captions")
# description = "NTT Data Bilbao team"
# title = "Image to Text using Lora"

# interface = gr.Interface(
#         fn=predict,
#         description=description,
#         inputs = inputs,
#         theme="grass",
#         outputs=output,
#         title=title,
#     )
# interface.launch(debug=True)

import torch 
import re 
import gradio as gr
from transformers import AutoTokenizer, ViTFeatureExtractor, VisionEncoderDecoderModel 

device='cpu'
encoder_checkpoint = "nlpconnect/vit-gpt2-image-captioning"
decoder_checkpoint = "nlpconnect/vit-gpt2-image-captioning"
model_checkpoint = "nlpconnect/vit-gpt2-image-captioning"
print("------------------------- 1 -------------------------\n")
feature_extractor = ViTFeatureExtractor.from_pretrained(encoder_checkpoint)
print("------------------------- 2 -------------------------\n")
tokenizer = AutoTokenizer.from_pretrained(decoder_checkpoint
print("------------------------- 3 -------------------------\n")
model = VisionEncoderDecoderModel.from_pretrained(model_checkpoint).to(device)
print("------------------------- 4 -------------------------\n")


def predict(image,max_length=64, num_beams=4):
  image = image.convert('RGB')
  image = feature_extractor(image, return_tensors="pt").pixel_values.to(device)
  clean_text = lambda x: x.replace('<|endoftext|>','').split('\n')[0]
  caption_ids = model.generate(image, max_length = max_length)[0]
  caption_text = clean_text(tokenizer.decode(caption_ids))
  return caption_text 


print("------------------------- 5 -------------------------\n")
input = gr.inputs.Image(label="Upload any Image", type = 'pil', optional=True)
output = gr.outputs.Textbox(type="auto",label="Captions")
examples = ["example1.jpg"]
print("------------------------- 6 -------------------------\n")
title = "Image Captioning "
description = "NTT Data"
interface = gr.Interface(
            
        fn=predict,
        description=description,
        inputs = input,
        theme="grass",
        outputs=output,
        examples = examples,
        title=title,
    )
interface.launch(debug=True)