vykanand commited on
Commit
699fe26
Β·
verified Β·
1 Parent(s): ed33eff

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +28 -310
app.py CHANGED
@@ -1,313 +1,31 @@
1
  import gradio as gr
2
- import spaces
3
- from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, TextIteratorStreamer
4
- from qwen_vl_utils import process_vision_info
5
- import torch
6
  from PIL import Image
7
- import os
8
- import uuid
9
- import io
10
- from threading import Thread
11
- from reportlab.lib.pagesizes import A4
12
- from reportlab.lib.styles import getSampleStyleSheet
13
- from reportlab.lib import colors
14
- from reportlab.platypus import SimpleDocTemplate, Image as RLImage, Paragraph, Spacer
15
- from reportlab.lib.units import inch
16
- from reportlab.pdfbase import pdfmetrics
17
- from reportlab.pdfbase.ttfonts import TTFont
18
- import docx
19
- from docx.enum.text import WD_ALIGN_PARAGRAPH
20
-
21
- # Define model options
22
- MODEL_OPTIONS = {
23
- "Qwen2VL Base": "Qwen/Qwen2-VL-2B-Instruct",
24
- "Latex OCR": "prithivMLmods/Qwen2-VL-OCR-2B-Instruct",
25
- "Math Prase": "prithivMLmods/Qwen2-VL-Math-Prase-2B-Instruct",
26
- "Text Analogy Ocrtest": "prithivMLmods/Qwen2-VL-Ocrtest-2B-Instruct"
27
- }
28
-
29
- # Preload models and processors into CPU
30
- models = {}
31
- processors = {}
32
- for name, model_id in MODEL_OPTIONS.items():
33
- print(f"Loading {name}...")
34
- models[name] = Qwen2VLForConditionalGeneration.from_pretrained(
35
- model_id,
36
- trust_remote_code=True,
37
- torch_dtype=torch.float32 # Use float32 to ensure CPU usage
38
- ).eval() # No `.to('cuda')`, will run on CPU by default
39
- processors[name] = AutoProcessor.from_pretrained(model_id, trust_remote_code=True)
40
-
41
- image_extensions = Image.registered_extensions()
42
-
43
- def identify_and_save_blob(blob_path):
44
- """Identifies if the blob is an image and saves it."""
45
- try:
46
- with open(blob_path, 'rb') as file:
47
- blob_content = file.read()
48
- try:
49
- Image.open(io.BytesIO(blob_content)).verify() # Check if it's a valid image
50
- extension = ".png" # Default to PNG for saving
51
- media_type = "image"
52
- except (IOError, SyntaxError):
53
- raise ValueError("Unsupported media type. Please upload a valid image.")
54
-
55
- filename = f"temp_{uuid.uuid4()}_media{extension}"
56
- with open(filename, "wb") as f:
57
- f.write(blob_content)
58
-
59
- return filename, media_type
60
-
61
- except FileNotFoundError:
62
- raise ValueError(f"The file {blob_path} was not found.")
63
- except Exception as e:
64
- raise ValueError(f"An error occurred while processing the file: {e}")
65
-
66
- def qwen_inference(model_name, media_input, text_input=None):
67
- """Handles inference for the selected model on CPU."""
68
- model = models[model_name]
69
- processor = processors[model_name]
70
-
71
- if isinstance(media_input, str):
72
- media_path = media_input
73
- if media_path.endswith(tuple([i for i in image_extensions.keys()])):
74
- media_type = "image"
75
- else:
76
- try:
77
- media_path, media_type = identify_and_save_blob(media_input)
78
- except Exception as e:
79
- raise ValueError("Unsupported media type. Please upload a valid image.")
80
-
81
- messages = [
82
- {
83
- "role": "user",
84
- "content": [
85
- {
86
- "type": media_type,
87
- media_type: media_path
88
- },
89
- {"type": "text", "text": text_input},
90
- ],
91
- }
92
- ]
93
-
94
- text = processor.apply_chat_template(
95
- messages, tokenize=False, add_generation_prompt=True
96
- )
97
- image_inputs, _ = process_vision_info(messages)
98
- inputs = processor(
99
- text=[text],
100
- images=image_inputs,
101
- padding=True,
102
- return_tensors="pt",
103
- )
104
-
105
- # Ensure model runs on CPU
106
- streamer = TextIteratorStreamer(
107
- processor.tokenizer, skip_prompt=True, skip_special_tokens=True
108
- )
109
- generation_kwargs = dict(inputs, streamer=streamer, max_new_tokens=1024)
110
-
111
- thread = Thread(target=model.generate, kwargs=generation_kwargs)
112
- thread.start()
113
-
114
- buffer = ""
115
- for new_text in streamer:
116
- buffer += new_text
117
- # Remove <|im_end|> or similar tokens from the output
118
- buffer = buffer.replace("<|im_end|>", "")
119
- yield buffer
120
-
121
- def format_plain_text(output_text):
122
- """Formats the output text as plain text without LaTeX delimiters."""
123
- plain_text = output_text.replace("\\(", "").replace("\\)", "").replace("\\[", "").replace("\\]", "")
124
- return plain_text
125
-
126
- def generate_document(media_path, output_text, file_format, font_choice, font_size, line_spacing, alignment, image_size):
127
- """Generates a document with the input image and plain text output."""
128
- plain_text = format_plain_text(output_text)
129
- if file_format == "pdf":
130
- return generate_pdf(media_path, plain_text, font_choice, font_size, line_spacing, alignment, image_size)
131
- elif file_format == "docx":
132
- return generate_docx(media_path, plain_text, font_choice, font_size, line_spacing, alignment, image_size)
133
-
134
- def generate_pdf(media_path, plain_text, font_choice, font_size, line_spacing, alignment, image_size):
135
- """Generates a PDF document."""
136
- filename = f"output_{uuid.uuid4()}.pdf"
137
- doc = SimpleDocTemplate(
138
- filename,
139
- pagesize=A4,
140
- rightMargin=inch,
141
- leftMargin=inch,
142
- topMargin=inch,
143
- bottomMargin=inch
144
- )
145
- styles = getSampleStyleSheet()
146
- styles["Normal"].fontName = font_choice
147
- styles["Normal"].fontSize = int(font_size)
148
- styles["Normal"].leading = int(font_size) * line_spacing
149
- styles["Normal"].alignment = {
150
- "Left": 0,
151
- "Center": 1,
152
- "Right": 2,
153
- "Justified": 4
154
- }[alignment]
155
-
156
- font_path = f"font/{font_choice}"
157
- pdfmetrics.registerFont(TTFont(font_choice, font_path))
158
-
159
- story = []
160
-
161
- image_sizes = {
162
- "Small": (200, 200),
163
- "Medium": (400, 400),
164
- "Large": (600, 600)
165
- }
166
- img = RLImage(media_path, width=image_sizes[image_size][0], height=image_sizes[image_size][1])
167
- story.append(img)
168
- story.append(Spacer(1, 12))
169
-
170
- text = Paragraph(plain_text, styles["Normal"])
171
- story.append(text)
172
-
173
- doc.build(story)
174
- return filename
175
-
176
- def generate_docx(media_path, plain_text, font_choice, font_size, line_spacing, alignment, image_size):
177
- """Generates a DOCX document."""
178
- filename = f"output_{uuid.uuid4()}.docx"
179
- doc = docx.Document()
180
-
181
- image_sizes = {
182
- "Small": docx.shared.Inches(2),
183
- "Medium": docx.shared.Inches(4),
184
- "Large": docx.shared.Inches(6)
185
- }
186
- doc.add_picture(media_path, width=image_sizes[image_size])
187
- doc.add_paragraph()
188
-
189
- paragraph = doc.add_paragraph()
190
- paragraph.paragraph_format.line_spacing = line_spacing
191
- paragraph.paragraph_format.alignment = {
192
- "Left": WD_ALIGN_PARAGRAPH.LEFT,
193
- "Center": WD_ALIGN_PARAGRAPH.CENTER,
194
- "Right": WD_ALIGN_PARAGRAPH.RIGHT,
195
- "Justified": WD_ALIGN_PARAGRAPH.JUSTIFY
196
- }[alignment]
197
- run = paragraph.add_run(plain_text)
198
- run.font.name = font_choice
199
- run.font.size = docx.shared.Pt(int(font_size))
200
-
201
- doc.save(filename)
202
- return filename
203
-
204
- # Gradio app setup
205
- with gr.Blocks() as demo:
206
- gr.Markdown("# Qwen2VL Models: Vision and Language Processing")
207
-
208
- with gr.Tab(label="Image Input"):
209
-
210
- with gr.Row():
211
- with gr.Column():
212
- model_choice = gr.Dropdown(
213
- label="Model Selection",
214
- choices=list(MODEL_OPTIONS.keys()),
215
- value="Latex OCR"
216
- )
217
- input_media = gr.File(
218
- label="Upload Image", type="filepath"
219
- )
220
- text_input = gr.Textbox(label="Question", placeholder="Ask a question about the image...")
221
- submit_btn = gr.Button(value="Submit")
222
-
223
- with gr.Column():
224
- output_text = gr.Textbox(label="Output Text", lines=10)
225
- plain_text_output = gr.Textbox(label="Standardized Plain Text", lines=10)
226
-
227
- submit_btn.click(
228
- qwen_inference, [model_choice, input_media, text_input], [output_text]
229
- ).then(
230
- lambda output_text: format_plain_text(output_text), [output_text], [plain_text_output]
231
- )
232
-
233
- with gr.Row():
234
- gr.Examples(
235
- examples=[
236
- ["examples/1.png", "summarize the letter", "Text Analogy Ocrtest"],
237
- ["examples/2.jpg", "Summarize the full image in detail", "Latex OCR"],
238
- ["examples/3.png", "Describe the photo", "Qwen2VL Base"],
239
- ["examples/4.png", "summarize and solve the problem", "Math Prase"],
240
- ],
241
- inputs=[input_media, text_input, model_choice],
242
- outputs=[output_text, plain_text_output],
243
- fn=lambda img, question, model: qwen_inference(model, img, question),
244
- cache_examples=False,
245
- )
246
-
247
- with gr.Row():
248
- with gr.Column():
249
- line_spacing = gr.Dropdown(
250
- choices=[0.5, 1.0, 1.15, 1.5, 2.0, 2.5, 3.0],
251
- value=1.5,
252
- label="Line Spacing"
253
- )
254
- font_size = gr.Dropdown(
255
- choices=["8", "10", "12", "14", "16", "18", "20", "22", "24"],
256
- value="18",
257
- label="Font Size"
258
- )
259
- font_choice = gr.Dropdown(
260
- choices=[
261
- "DejaVuMathTeXGyre.ttf",
262
- "FiraCode-Medium.ttf",
263
- "InputMono-Light.ttf",
264
- "JetBrainsMono-Thin.ttf",
265
- "ProggyCrossed Regular Mac.ttf",
266
- "SourceCodePro-Black.ttf",
267
- "arial.ttf",
268
- "calibri.ttf",
269
- "mukta-malar-extralight.ttf",
270
- "noto-sans-arabic-medium.ttf",
271
- "times new roman.ttf",
272
- "ANGSA.ttf",
273
- "Book-Antiqua.ttf",
274
- "CONSOLA.TTF",
275
- "COOPBL.TTF",
276
- "Rockwell-Bold.ttf",
277
- "Candara Light.TTF",
278
- "Carlito-Regular.ttf Carlito-Regular.ttf",
279
- "Castellar.ttf",
280
- "Courier New.ttf",
281
- "LSANS.TTF",
282
- "Lucida Bright Regular.ttf",
283
- "TRTempusSansITC.ttf",
284
- "Verdana.ttf",
285
- "bell-mt.ttf",
286
- "eras-itc-light.ttf",
287
- "fonnts.com-aptos-light.ttf",
288
- "georgia.ttf",
289
- "segoeuithis.ttf",
290
- "youyuan.TTF",
291
- "TfPonetoneExpanded-7BJZA.ttf",
292
- ],
293
- value="youyuan.TTF",
294
- label="Font Choice"
295
- )
296
- alignment = gr.Dropdown(
297
- choices=["Left", "Center", "Right", "Justified"],
298
- value="Justified",
299
- label="Text Alignment"
300
- )
301
- image_size = gr.Dropdown(
302
- choices=["Small", "Medium", "Large"],
303
- value="Small",
304
- label="Image Size"
305
- )
306
- file_format = gr.Radio(["pdf", "docx"], label="File Format", value="pdf")
307
- get_document_btn = gr.Button(value="Get Document")
308
-
309
- get_document_btn.click(
310
- generate_document, [input_media, output_text, file_format, font_choice, font_size, line_spacing, alignment, image_size], gr.File(label="Download Document")
311
- )
312
 
313
- demo.launch(debug=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from torchvision import models, transforms
 
 
 
3
  from PIL import Image
4
+ import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
+ # Load the pre-trained MobileNetV2 model
7
+ model = models.mobilenet_v2(pretrained=True)
8
+ model.eval()
9
+
10
+ # Image transformation for input
11
+ transform = transforms.Compose([
12
+ transforms.Resize(256),
13
+ transforms.CenterCrop(224),
14
+ transforms.ToTensor(),
15
+ transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
16
+ ])
17
+
18
+ def classify_image(image):
19
+ # Apply transformations
20
+ img_tensor = transform(image).unsqueeze(0)
21
+
22
+ # Perform inference
23
+ with torch.no_grad():
24
+ outputs = model(img_tensor)
25
+ _, predicted_class = torch.max(outputs, 1)
26
+
27
+ return predicted_class.item()
28
+
29
+ # Gradio interface
30
+ interface = gr.Interface(fn=classify_image, inputs=gr.Image(type="pil"), outputs="text")
31
+ interface.launch()