rodrigomasini's picture
Update main.py
08b4034 verified
raw
history blame
15.2 kB
import os
import time
import tempfile
import traceback
import logging # Added
from pathlib import Path
from typing import List, Union, Optional, Dict, Any, Literal
import uuid # For request IDs
from fastapi import FastAPI, File, UploadFile, HTTPException, status, Request # Added Request
from pydantic import BaseModel, Field
# PIL.Image is not directly used in this file after refactoring,
# but mdr_pdf_parser might use it, so keep if necessary for that.
# --- IMPORTS ---
from mdr_pdf_parser import (
MagicPDFProcessor,
MDRStructuredBlock, # Assuming this is the base type for the others
MDRTextBlock,
MDRTableBlock,
MDRFormulaBlock,
MDRFigureBlock,
MDRTextKind, # Used by MDRTextBlockModel
MDRTableFormat, # Used by MDRTableBlockModel
MDRRectangle, # Used by MDRRectangleModel
MDRTextSpan, # Used by MDRTextSpanModel
MDRExtractedTableFormat # For configuration
)
# --- Logging Configuration ---
LOG_LEVEL_STR = os.environ.get("LOG_LEVEL", "INFO").upper()
LOG_LEVEL = getattr(logging, LOG_LEVEL_STR, logging.INFO)
logging.basicConfig(
level=LOG_LEVEL,
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger("mdr_fastapi_service")
# --- Configuration ---
# Read from environment variables, falling back to defaults
MODEL_DIR = os.environ.get("MDR_MODEL_DIR", "/models")
DEVICE = os.environ.get("MDR_DEVICE", "cuda")
TABLE_FORMAT_STR = os.environ.get("MDR_TABLE_FORMAT", "MARKDOWN")
DEBUG_DIR_PATH_STR = os.environ.get("MDR_DEBUG_DIR_PATH", None) # For processor's debug output
# Convert table format string to Enum
try:
TABLE_FORMAT = MDRExtractedTableFormat[TABLE_FORMAT_STR.upper()]
except KeyError:
logger.warning(f"Invalid MDR_TABLE_FORMAT '{TABLE_FORMAT_STR}'. Defaulting to DISABLE.")
TABLE_FORMAT = MDRExtractedTableFormat.DISABLE
# --- Global Processor Variable ---
mdr_processor: Optional[MagicPDFProcessor] = None
# --- API Models (Pydantic) ---
# (Pydantic models remain largely the same, ensure they are correct and complete)
class MDRPointModel(BaseModel):
x: float
y: float
class MDRRectangleModel(BaseModel):
lt: MDRPointModel
rt: MDRPointModel
lb: MDRPointModel
rb: MDRPointModel
@classmethod
def from_mdr_rectangle(cls, rect: MDRRectangle):
return cls(
lt=MDRPointModel(x=rect.lt[0], y=rect.lt[1]),
rt=MDRPointModel(x=rect.rt[0], y=rect.rt[1]),
lb=MDRPointModel(x=rect.lb[0], y=rect.lb[1]),
rb=MDRPointModel(x=rect.rb[0], y=rect.rb[1]),
)
class MDRTextSpanModel(BaseModel):
content: str
rank: float
rect: MDRRectangleModel
@classmethod
def from_mdr_text_span(cls, span: MDRTextSpan):
return cls(
content=span.content,
rank=span.rank,
rect=MDRRectangleModel.from_mdr_rectangle(span.rect)
)
class MDRBasicBlockModel(BaseModel):
block_type: str
rect: MDRRectangleModel
texts: List[MDRTextSpanModel] = Field(default_factory=list)
font_size: float
class MDRTextBlockModel(MDRBasicBlockModel):
block_type: Literal["TextBlock"] = "TextBlock"
kind: str
has_paragraph_indentation: bool
last_line_touch_end: bool
texts: List[MDRTextSpanModel]
@classmethod
def from_mdr_text_block(cls, block: MDRTextBlock):
return cls(
rect=MDRRectangleModel.from_mdr_rectangle(block.rect),
texts=[MDRTextSpanModel.from_mdr_text_span(span) for span in block.texts],
font_size=block.font_size,
kind=block.kind.name,
has_paragraph_indentation=block.has_paragraph_indentation,
last_line_touch_end=block.last_line_touch_end
)
class MDRTableBlockModel(MDRBasicBlockModel):
block_type: Literal["TableBlock"] = "TableBlock"
content: str
format: str
@classmethod
def from_mdr_table_block(cls, block: MDRTableBlock):
return cls(
rect=MDRRectangleModel.from_mdr_rectangle(block.rect),
texts=[MDRTextSpanModel.from_mdr_text_span(span) for span in block.texts],
font_size=block.font_size,
content=block.content,
format=block.format.name
)
class MDRFormulaBlockModel(MDRBasicBlockModel):
block_type: Literal["FormulaBlock"] = "FormulaBlock"
content: Optional[str] = None
@classmethod
def from_mdr_formula_block(cls, block: MDRFormulaBlock):
return cls(
rect=MDRRectangleModel.from_mdr_rectangle(block.rect),
texts=[MDRTextSpanModel.from_mdr_text_span(span) for span in block.texts],
font_size=block.font_size,
content=block.content
)
class MDRFigureBlockModel(MDRBasicBlockModel):
block_type: Literal["FigureBlock"] = "FigureBlock"
@classmethod
def from_mdr_figure_block(cls, block: MDRFigureBlock):
return cls(
rect=MDRRectangleModel.from_mdr_rectangle(block.rect),
texts=[MDRTextSpanModel.from_mdr_text_span(span) for span in block.texts],
font_size=block.font_size
)
MDRStructuredBlockModelAPI = Union[MDRTextBlockModel, MDRTableBlockModel, MDRFormulaBlockModel, MDRFigureBlockModel]
# --- FastAPI App ---
app = FastAPI(
title="MagicDataReadiness PDF Processor",
description="API service to extract structured content from PDF files.",
version="1.0.0"
)
# --- Helper Functions ---
def _convert_block_to_api_model(block: MDRStructuredBlock) -> Optional[MDRStructuredBlockModelAPI]:
"""Converts internal MDR block to an API model."""
if isinstance(block, MDRTextBlock):
return MDRTextBlockModel.from_mdr_text_block(block)
elif isinstance(block, MDRTableBlock):
return MDRTableBlockModel.from_mdr_table_block(block)
elif isinstance(block, MDRFormulaBlock):
return MDRFormulaBlockModel.from_mdr_formula_block(block)
elif isinstance(block, MDRFigureBlock):
return MDRFigureBlockModel.from_mdr_figure_block(block)
logger.warning(f"Unknown block type encountered: {type(block)}. Skipping conversion.")
return None
# --- Application Lifecycle Events ---
@app.on_event("startup")
async def startup_event():
global mdr_processor
logger.info("Application startup sequence initiated.")
logger.info("--- Configuration ---")
logger.info(f" MDR_MODEL_DIR: {MODEL_DIR}")
logger.info(f" MDR_DEVICE: {DEVICE}")
logger.info(f" MDR_TABLE_FORMAT: {TABLE_FORMAT.name}")
logger.info(f" MDR_DEBUG_DIR_PATH: {DEBUG_DIR_PATH_STR if DEBUG_DIR_PATH_STR else 'Not set'}")
logger.info(f" LOG_LEVEL: {LOG_LEVEL_STR}")
logger.info("---------------------")
logger.info("Initializing MagicPDFProcessor...")
init_start_time = time.time()
try:
mdr_processor = MagicPDFProcessor(
device=DEVICE,
model_dir_path=MODEL_DIR,
extract_table_format=TABLE_FORMAT,
debug_dir_path=DEBUG_DIR_PATH_STR # Pass the actual path or None
)
init_duration = time.time() - init_start_time
logger.info(f"MagicPDFProcessor initialized successfully ({init_duration:.2f}s)")
except Exception as e:
logger.critical(f"Failed to initialize MagicPDFProcessor: {e}", exc_info=True)
# mdr_processor will remain None, startup_event_check will handle this
# No need to print traceback here, logger.critical with exc_info=True does it.
@app.on_event("startup") # Separate event to check after initialization attempt
async def startup_event_check():
if mdr_processor is None:
logger.error("MagicPDFProcessor is not initialized. Service cannot function correctly.")
# Depending on deployment, you might want to exit or let it run in a degraded state.
# For now, it will allow FastAPI to start but /health and /process-pdf will fail.
# raise RuntimeError("MagicPDFProcessor failed to initialize. Service cannot start.") # This would stop FastAPI
else:
logger.info("MagicDataReadiness Service is ready and processor is available.")
# --- API Endpoints ---
@app.get("/health", summary="Health Check")
async def health_check():
"""Simple health check endpoint."""
if mdr_processor is None:
logger.warning("/health endpoint called but processor is not initialized.")
raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Processor not initialized")
return {"status": "ok", "message": "MagicPDFProcessor is running."}
@app.post("/process-pdf/",
response_model=List[MDRStructuredBlockModelAPI],
summary="Process a PDF file",
description="Upload a PDF file to extract structured blocks (text, tables, figures, formulas).")
async def process_pdf_endpoint(request: Request, file: UploadFile = File(..., description="The PDF file to process.")):
"""
Handles PDF file upload, processing, and returns extracted blocks.
"""
request_id = str(uuid.uuid4())
client_host = request.client.host if request.client else "unknown"
logger.info(f"RID-{request_id}: Received /process-pdf request from {client_host} for file: '{file.filename}' (type: {file.content_type}, size: {file.size})")
if mdr_processor is None:
logger.error(f"RID-{request_id}: Processor not initialized. Cannot process request.")
raise HTTPException(status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Processor not initialized")
if not file.filename or not file.filename.lower().endswith(".pdf"):
logger.warning(f"RID-{request_id}: Invalid file type uploaded: '{file.filename}'")
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid file type. Please upload a PDF.")
temp_pdf_path_obj = None # To ensure it's always defined for finally block
try:
# Save uploaded file temporarily
save_start_time = time.time()
# Create a temporary directory if it doesn't exist
temp_dir = Path("./temp_uploads") # Consider making this configurable
temp_dir.mkdir(parents=True, exist_ok=True)
# Use a temporary file with a unique name to avoid collisions
with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf", dir=temp_dir, prefix=f"req_{request_id}_") as temp_file:
content = await file.read()
temp_file.write(content)
temp_pdf_path_obj = Path(temp_file.name)
save_duration = time.time() - save_start_time
logger.info(f"RID-{request_id}: File '{file.filename}' saved temporarily to '{temp_pdf_path_obj}' ({save_duration:.2f}s)")
except Exception as e:
logger.error(f"RID-{request_id}: Failed to save uploaded file '{file.filename}': {e}", exc_info=True)
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Failed to save uploaded file: {e}")
extracted_blocks_api: List[MDRStructuredBlockModelAPI] = []
processing_start_time = time.time()
try:
logger.info(f"RID-{request_id}: Starting PDF processing for '{temp_pdf_path_obj}'...")
# Define a progress callback for the processor
def log_progress(completed_pages: int, total_pages: int):
if total_pages > 0:
percent_done = (completed_pages / total_pages) * 100
logger.info(f"RID-{request_id}: Processor progress - Page {completed_pages}/{total_pages} ({percent_done:.1f}%) scanned.")
else:
logger.info(f"RID-{request_id}: Processor progress - Page {completed_pages} scanned (total pages unknown or zero).")
# Process the document using the temporary file path
# MagicPDFProcessor.process_document returns a generator
all_blocks_internal = list(mdr_processor.process_document(
pdf_input=str(temp_pdf_path_obj),
report_progress=log_progress # Pass the progress logger
))
collection_duration = time.time() - processing_start_time
logger.info(f"RID-{request_id}: Extracted {len(all_blocks_internal)} raw blocks from processor ({collection_duration:.2f}s).")
conversion_start_time = time.time()
for i, block_internal in enumerate(all_blocks_internal):
logger.debug(f"RID-{request_id}: Converting internal block {i+1}/{len(all_blocks_internal)} of type {type(block_internal)} to API model.")
api_block = _convert_block_to_api_model(block_internal)
if api_block:
extracted_blocks_api.append(api_block)
conversion_duration = time.time() - conversion_start_time
logger.info(f"RID-{request_id}: Converted {len(extracted_blocks_api)} blocks to API models ({conversion_duration:.2f}s).")
total_processing_duration = time.time() - processing_start_time
logger.info(f"RID-{request_id}: PDF processing finished in {total_processing_duration:.2f}s. Returning {len(extracted_blocks_api)} blocks.")
except Exception as e:
logger.error(f"RID-{request_id}: Error during PDF processing for '{temp_pdf_path_obj}': {e}", exc_info=True)
# Ensure traceback is logged by logger.error with exc_info=True
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"An error occurred during PDF processing: {e}")
finally:
# Clean up the temporary file
if temp_pdf_path_obj and temp_pdf_path_obj.exists():
try:
os.remove(temp_pdf_path_obj)
logger.info(f"RID-{request_id}: Cleaned up temporary file: {temp_pdf_path_obj}")
except OSError as e:
logger.warning(f"RID-{request_id}: Could not remove temporary file {temp_pdf_path_obj}: {e}")
elif temp_pdf_path_obj: # Path was set but file doesn't exist (e.g. save failed)
logger.info(f"RID-{request_id}: Temporary file {temp_pdf_path_obj} not found for cleanup (may have failed to save).")
return extracted_blocks_api
@app.get("/", summary="Root Endpoint")
async def read_root():
"""Provides basic information about the API."""
return {
"message": "Welcome to the MagicDataReadiness PDF Processor API!",
"docs_url": "/docs", # FastAPI default
"redoc_url": "/redoc", # FastAPI default
"health_url": "/health",
"active_configuration": {
"model_directory": MODEL_DIR,
"target_device": DEVICE, # This is the configured device, actual might differ if fallback
"table_format": TABLE_FORMAT.name,
"log_level": LOG_LEVEL_STR,
"processor_debug_output": DEBUG_DIR_PATH_STR if DEBUG_DIR_PATH_STR else "Disabled"
}
}
# --- Main execution for local testing (optional) ---
if __name__ == "__main__":
# This block is for running the app directly with uvicorn for local development.
# It's not strictly necessary if you always run with `uvicorn main:app`.
import uvicorn
logger.info("Starting Uvicorn server for local development...")
uvicorn.run(app, host="0.0.0.0", port=7860, log_level=LOG_LEVEL_STR.lower())