OCR-SMALL / main.py
AnkitShrestha's picture
Add internal ollama parsing to citizenship ocr
a1c0d1f
# # ! pip uninstall -y tensorflow
# # ! pip install "python-doctr[torch,viz]"
# from fastapi import FastAPI, UploadFile, File
# from fastapi.responses import JSONResponse
# from utils import dev_number, roman_number, dev_letter, roman_letter
# import tempfile
# app = FastAPI()
# @app.post("/ocr_dev_number/")
# async def extract_dev_number(image: UploadFile = File(...)):
# # Save uploaded image temporarily
# with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp:
# content = await image.read()
# tmp.write(content)
# tmp_path = tmp.name
# # predict the image
# predicted_str = dev_number(tmp_path)
# # Return result as JSON
# return JSONResponse(content={"predicted_str": predicted_str})
# @app.post("/ocr_roman_number/")
# async def extract_roman_number(image: UploadFile = File(...)):
# # Save uploaded image temporarily
# with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp:
# content = await image.read()
# tmp.write(content)
# tmp_path = tmp.name
# # predict the image
# predicted_str = roman_number(tmp_path)
# # Return result as JSON
# return JSONResponse(content={"predicted_str": predicted_str})
# @app.post("/ocr_dev_letter/")
# async def extract_dev_letter(image: UploadFile = File(...)):
# # Save uploaded image temporarily
# with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp:
# content = await image.read()
# tmp.write(content)
# tmp_path = tmp.name
# # predict the image
# predicted_str = dev_letter(tmp_path)
# # Return result as JSON
# return JSONResponse(content={"predicted_str": predicted_str})
# @app.post("/ocr_roman_letter/")
# async def extract_roman_letter(image: UploadFile = File(...)):
# # Save uploaded image temporarily
# with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp:
# content = await image.read()
# tmp.write(content)
# tmp_path = tmp.name
# # predict the image
# predicted_str = roman_letter(tmp_path)
# # Return result as JSON
# return JSONResponse(content={"predicted_str": predicted_str})
import os
import tempfile
from typing import Literal
from fastapi import FastAPI, UploadFile, File, HTTPException, Depends
from fastapi.responses import JSONResponse
from pydantic import BaseModel
import shutil
# Import from optimized utils
from utils import dev_number, roman_number, dev_letter, roman_letter, predict_ne, perform_citizenship_ocr
app = FastAPI(
title="OCR API",
description="API for optical character recognition of Roman and Devanagari text",
version="1.0.0"
)
class OCRResponse(BaseModel):
"""Response model for OCR endpoints"""
predicted_str: str
confidence: float = None # Optional confidence field
# Helper function to handle file uploads consistently
async def save_upload_file_tmp(upload_file: UploadFile) -> str:
"""Save an upload file to a temporary file and return the path"""
try:
# Create a temporary file with the appropriate suffix
suffix = os.path.splitext(upload_file.filename)[1] if upload_file.filename else ".png"
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp:
# Get the file content
content = await upload_file.read()
# Write content to temporary file
tmp.write(content)
tmp_path = tmp.name
return tmp_path
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error processing file: {str(e)}")
# Generic OCR function that can be reused across endpoints
async def process_ocr_request(
image: UploadFile = File(...),
ocr_function=None
):
"""Process an OCR request using the specified OCR function"""
if not ocr_function:
raise HTTPException(status_code=500, detail="OCR function not specified")
try:
# Save uploaded image temporarily
tmp_path = await save_upload_file_tmp(image)
# Process the image with the specified OCR function
result = ocr_function(tmp_path)
# Clean up the temporary file
os.unlink(tmp_path)
# Handle different types of results (string vs doctr output)
if isinstance(result, str):
return JSONResponse(content={"predicted_str": result})
else:
# For doctr results, extract the text (adapt as needed based on doctr output format)
# This assumes roman_letter function returns a structure with extractable text
extracted_text = " ".join([block.value for page in result.pages for block in page.blocks])
return JSONResponse(content={"predicted_str": extracted_text})
except Exception as e:
# Ensure we clean up even if there's an error
if 'tmp_path' in locals() and os.path.exists(tmp_path):
os.unlink(tmp_path)
raise HTTPException(status_code=500, detail=f"OCR processing error: {str(e)}")
# Endpoints with minimal duplication
@app.post("/ocr/", summary="Generic OCR endpoint")
async def extract_text(
image: UploadFile = File(...),
model_type: Literal["dev_number", "roman_number", "dev_letter", "roman_letter"] = "roman_letter"
):
"""
Generic OCR endpoint that can handle any supported recognition type.
- **image**: Image file to process
- **model_type**: Type of OCR to perform
"""
ocr_functions = {
"dev_number": dev_number,
"roman_number": roman_number,
"dev_letter": dev_letter,
"roman_letter": roman_letter,
}
if model_type not in ocr_functions:
raise HTTPException(status_code=400, detail=f"Invalid model type: {model_type}")
return await process_ocr_request(image, ocr_functions[model_type])
# For backward compatibility, keep the original endpoints
@app.post("/ocr_dev_number/")
async def extract_dev_number(image: UploadFile = File(...)):
"""Extract Devanagari numbers from an image"""
return await process_ocr_request(image, dev_number)
@app.post("/ocr_roman_number/")
async def extract_roman_number(image: UploadFile = File(...)):
"""Extract Roman numbers from an image"""
return await process_ocr_request(image, roman_number)
@app.post("/ocr_dev_letter/")
async def extract_dev_letter(image: UploadFile = File(...)):
"""Extract Devanagari letters from an image"""
return await process_ocr_request(image, dev_letter)
@app.post("/ocr_roman_letter/")
async def extract_roman_letter(image: UploadFile = File(...)):
"""Extract Roman letters from an image"""
return await process_ocr_request(image, roman_letter)
@app.post("/predict_ne")
async def classify_ne(image: UploadFile = File(...)):
"""Predict Named Entities from an image"""
# Placeholder for Named Entity Recognition logic
image_path = await save_upload_file_tmp(image)
prediction = predict_ne(
image_path=image_path,
# model="models/nepali_english_classifier.pth", # Update with actual model path
device="cpu" # or "cpu"
)
# Implement the logic as per your requirements
return JSONResponse(content={"predicted": prediction})
@app.post("/ocr_citizenship/")
async def ocr_citizenship(image: UploadFile = File(...)):
"""OCR the provided Nepali Citizenship card"""
image_path = await save_upload_file_tmp(image)
prediction = perform_citizenship_ocr(
image_path=image_path,
)
return JSONResponse(content=prediction)
# Health check endpoint
@app.get("/health")
async def health_check():
"""Health check endpoint to verify the API is running"""
return {"status": "healthy"}