abd-meda commited on
Commit
3b3134e
·
1 Parent(s): 6bea929
Files changed (1) hide show
  1. app.py +109 -0
app.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import re
4
+ import gradio as gr
5
+ from pathlib import Path
6
+ from transformers import AutoTokenizer, AutoFeatureExtractor, VisionEncoderDecoderModel
7
+
8
+
9
+ # Pattern to ignore all the text after 2 or more full stops
10
+ regex_pattern = "[.]{2,}"
11
+
12
+
13
+ def post_process(text):
14
+ try:
15
+ text = text.strip()
16
+ text = re.split(regex_pattern, text)[0]
17
+ except Exception as e:
18
+ print(e)
19
+ pass
20
+ return text
21
+
22
+
23
+ def set_example_image(example: list) -> dict:
24
+ return gr.Image.update(value=example[0])
25
+
26
+
27
+ def predict(image, max_length=64, num_beams=4):
28
+ pixel_values = feature_extractor(images=image, return_tensors="pt").pixel_values
29
+ pixel_values = pixel_values.to(device)
30
+
31
+ with torch.no_grad():
32
+ output_ids = model.generate(
33
+ pixel_values,
34
+ max_length=max_length,
35
+ num_beams=num_beams,
36
+ return_dict_in_generate=True,
37
+ ).sequences
38
+
39
+ preds = tokenizer.batch_decode(output_ids, skip_special_tokens=True)
40
+ pred = post_process(preds[0])
41
+
42
+ return pred
43
+
44
+
45
+ model_name_or_path = "deepklarity/poster2plot"
46
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
47
+
48
+ # Load model.
49
+
50
+ model = VisionEncoderDecoderModel.from_pretrained(model_name_or_path)
51
+ model.to(device)
52
+ print("Loaded model")
53
+
54
+ feature_extractor = AutoFeatureExtractor.from_pretrained(model.encoder.name_or_path)
55
+ print("Loaded feature_extractor")
56
+
57
+ tokenizer = AutoTokenizer.from_pretrained(model.decoder.name_or_path, use_fast=True)
58
+ if model.decoder.name_or_path == "gpt2":
59
+ tokenizer.pad_token = tokenizer.eos_token
60
+
61
+ print("Loaded tokenizer")
62
+
63
+ title = "Poster2Plot: Upload a Movie/T.V show poster to generate a plot"
64
+ description = ""
65
+
66
+ input = gr.inputs.Image(type="pil")
67
+
68
+ example_images = sorted(
69
+ [f.as_posix() for f in Path("examples").glob("*.jpg")]
70
+ )
71
+ print(f"Loaded {len(example_images)} example images")
72
+
73
+ demo = gr.Blocks()
74
+ filenames = next(os.walk('examples'), (None, None, []))[2]
75
+ examples = [[f"examples/{filename}"] for filename in filenames]
76
+ print(examples)
77
+
78
+ with demo:
79
+ with gr.Column():
80
+ with gr.Row():
81
+ with gr.Column():
82
+ input_image = gr.Image()
83
+ with gr.Row():
84
+ clear_button = gr.Button(value="Clear", variant='secondary')
85
+ submit_button = gr.Button(value="Submit", variant='primary')
86
+ with gr.Column():
87
+ plot = gr.Textbox()
88
+ with gr.Row():
89
+ example_images = gr.Dataset(components=[input_image], samples=examples)
90
+
91
+ submit_button.click(fn=predict, inputs=[input_image], outputs=[plot])
92
+ example_images.click(fn=set_example_image, inputs=[example_images], outputs=example_images.components)
93
+
94
+ demo.launch()
95
+
96
+
97
+ interface = gr.Interface(
98
+ fn=predict,
99
+ inputs=input,
100
+ outputs="textbox",
101
+ title=title,
102
+ description=description,
103
+ examples=example_images,
104
+ examples_per_page=20,
105
+ live=True,
106
+ article='<p>Made by: <a href="https://twitter.com/kartik_godawat" target="_blank" rel="noopener noreferrer">dk-crazydiv</a> and <a href="https://twitter.com/dsr_ai" target="_blank" rel="noopener noreferrer">dsr</a></p>'
107
+ )
108
+
109
+ interface.launch()