|
from flask import Flask, request, jsonify |
|
from flask_cors import CORS |
|
from PIL import Image |
|
import io |
|
import os |
|
|
|
from transformers import DonutProcessor, VisionEncoderDecoderModel |
|
import torch |
|
import fitz |
|
|
|
|
|
app = Flask(__name__) |
|
CORS(app) |
|
|
|
|
|
device = "cpu" |
|
processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base") |
|
model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base").to(device) |
|
model.eval() |
|
|
|
def convert_pdf_to_image(file_stream): |
|
doc = fitz.open(stream=file_stream.read(), filetype="pdf") |
|
page = doc.load_page(0) |
|
pix = page.get_pixmap(dpi=150) |
|
img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples) |
|
return img |
|
|
|
@app.route("/ocr", methods=["POST"]) |
|
def ocr(): |
|
if "file" not in request.files: |
|
return jsonify({"error": "No file uploaded"}), 400 |
|
|
|
file = request.files["file"] |
|
filename = file.filename.lower() |
|
|
|
|
|
if filename.endswith(".pdf"): |
|
image = convert_pdf_to_image(file) |
|
else: |
|
image = Image.open(io.BytesIO(file.read())).convert("RGB") |
|
|
|
|
|
pixel_values = processor(image, return_tensors="pt").pixel_values.to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
output = model.generate(pixel_values, max_length=512, return_dict_in_generate=True) |
|
|
|
|
|
parsed_text = processor.batch_decode(output.sequences)[0] |
|
parsed_text = processor.tokenizer.decode(output.sequences[0], skip_special_tokens=True) |
|
|
|
return jsonify({"text": parsed_text}) |
|
|
|
|
|
@app.route("/", methods=["GET"]) |
|
def index(): |
|
return "Smart OCR Flask API (Donut-based)" |
|
|
|
if __name__ == "__main__": |
|
app.run(host="0.0.0.0", port=7860) |