Vlm-test / app.py
mike23415's picture
Create app.py
03f2dd0 verified
raw
history blame
1.77 kB
from flask import Flask, request, jsonify
from flask_cors import CORS
from PIL import Image
import io
import os
from transformers import DonutProcessor, VisionEncoderDecoderModel
import torch
import fitz # PyMuPDF
# Initialize Flask
app = Flask(__name__)
CORS(app)
# Load Donut model and processor
device = "cpu"
processor = DonutProcessor.from_pretrained("naver-clova-ix/donut-base")
model = VisionEncoderDecoderModel.from_pretrained("naver-clova-ix/donut-base").to(device)
model.eval()
def convert_pdf_to_image(file_stream):
doc = fitz.open(stream=file_stream.read(), filetype="pdf")
page = doc.load_page(0)
pix = page.get_pixmap(dpi=150)
img = Image.frombytes("RGB", [pix.width, pix.height], pix.samples)
return img
@app.route("/ocr", methods=["POST"])
def ocr():
if "file" not in request.files:
return jsonify({"error": "No file uploaded"}), 400
file = request.files["file"]
filename = file.filename.lower()
# Convert input to PIL image
if filename.endswith(".pdf"):
image = convert_pdf_to_image(file)
else:
image = Image.open(io.BytesIO(file.read())).convert("RGB")
# Preprocess image
pixel_values = processor(image, return_tensors="pt").pixel_values.to(device)
# Run model
with torch.no_grad():
output = model.generate(pixel_values, max_length=512, return_dict_in_generate=True)
# Decode output
parsed_text = processor.batch_decode(output.sequences)[0]
parsed_text = processor.tokenizer.decode(output.sequences[0], skip_special_tokens=True)
return jsonify({"text": parsed_text})
@app.route("/", methods=["GET"])
def index():
return "Smart OCR Flask API (Donut-based)"
if __name__ == "__main__":
app.run(host="0.0.0.0", port=7860)