OCR-SMALL / utils.py
AnkitShrestha's picture
Add internal ollama parsing to citizenship ocr
a1c0d1f
from doctr.models import detection_predictor, recognition_predictor
from doctr.io import DocumentFile
from surya.recognition import RecognitionPredictor
from surya.detection import DetectionPredictor
from PIL import Image
# from functools import lru_cache
from torchvision import models
from typing import List
from fastapi import HTTPException
from data_models import Citizenship
import json
import torchvision.transforms as transforms
import torch
import torch.nn as nn
import numpy as np
import cv2
import regex as re
import requests
# import os
import pickle
# Character sets
CHARACTER_NUM = "0123456789-"
CHARACTER_LETTER = ''' "()-./0123456789:?ABCDEFGHIKLMNOPQRSTUWYabcdefghijklmnoprstuvwyँंःअआइईउऊऋऌऍऎएऐऑऒओऔकखगघङचछजझञटठडढणतथदधनऩपफबभमयरऱलळऴवशषसह़ऽािीुूृॄॅॆेैॉॊोौ्ॐ॒॑॓॔क़ख़ग़ज़ड़ढ़फ़य़ॠॢ।॥०१२३४५६७८९॰ॱॲॻॼॽॾ^''' #"()-./0123456789:?ABCDEFGHIKLMNOPQRSTUWYabcdefghijklmnoprstuvwyँंःअआइईउऊऋऌऍऎएऐऑऒओऔकखगघङचछजझञटठडढणतथदधनऩपफबभमयरऱलळऴवशषसह़ऽािीुूृॄॅॆेैॉॊोौ्ॐ॒॑॓॔क़ख़ग़ज़ड़ढ़फ़य़ॠॢ।॥०१२३४५६७८९॰ॱॲॻॼॽॾ^"
# Model paths - these should be configurable
MODEL_PATHS = {
'dev_digits': "models/devnagri_digits_20k_v2.pth",
'roman_digits': "models/roman_digits_20k_v5.pth",
'dev_letter': "models/small_devnagari_letter.pth",
'classify_ne': "models/nepali_english_classifier.pth"
}
# Use GPU if available
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class ResNetClassifier(nn.Module):
def __init__(self, num_classes=2):
super(ResNetClassifier, self).__init__()
self.base_model = models.resnet50(weights='IMAGENET1K_V2') # Pre-trained ResNet-50
for param in self.base_model.parameters():
param.requires_grad = False # Freeze base model
num_ftrs = self.base_model.fc.in_features
self.base_model.fc = nn.Sequential(
nn.Linear(num_ftrs, 128),
nn.ReLU(),
nn.Linear(128, num_classes)
)
def forward(self, x):
return self.base_model(x)
# Define the CRNN model
class CRNN(nn.Module):
def __init__(self, num_classes, input_size=(1, 64, 256)):
super(CRNN, self).__init__()
self.conv_block = nn.Sequential(
nn.Conv2d(input_size[0], 64, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2), # 64x128
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2), # 32x64
nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2), # 16x32
nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2, stride=2) # 8x16
)
# Dimensions after conv: batch x 512 x 8 x 16
feature_height = input_size[1] // 16 # 64 -> 4 pools → 64/2^4 = 4
self.rnn = nn.LSTM(
input_size=512 * feature_height, # 512 * 4 = 2048
hidden_size=128,
num_layers=1,
bidirectional=True,
dropout=0.3,
batch_first=True
)
self.fc = nn.Linear(256, num_classes) # 256 for bidirectional
def forward(self, x):
x = self.conv_block(x) # (B, 512, H=4, W=16)
b, c, h, w = x.size()
x = x.permute(0, 3, 1, 2) # (B, W, C, H)
x = x.contiguous().view(b, w, c * h) # (B, seq_len, input_size)
x, _ = self.rnn(x) # (B, seq_len, 512)
x = self.fc(x) # (B, seq_len, num_classes)
return x
class OCRModelManager:
"""
Singleton class to manage OCR models and prevent repeated loading
"""
_instance = None
def __new__(cls):
if cls._instance is None:
cls._instance = super(OCRModelManager, cls).__new__(cls)
cls._instance.models = {}
cls._instance.char_maps = {}
cls._instance.transforms = {}
cls._instance.initialize_transforms()
# Initialize doctr model once
cls._instance.roman_letter_model = recognition_predictor(pretrained=True)
return cls._instance
def initialize_transforms(self):
"""Initialize standard transforms used across models"""
self.transforms['standard'] = transforms.Compose([
transforms.Resize((64, 256)),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
def get_model(self, model_type, character_set):
"""Get or load a model based on type"""
if model_type not in self.models:
if model_type not in MODEL_PATHS:
raise ValueError(f"Unknown model type: {model_type}")
# Create character to ID mapping
self.char_maps[model_type] = {
'id_to_char': {i: c for i, c in enumerate(character_set)},
'char_to_id': {c: i for i, c in enumerate(character_set)}
}
# Initialize and load model
model = CRNN(num_classes=len(character_set))
model.load_state_dict(torch.load(MODEL_PATHS[model_type], map_location=DEVICE))
model.eval() # Set to evaluation mode
model = model.to(DEVICE)
self.models[model_type] = model
return self.models[model_type], self.char_maps[model_type]
def preprocess_image(self, image_path, model_type):
"""Preprocess image based on model type"""
image = Image.open(image_path).convert('L')
# Apply specific preprocessing based on model type
if model_type != 'dev_letter':
# Binarize the image for digit models
image = image.point(lambda x: 0 if x < 128 else 255, 'L')
# Resize to model input size
image = image.resize((256, 64))
# Invert colors for dev_letter model
if model_type == 'dev_letter':
image = Image.eval(image, lambda x: 255 - x)
# Apply transforms
tensor_image = self.transforms['standard'](image).unsqueeze(0).to(DEVICE)
return tensor_image
def predict(self, image_path, model_type, character_set):
"""Make a prediction using the specified model"""
# Get or load model
model, char_map = self.get_model(model_type, character_set)
# Preprocess image
tensor_image = self.preprocess_image(image_path, model_type)
# Run inference
with torch.no_grad():
output = model(tensor_image)
output = output.permute(1, 0, 2) # (seq_len, batch_size, num_classes)
_, predicted = output.max(2)
predicted = predicted.permute(1, 0) # (batch_size, seq_len)
# Convert tokens to string
predicted_str = ''.join([char_map['id_to_char'][i] for i in predicted[0].cpu().numpy()])
return predicted_str
def predict_roman_letter(self, image_path):
"""Predict using the doctr model for Roman letters"""
img = DocumentFile.from_images(image_path)
result = self.roman_letter_model(img)
# print(result)
return result[0][0]
# Initialize the model manager as a singleton
ocr_manager = OCRModelManager()
# Simplified API functions
def dev_number(image_path):
"""Recognize Devanagari digits in an image"""
return ocr_manager.predict(image_path, 'dev_digits', CHARACTER_NUM)
def roman_number(image_path):
"""Recognize Roman digits in an image"""
return ocr_manager.predict(image_path, 'roman_digits', CHARACTER_NUM)
def dev_letter(image_path):
"""Recognize Devanagari letters in an image"""
return ocr_manager.predict(image_path, 'dev_letter', CHARACTER_LETTER)
def roman_letter(image_path):
"""Recognize Roman letters in an image"""
return ocr_manager.predict_roman_letter(image_path)
def predict_ne(image_path, device="cpu"):
# load label encoder
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ResNetClassifier(num_classes=4).to(device)
# model.eval()
transform = transforms.Compose([
transforms.Resize(256), # Resize shorter side to 256
transforms.CenterCrop(224), # Crop center 224x224 patch
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
image = Image.open(image_path).convert('RGB')
image_tensor = transform(image).unsqueeze(0).to(device)
# loading model weights/state_dict
model.load_state_dict(torch.load('models/dev_roman_classifier.pth', map_location=device))
model.eval()
# loading label encoder
with open('models/dev_roman_label_encoder.pkl', 'rb') as f:
le = pickle.load(f)
with torch.no_grad():
output = model(image_tensor)
_, predicted = torch.max(output, 1)
return le.inverse_transform([predicted.item()])[0]
doctr_detector = None
surya_recognition_predictor = None
surya_detection_predictor = None
def initialize_detector():
global doctr_detector, surya_recognition_predictor, surya_detection_predictor
if doctr_detector is None:
doctr_detector = detection_predictor('db_mobilenet_v3_large', pretrained=True, assume_straight_pages=True, preserve_aspect_ratio=True)
if surya_recognition_predictor is None:
surya_recognition_predictor = RecognitionPredictor()
if surya_detection_predictor is None:
surya_detection_predictor = DetectionPredictor()
return doctr_detector, surya_recognition_predictor, surya_detection_predictor
def get_cleaned_boxes(out, page):
h, w, _ = page.shape
cleaned_boxes = []
for box in out[0]['words']:
coords = np.array(box[:4]) # 4 corner points (normalized)
coords *= np.array([w, h, w, h])
x1, y1, x2, y2 = coords
x_thresh = 0.7 * page.shape[1]
y_thresh = 0.3* page.shape[0]
if x1> x_thresh and y1 < y_thresh:
continue
if (x2 - x1) * (y2 - y1) < 100:
continue
cleaned_boxes.append(coords.astype('int'))
return cleaned_boxes
# The most inefficient code in existence
def merge_boxes_same_line(boxes, y_thresh=5, x_thresh=60):
# Sort boxes first by x and then by y
boxes = sorted(boxes, key=lambda b: (b[1],b[0]))
# Trying make all boxes within certain threshold have the same y coordinate for sorting
# Threshold for grouping rows
row_threshold = 15
aligned_boxes = []
current_row = []
current_y = boxes[0][1]
for box in boxes:
x1, y1, x2, y2 = box
if abs(y1 - current_y) <= row_threshold:
current_row.append(box)
else:
# Align all y1 and y2 in the row
avg_y1 = int(np.mean([b[1] for b in current_row]))
avg_y2 = int(np.mean([b[3] for b in current_row]))
aligned_boxes.extend([(b[0], avg_y1, b[2], avg_y2) for b in current_row])
current_row = [box]
current_y = y1
# Handle the last row
if current_row:
avg_y1 = int(np.mean([b[1] for b in current_row]))
avg_y2 = int(np.mean([b[3] for b in current_row]))
aligned_boxes.extend([(b[0], avg_y1, b[2], avg_y2) for b in current_row])
# After aligning all boxes on y axis, re sort them
aligned_boxes = sorted(aligned_boxes, key=lambda b: (b[1],b[0]))
# Merge adjacent boxes within certain threshold
merged = []
p_x1, p_y1, p_x2, p_y2 = aligned_boxes[0]
for i in range(1,len(aligned_boxes)):
x1, y1, x2, y2 = aligned_boxes[i]
if abs(p_y1 - y1) < y_thresh and abs(x1 - p_x2) < x_thresh:
p_x1 = min(p_x1, x1)
p_y1 = min(p_y1, y1)
p_x2 = max(p_x2, x2)
p_y2 = max(p_y2, y2)
else:
merged.append([p_x1, p_y1, p_x2, p_y2])
p_x1, p_y1, p_x2, p_y2 = x1, y1, x2, y2
merged.append([p_x1, p_y1, p_x2, p_y2])
return np.array(merged)
def ocr_citizenship(image_path: str) -> List[List[str]]:
doctr_detector, surya_recognition_predictor, surya_detection_predictor = initialize_detector()
page = cv2.imread(image_path)
page = cv2.convertScaleAbs(page, alpha=1.5, beta=0)
page = cv2.resize(page, (720,480))
out = doctr_detector([page])
cleaned_boxes = get_cleaned_boxes(out,page)
merged = merge_boxes_same_line(cleaned_boxes)
pattern = r'(नेपाली\s*नागरिकताको\s*प्रमाणपत्र){e<=6}'
prev_y = 0
start = False
first_start = True
y_thresh = 5
text_combine = ''
full_result = []
line_result = []
for boxes in merged[3:]:
x1, y1, x2, y2 = boxes[0],boxes[1],boxes[2],boxes[3]
crop = page[y1:y2,x1:x2]
pil_image = Image.fromarray(crop).convert('L')
# OCR PART
langs = ["en",'ne']
predictions = surya_recognition_predictor(images=[pil_image], langs=[langs],det_predictor=surya_detection_predictor)
text_combo = ''
for text_line in predictions[0].text_lines:
text_combo = text_combo + " " + text_line.text.strip()
text_combo = text_combo.strip()
# OCR PART END
if not start:
match = re.search(pattern, text_combo)
if match:
start = True
continue
if first_start:
first_start = False
prev_y = boxes[1]
if y1 - prev_y > y_thresh:
full_result.append(line_result)
line_result = []
line_result.append(text_combo)
prev_y = boxes[1]
return full_result
PARSE_PROMPT = "You are a parsing agent. Your task is to generate a json response from the given text corpus."
def create_local_model(message, base_model):
try:
ollama_endpoint = "api/chat"
url = f"https://aioverlords-amnil-internal-ollama.hf.space/proxy/{ollama_endpoint}"
# Data to send in the POST request
data = {
"data": {
"model": "aisingapore/Llama-SEA-LION-v3-8B-IT",
"messages": message,
"stream": False,
"format": base_model.model_json_schema()
}
}
response = requests.post(url, json=data)
# Check the response
if response.status_code == 200:
print(f"Request Success:", response.json())
return json.loads(response.json()["message"]["content"])
# return response.json()
else:
print(f"Request Error:", response.status_code, response.text)
raise HTTPException(status_code=response.status_code, detail=response.text)
except HTTPException as http_exec:
raise http_exec
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
def perform_citizenship_ocr(image_path):
try:
unparsed_result = ocr_citizenship(image_path)
message = [
{"role": "system", "content": PARSE_PROMPT},
{"role": "user", "content": f"Given Text: \n{unparsed_result}"},
]
return create_local_model(message, Citizenship)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))