intuitive262 commited on
Commit
3bc9acc
·
1 Parent(s): 99c2cdc

code files

Browse files
Files changed (2) hide show
  1. app.py +92 -0
  2. requirements.txt +10 -0
app.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import numpy as np
3
+ from PIL import Image
4
+ import cv2
5
+ from transformers import TrOCRProcessor, VisionEncoderDecoderModel
6
+ from huggingface_hub import hf_hub_download
7
+ import torch
8
+ import re
9
+
10
+ # Download and load the GOT OCR model
11
+ got_model_path = hf_hub_download(repo_id="junyeopkim/got_2.0_torch_script", filename="got_2.0_tiny.torchscript")
12
+ got_model = torch.jit.load(got_model_path)
13
+
14
+ # Load the Surya-OCR model
15
+ surya_processor = TrOCRProcessor.from_pretrained("suryavarmaaddala/suryaocr")
16
+ surya_model = VisionEncoderDecoderModel.from_pretrained("suryavarmaaddala/suryaocr")
17
+
18
+ def preprocess_image(image):
19
+ if isinstance(image, str):
20
+ image = Image.open(image).convert("RGB")
21
+ elif isinstance(image, np.ndarray):
22
+ image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
23
+ return image
24
+
25
+ def got_ocr(image):
26
+ image = preprocess_image(image)
27
+ image = image.resize((224, 224))
28
+ input_tensor = torch.from_numpy(np.array(image)).permute(2, 0, 1).float() / 255.0
29
+ input_tensor = input_tensor.unsqueeze(0)
30
+
31
+ with torch.no_grad():
32
+ output = got_model(input_tensor)
33
+
34
+ return output[0].item()
35
+
36
+ def surya_ocr(image):
37
+ image = preprocess_image(image)
38
+ pixel_values = surya_processor(image, return_tensors="pt").pixel_values
39
+
40
+ generated_ids = surya_model.generate(pixel_values)
41
+ generated_text = surya_processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
42
+
43
+ return generated_text
44
+
45
+ def post_process_text(text):
46
+ # Simple post-processing to split into lines
47
+ return '\n'.join(text.split('. '))
48
+
49
+ def search_text(text, query):
50
+ try:
51
+ pattern = re.compile(query, re.IGNORECASE)
52
+ lines = text.split('\n')
53
+ matching_lines = [line for line in lines if pattern.search(line)]
54
+ return '\n'.join(matching_lines) if matching_lines else "No matches found."
55
+ except re.error:
56
+ return "Invalid regex pattern. Please try again."
57
+
58
+ def process_and_search(image, search_query):
59
+ try:
60
+ got_score = got_ocr(image)
61
+ surya_text = surya_ocr(image)
62
+
63
+ result = f"GOT OCR Score: {got_score:.4f}\n\nExtracted Text:\n{surya_text}"
64
+ processed_text = post_process_text(result)
65
+
66
+ search = None
67
+ if search_query:
68
+ search = search_text(processed_text, search_query)
69
+ return image, processed_text, search
70
+ except Exception as e:
71
+ return None, f"An error occurred: {str(e)}", None
72
+
73
+ with gr.Blocks() as demo:
74
+ with gr.Row():
75
+ with gr.Column(scale=1):
76
+ image_input = gr.Image(type="filepath", label="Upload your image")
77
+ search_query_input = gr.Textbox(label="Enter search query")
78
+ submit_button = gr.Button("Submit")
79
+
80
+ with gr.Column(scale=2):
81
+ displayed_image = gr.Image(label="Uploaded Image")
82
+ ocr_result = gr.Textbox(label="OCR Result", lines=10)
83
+ search_result = gr.Textbox(label="Search Result", lines=5)
84
+
85
+ submit_button.click(
86
+ fn=process_and_search,
87
+ inputs=[image_input, search_query_input],
88
+ outputs=[displayed_image, ocr_result, search_result]
89
+ )
90
+
91
+ if __name__ == "__main__":
92
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ Pillow
3
+ surya-ocr
4
+ torch
5
+ transformers
6
+ tiktoken
7
+ torchvision
8
+ verovio
9
+ accelerate
10
+ rapidfuzz