Spaces:
Sleeping
Sleeping
# # ! pip uninstall -y tensorflow | |
# # ! pip install "python-doctr[torch,viz]" | |
# from fastapi import FastAPI, UploadFile, File | |
# from fastapi.responses import JSONResponse | |
# from utils import dev_number, roman_number, dev_letter, roman_letter | |
# import tempfile | |
# app = FastAPI() | |
# @app.post("/ocr_dev_number/") | |
# async def extract_dev_number(image: UploadFile = File(...)): | |
# # Save uploaded image temporarily | |
# with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp: | |
# content = await image.read() | |
# tmp.write(content) | |
# tmp_path = tmp.name | |
# # predict the image | |
# predicted_str = dev_number(tmp_path) | |
# # Return result as JSON | |
# return JSONResponse(content={"predicted_str": predicted_str}) | |
# @app.post("/ocr_roman_number/") | |
# async def extract_roman_number(image: UploadFile = File(...)): | |
# # Save uploaded image temporarily | |
# with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp: | |
# content = await image.read() | |
# tmp.write(content) | |
# tmp_path = tmp.name | |
# # predict the image | |
# predicted_str = roman_number(tmp_path) | |
# # Return result as JSON | |
# return JSONResponse(content={"predicted_str": predicted_str}) | |
# @app.post("/ocr_dev_letter/") | |
# async def extract_dev_letter(image: UploadFile = File(...)): | |
# # Save uploaded image temporarily | |
# with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp: | |
# content = await image.read() | |
# tmp.write(content) | |
# tmp_path = tmp.name | |
# # predict the image | |
# predicted_str = dev_letter(tmp_path) | |
# # Return result as JSON | |
# return JSONResponse(content={"predicted_str": predicted_str}) | |
# @app.post("/ocr_roman_letter/") | |
# async def extract_roman_letter(image: UploadFile = File(...)): | |
# # Save uploaded image temporarily | |
# with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp: | |
# content = await image.read() | |
# tmp.write(content) | |
# tmp_path = tmp.name | |
# # predict the image | |
# predicted_str = roman_letter(tmp_path) | |
# # Return result as JSON | |
# return JSONResponse(content={"predicted_str": predicted_str}) | |
import os | |
import tempfile | |
from typing import Literal | |
from fastapi import FastAPI, UploadFile, File, HTTPException, Depends | |
from fastapi.responses import JSONResponse | |
from pydantic import BaseModel | |
import shutil | |
# Import from optimized utils | |
from utils import dev_number, roman_number, dev_letter, roman_letter, predict_ne, perform_citizenship_ocr | |
app = FastAPI( | |
title="OCR API", | |
description="API for optical character recognition of Roman and Devanagari text", | |
version="1.0.0" | |
) | |
class OCRResponse(BaseModel): | |
"""Response model for OCR endpoints""" | |
predicted_str: str | |
confidence: float = None # Optional confidence field | |
# Helper function to handle file uploads consistently | |
async def save_upload_file_tmp(upload_file: UploadFile) -> str: | |
"""Save an upload file to a temporary file and return the path""" | |
try: | |
# Create a temporary file with the appropriate suffix | |
suffix = os.path.splitext(upload_file.filename)[1] if upload_file.filename else ".png" | |
with tempfile.NamedTemporaryFile(delete=False, suffix=suffix) as tmp: | |
# Get the file content | |
content = await upload_file.read() | |
# Write content to temporary file | |
tmp.write(content) | |
tmp_path = tmp.name | |
return tmp_path | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Error processing file: {str(e)}") | |
# Generic OCR function that can be reused across endpoints | |
async def process_ocr_request( | |
image: UploadFile = File(...), | |
ocr_function=None | |
): | |
"""Process an OCR request using the specified OCR function""" | |
if not ocr_function: | |
raise HTTPException(status_code=500, detail="OCR function not specified") | |
try: | |
# Save uploaded image temporarily | |
tmp_path = await save_upload_file_tmp(image) | |
# Process the image with the specified OCR function | |
result = ocr_function(tmp_path) | |
# Clean up the temporary file | |
os.unlink(tmp_path) | |
# Handle different types of results (string vs doctr output) | |
if isinstance(result, str): | |
return JSONResponse(content={"predicted_str": result}) | |
else: | |
# For doctr results, extract the text (adapt as needed based on doctr output format) | |
# This assumes roman_letter function returns a structure with extractable text | |
extracted_text = " ".join([block.value for page in result.pages for block in page.blocks]) | |
return JSONResponse(content={"predicted_str": extracted_text}) | |
except Exception as e: | |
# Ensure we clean up even if there's an error | |
if 'tmp_path' in locals() and os.path.exists(tmp_path): | |
os.unlink(tmp_path) | |
raise HTTPException(status_code=500, detail=f"OCR processing error: {str(e)}") | |
# Endpoints with minimal duplication | |
async def extract_text( | |
image: UploadFile = File(...), | |
model_type: Literal["dev_number", "roman_number", "dev_letter", "roman_letter"] = "roman_letter" | |
): | |
""" | |
Generic OCR endpoint that can handle any supported recognition type. | |
- **image**: Image file to process | |
- **model_type**: Type of OCR to perform | |
""" | |
ocr_functions = { | |
"dev_number": dev_number, | |
"roman_number": roman_number, | |
"dev_letter": dev_letter, | |
"roman_letter": roman_letter, | |
} | |
if model_type not in ocr_functions: | |
raise HTTPException(status_code=400, detail=f"Invalid model type: {model_type}") | |
return await process_ocr_request(image, ocr_functions[model_type]) | |
# For backward compatibility, keep the original endpoints | |
async def extract_dev_number(image: UploadFile = File(...)): | |
"""Extract Devanagari numbers from an image""" | |
return await process_ocr_request(image, dev_number) | |
async def extract_roman_number(image: UploadFile = File(...)): | |
"""Extract Roman numbers from an image""" | |
return await process_ocr_request(image, roman_number) | |
async def extract_dev_letter(image: UploadFile = File(...)): | |
"""Extract Devanagari letters from an image""" | |
return await process_ocr_request(image, dev_letter) | |
async def extract_roman_letter(image: UploadFile = File(...)): | |
"""Extract Roman letters from an image""" | |
return await process_ocr_request(image, roman_letter) | |
async def classify_ne(image: UploadFile = File(...)): | |
"""Predict Named Entities from an image""" | |
# Placeholder for Named Entity Recognition logic | |
image_path = await save_upload_file_tmp(image) | |
prediction = predict_ne( | |
image_path=image_path, | |
# model="models/nepali_english_classifier.pth", # Update with actual model path | |
device="cpu" # or "cpu" | |
) | |
# Implement the logic as per your requirements | |
return JSONResponse(content={"predicted": prediction}) | |
async def ocr_citizenship(image: UploadFile = File(...)): | |
"""OCR the provided Nepali Citizenship card""" | |
image_path = await save_upload_file_tmp(image) | |
prediction = perform_citizenship_ocr( | |
image_path=image_path, | |
) | |
return JSONResponse(content=prediction) | |
# Health check endpoint | |
async def health_check(): | |
"""Health check endpoint to verify the API is running""" | |
return {"status": "healthy"} |