mike23415 commited on
Commit
03f2dd0
·
verified ·
1 Parent(s): 8d65f7c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -0
app.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, request, jsonify
2
+ from flask_cors import CORS
3
+ from PIL import Image
4
+ import io
5
+ import os
6
+
7
+ from transformers import DonutProcessor, VisionEncoderDecoderModel
8
+ import torch
9
+ import fitz # PyMuPDF
10
+
11
+ # Initialize Flask
12
+ app = Flask(__name__)
13
+ CORS(app)
14
+
15
+ # Load Donut model and processor
16
+ device = "cpu"
17
+ processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base")
18
+ model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base").to(device)
19
+ model.eval()
20
+
21
+ def convert_pdf_to_image(file_stream):
22
+ doc = fitz.open(stream=file_stream.read(), filetype="pdf")
23
+ page = doc.load_page(0)
24
+ pix = page.get_pixmap(dpi=150)
25
+ img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
26
+ return img
27
+
28
+ @app.route("/ocr", methods=["POST"])
29
+ def ocr():
30
+ if "file" not in request.files:
31
+ return jsonify({"error": "No file uploaded"}), 400
32
+
33
+ file = request.files["file"]
34
+ filename = file.filename.lower()
35
+
36
+ # Convert input to PIL image
37
+ if filename.endswith(".pdf"):
38
+ image = convert_pdf_to_image(file)
39
+ else:
40
+ image = Image.open(io.BytesIO(file.read())).convert("RGB")
41
+
42
+ # Preprocess image
43
+ pixel_values = processor(image, return_tensors="pt").pixel_values.to(device)
44
+
45
+ # Run model
46
+ with torch.no_grad():
47
+ output = model.generate(pixel_values, max_length=512, return_dict_in_generate=True)
48
+
49
+ # Decode output
50
+ parsed_text = processor.batch_decode(output.sequences)[0]
51
+ parsed_text = processor.tokenizer.decode(output.sequences[0], skip_special_tokens=True)
52
+
53
+ return jsonify({"text": parsed_text})
54
+
55
+
56
+ @app.route("/", methods=["GET"])
57
+ def index():
58
+ return "Smart OCR Flask API (Donut-based)"
59
+
60
+ if __name__ == "__main__":
61
+ app.run(host="0.0.0.0", port=7860)