mike23415 commited on
Commit
035a6f9
·
verified ·
1 Parent(s): bace547

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +120 -22
app.py CHANGED
@@ -4,58 +4,156 @@ from PIL import Image
4
  import io
5
  import os
6
 
7
- from transformers import DonutProcessor, VisionEncoderDecoderModel
 
8
  import torch
 
 
 
 
 
 
 
9
  import fitz # PyMuPDF
10
 
11
  # Initialize Flask
12
  app = Flask(__name__)
13
  CORS(app)
14
 
15
- # Load Donut model and processor
16
- device = "cpu"
17
- processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base")
18
- model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base").to(device)
19
  model.eval()
20
 
 
 
 
21
  def convert_pdf_to_image(file_stream):
 
22
  doc = fitz.open(stream=file_stream.read(), filetype="pdf")
23
  page = doc.load_page(0)
24
- pix = page.get_pixmap(dpi=150)
 
25
  img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
 
26
  return img
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  @app.route("/ocr", methods=["POST"])
29
  def ocr():
30
  if "file" not in request.files:
31
  return jsonify({"error": "No file uploaded"}), 400
32
 
33
  file = request.files["file"]
 
 
 
34
  filename = file.filename.lower()
35
 
36
- # Convert input to PIL image
37
- if filename.endswith(".pdf"):
38
- image = convert_pdf_to_image(file)
39
- else:
40
- image = Image.open(io.BytesIO(file.read())).convert("RGB")
 
41
 
42
- # Preprocess image
43
- pixel_values = processor(image, return_tensors="pt").pixel_values.to(device)
44
 
45
- # Run model
46
- with torch.no_grad():
47
- output = model.generate(pixel_values, max_length=512, return_dict_in_generate=True)
48
 
49
- # Decode output
50
- parsed_text = processor.batch_decode(output.sequences)[0]
51
- parsed_text = processor.tokenizer.decode(output.sequences[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
52
 
53
- return jsonify({"text": parsed_text})
 
 
 
54
 
 
 
 
55
 
56
  @app.route("/", methods=["GET"])
57
  def index():
58
- return "Smart OCR Flask API (Donut-based)"
59
 
60
  if __name__ == "__main__":
61
- app.run(host="0.0.0.0", port=7860)
 
4
  import io
5
  import os
6
 
7
+ # Option 1: Using TrOCR (Transformer-based OCR)
8
+ from transformers import TrOCRProcessor, VisionEncoderDecoderModel
9
  import torch
10
+
11
+ # Option 2: Using EasyOCR (commented out - uncomment if you prefer this)
12
+ # import easyocr
13
+
14
+ # Option 3: Using Tesseract (commented out - uncomment if you prefer this)
15
+ # import pytesseract
16
+
17
  import fitz # PyMuPDF
18
 
19
  # Initialize Flask
20
  app = Flask(__name__)
21
  CORS(app)
22
 
23
+ # Load TrOCR model and processor (better for text extraction)
24
+ device = "cuda" if torch.cuda.is_available() else "cpu"
25
+ processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed")
26
+ model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-printed").to(device)
27
  model.eval()
28
 
29
+ # Alternative: Initialize EasyOCR reader (uncomment if using EasyOCR)
30
+ # reader = easyocr.Reader(['en'])
31
+
32
  def convert_pdf_to_image(file_stream):
33
+ """Convert PDF to image with higher DPI for better OCR"""
34
  doc = fitz.open(stream=file_stream.read(), filetype="pdf")
35
  page = doc.load_page(0)
36
+ # Increase DPI for better text recognition
37
+ pix = page.get_pixmap(dpi=300) # Higher DPI
38
  img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
39
+ doc.close()
40
  return img
41
 
42
+ def preprocess_image(image):
43
+ """Preprocess image for better OCR results"""
44
+ # Convert to grayscale if needed
45
+ if image.mode != 'RGB':
46
+ image = image.convert('RGB')
47
+
48
+ # Resize if image is too small
49
+ width, height = image.size
50
+ if width < 1000 or height < 1000:
51
+ scale_factor = max(1000/width, 1000/height)
52
+ new_width = int(width * scale_factor)
53
+ new_height = int(height * scale_factor)
54
+ image = image.resize((new_width, new_height), Image.Resampling.LANCZOS)
55
+
56
+ return image
57
+
58
+ def extract_text_trocr(image):
59
+ """Extract text using TrOCR"""
60
+ try:
61
+ # Split image into chunks if it's large (TrOCR works better on smaller sections)
62
+ width, height = image.size
63
+ chunk_height = 400 # Process in chunks
64
+ extracted_texts = []
65
+
66
+ for y in range(0, height, chunk_height):
67
+ chunk = image.crop((0, y, width, min(y + chunk_height, height)))
68
+
69
+ # Process with TrOCR
70
+ pixel_values = processor(chunk, return_tensors="pt").pixel_values.to(device)
71
+
72
+ with torch.no_grad():
73
+ generated_ids = model.generate(pixel_values, max_length=512)
74
+
75
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
76
+ if generated_text.strip():
77
+ extracted_texts.append(generated_text.strip())
78
+
79
+ return "\n".join(extracted_texts)
80
+ except Exception as e:
81
+ print(f"TrOCR error: {e}")
82
+ return ""
83
+
84
+ def extract_text_easyocr(image):
85
+ """Extract text using EasyOCR (uncomment the import and initialization above)"""
86
+ try:
87
+ results = reader.readtext(image)
88
+ extracted_text = []
89
+ for (bbox, text, confidence) in results:
90
+ if confidence > 0.5: # Filter low confidence detections
91
+ extracted_text.append(text)
92
+ return "\n".join(extracted_text)
93
+ except Exception as e:
94
+ print(f"EasyOCR error: {e}")
95
+ return ""
96
+
97
+ def extract_text_tesseract(image):
98
+ """Extract text using Tesseract (uncomment the import above)"""
99
+ try:
100
+ # Convert to grayscale for better OCR
101
+ gray_image = image.convert('L')
102
+ text = pytesseract.image_to_string(gray_image, config='--psm 6')
103
+ return text.strip()
104
+ except Exception as e:
105
+ print(f"Tesseract error: {e}")
106
+ return ""
107
+
108
  @app.route("/ocr", methods=["POST"])
109
  def ocr():
110
  if "file" not in request.files:
111
  return jsonify({"error": "No file uploaded"}), 400
112
 
113
  file = request.files["file"]
114
+ if not file.filename:
115
+ return jsonify({"error": "No file selected"}), 400
116
+
117
  filename = file.filename.lower()
118
 
119
+ try:
120
+ # Convert input to PIL image
121
+ if filename.endswith(".pdf"):
122
+ image = convert_pdf_to_image(file)
123
+ else:
124
+ image = Image.open(io.BytesIO(file.read())).convert("RGB")
125
 
126
+ # Preprocess image
127
+ image = preprocess_image(image)
128
 
129
+ # Extract text using TrOCR (primary method)
130
+ extracted_text = extract_text_trocr(image)
 
131
 
132
+ # If TrOCR fails or returns empty, try alternative methods
133
+ if not extracted_text:
134
+ print("TrOCR failed, trying alternative methods...")
135
+ # Uncomment one of these if you have the libraries installed:
136
+ # extracted_text = extract_text_easyocr(image)
137
+ # extracted_text = extract_text_tesseract(image)
138
+
139
+ if not extracted_text:
140
+ return jsonify({
141
+ "text": "",
142
+ "message": "No text could be extracted from the image. The image might be too blurry, have low contrast, or contain handwritten text."
143
+ })
144
 
145
+ return jsonify({
146
+ "text": extracted_text,
147
+ "message": "Text extracted successfully"
148
+ })
149
 
150
+ except Exception as e:
151
+ print(f"OCR processing error: {e}")
152
+ return jsonify({"error": f"Failed to process file: {str(e)}"}), 500
153
 
154
  @app.route("/", methods=["GET"])
155
  def index():
156
+ return "Smart OCR Flask API (TrOCR-based)"
157
 
158
  if __name__ == "__main__":
159
+ app.run(host="0.0.0.0", port=7860, debug=True)