Spaces:
Sleeping
Sleeping
from doctr.models import detection_predictor, recognition_predictor | |
from doctr.io import DocumentFile | |
from surya.recognition import RecognitionPredictor | |
from surya.detection import DetectionPredictor | |
from PIL import Image | |
# from functools import lru_cache | |
from torchvision import models | |
from typing import List | |
from fastapi import HTTPException | |
from data_models import Citizenship | |
import json | |
import torchvision.transforms as transforms | |
import torch | |
import torch.nn as nn | |
import numpy as np | |
import cv2 | |
import regex as re | |
import requests | |
# import os | |
import pickle | |
# Character sets | |
CHARACTER_NUM = "0123456789-" | |
CHARACTER_LETTER = ''' "()-./0123456789:?ABCDEFGHIKLMNOPQRSTUWYabcdefghijklmnoprstuvwyँंःअआइईउऊऋऌऍऎएऐऑऒओऔकखगघङचछजझञटठडढणतथदधनऩपफबभमयरऱलळऴवशषसह़ऽािीुूृॄॅॆेैॉॊोौ्ॐ॒॑॓॔क़ख़ग़ज़ड़ढ़फ़य़ॠॢ।॥०१२३४५६७८९॰ॱॲॻॼॽॾ^''' #"()-./0123456789:?ABCDEFGHIKLMNOPQRSTUWYabcdefghijklmnoprstuvwyँंःअआइईउऊऋऌऍऎएऐऑऒओऔकखगघङचछजझञटठडढणतथदधनऩपफबभमयरऱलळऴवशषसह़ऽािीुूृॄॅॆेैॉॊोौ्ॐ॒॑॓॔क़ख़ग़ज़ड़ढ़फ़य़ॠॢ।॥०१२३४५६७८९॰ॱॲॻॼॽॾ^" | |
# Model paths - these should be configurable | |
MODEL_PATHS = { | |
'dev_digits': "models/devnagri_digits_20k_v2.pth", | |
'roman_digits': "models/roman_digits_20k_v5.pth", | |
'dev_letter': "models/small_devnagari_letter.pth", | |
'classify_ne': "models/nepali_english_classifier.pth" | |
} | |
# Use GPU if available | |
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
class ResNetClassifier(nn.Module): | |
def __init__(self, num_classes=2): | |
super(ResNetClassifier, self).__init__() | |
self.base_model = models.resnet50(weights='IMAGENET1K_V2') # Pre-trained ResNet-50 | |
for param in self.base_model.parameters(): | |
param.requires_grad = False # Freeze base model | |
num_ftrs = self.base_model.fc.in_features | |
self.base_model.fc = nn.Sequential( | |
nn.Linear(num_ftrs, 128), | |
nn.ReLU(), | |
nn.Linear(128, num_classes) | |
) | |
def forward(self, x): | |
return self.base_model(x) | |
# Define the CRNN model | |
class CRNN(nn.Module): | |
def __init__(self, num_classes, input_size=(1, 64, 256)): | |
super(CRNN, self).__init__() | |
self.conv_block = nn.Sequential( | |
nn.Conv2d(input_size[0], 64, kernel_size=3, stride=1, padding=1), | |
nn.BatchNorm2d(64), | |
nn.ReLU(), | |
nn.MaxPool2d(kernel_size=2, stride=2), # 64x128 | |
nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1), | |
nn.BatchNorm2d(128), | |
nn.ReLU(), | |
nn.MaxPool2d(kernel_size=2, stride=2), # 32x64 | |
nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1), | |
nn.BatchNorm2d(256), | |
nn.ReLU(), | |
nn.MaxPool2d(kernel_size=2, stride=2), # 16x32 | |
nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1), | |
nn.BatchNorm2d(512), | |
nn.ReLU(), | |
nn.MaxPool2d(kernel_size=2, stride=2) # 8x16 | |
) | |
# Dimensions after conv: batch x 512 x 8 x 16 | |
feature_height = input_size[1] // 16 # 64 -> 4 pools → 64/2^4 = 4 | |
self.rnn = nn.LSTM( | |
input_size=512 * feature_height, # 512 * 4 = 2048 | |
hidden_size=128, | |
num_layers=1, | |
bidirectional=True, | |
dropout=0.3, | |
batch_first=True | |
) | |
self.fc = nn.Linear(256, num_classes) # 256 for bidirectional | |
def forward(self, x): | |
x = self.conv_block(x) # (B, 512, H=4, W=16) | |
b, c, h, w = x.size() | |
x = x.permute(0, 3, 1, 2) # (B, W, C, H) | |
x = x.contiguous().view(b, w, c * h) # (B, seq_len, input_size) | |
x, _ = self.rnn(x) # (B, seq_len, 512) | |
x = self.fc(x) # (B, seq_len, num_classes) | |
return x | |
class OCRModelManager: | |
""" | |
Singleton class to manage OCR models and prevent repeated loading | |
""" | |
_instance = None | |
def __new__(cls): | |
if cls._instance is None: | |
cls._instance = super(OCRModelManager, cls).__new__(cls) | |
cls._instance.models = {} | |
cls._instance.char_maps = {} | |
cls._instance.transforms = {} | |
cls._instance.initialize_transforms() | |
# Initialize doctr model once | |
cls._instance.roman_letter_model = recognition_predictor(pretrained=True) | |
return cls._instance | |
def initialize_transforms(self): | |
"""Initialize standard transforms used across models""" | |
self.transforms['standard'] = transforms.Compose([ | |
transforms.Resize((64, 256)), | |
transforms.ToTensor(), | |
transforms.Normalize((0.5,), (0.5,)) | |
]) | |
def get_model(self, model_type, character_set): | |
"""Get or load a model based on type""" | |
if model_type not in self.models: | |
if model_type not in MODEL_PATHS: | |
raise ValueError(f"Unknown model type: {model_type}") | |
# Create character to ID mapping | |
self.char_maps[model_type] = { | |
'id_to_char': {i: c for i, c in enumerate(character_set)}, | |
'char_to_id': {c: i for i, c in enumerate(character_set)} | |
} | |
# Initialize and load model | |
model = CRNN(num_classes=len(character_set)) | |
model.load_state_dict(torch.load(MODEL_PATHS[model_type], map_location=DEVICE)) | |
model.eval() # Set to evaluation mode | |
model = model.to(DEVICE) | |
self.models[model_type] = model | |
return self.models[model_type], self.char_maps[model_type] | |
def preprocess_image(self, image_path, model_type): | |
"""Preprocess image based on model type""" | |
image = Image.open(image_path).convert('L') | |
# Apply specific preprocessing based on model type | |
if model_type != 'dev_letter': | |
# Binarize the image for digit models | |
image = image.point(lambda x: 0 if x < 128 else 255, 'L') | |
# Resize to model input size | |
image = image.resize((256, 64)) | |
# Invert colors for dev_letter model | |
if model_type == 'dev_letter': | |
image = Image.eval(image, lambda x: 255 - x) | |
# Apply transforms | |
tensor_image = self.transforms['standard'](image).unsqueeze(0).to(DEVICE) | |
return tensor_image | |
def predict(self, image_path, model_type, character_set): | |
"""Make a prediction using the specified model""" | |
# Get or load model | |
model, char_map = self.get_model(model_type, character_set) | |
# Preprocess image | |
tensor_image = self.preprocess_image(image_path, model_type) | |
# Run inference | |
with torch.no_grad(): | |
output = model(tensor_image) | |
output = output.permute(1, 0, 2) # (seq_len, batch_size, num_classes) | |
_, predicted = output.max(2) | |
predicted = predicted.permute(1, 0) # (batch_size, seq_len) | |
# Convert tokens to string | |
predicted_str = ''.join([char_map['id_to_char'][i] for i in predicted[0].cpu().numpy()]) | |
return predicted_str | |
def predict_roman_letter(self, image_path): | |
"""Predict using the doctr model for Roman letters""" | |
img = DocumentFile.from_images(image_path) | |
result = self.roman_letter_model(img) | |
# print(result) | |
return result[0][0] | |
# Initialize the model manager as a singleton | |
ocr_manager = OCRModelManager() | |
# Simplified API functions | |
def dev_number(image_path): | |
"""Recognize Devanagari digits in an image""" | |
return ocr_manager.predict(image_path, 'dev_digits', CHARACTER_NUM) | |
def roman_number(image_path): | |
"""Recognize Roman digits in an image""" | |
return ocr_manager.predict(image_path, 'roman_digits', CHARACTER_NUM) | |
def dev_letter(image_path): | |
"""Recognize Devanagari letters in an image""" | |
return ocr_manager.predict(image_path, 'dev_letter', CHARACTER_LETTER) | |
def roman_letter(image_path): | |
"""Recognize Roman letters in an image""" | |
return ocr_manager.predict_roman_letter(image_path) | |
def predict_ne(image_path, device="cpu"): | |
# load label encoder | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
model = ResNetClassifier(num_classes=4).to(device) | |
# model.eval() | |
transform = transforms.Compose([ | |
transforms.Resize(256), # Resize shorter side to 256 | |
transforms.CenterCrop(224), # Crop center 224x224 patch | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225]) | |
]) | |
image = Image.open(image_path).convert('RGB') | |
image_tensor = transform(image).unsqueeze(0).to(device) | |
# loading model weights/state_dict | |
model.load_state_dict(torch.load('models/dev_roman_classifier.pth', map_location=device)) | |
model.eval() | |
# loading label encoder | |
with open('models/dev_roman_label_encoder.pkl', 'rb') as f: | |
le = pickle.load(f) | |
with torch.no_grad(): | |
output = model(image_tensor) | |
_, predicted = torch.max(output, 1) | |
return le.inverse_transform([predicted.item()])[0] | |
doctr_detector = None | |
surya_recognition_predictor = None | |
surya_detection_predictor = None | |
def initialize_detector(): | |
global doctr_detector, surya_recognition_predictor, surya_detection_predictor | |
if doctr_detector is None: | |
doctr_detector = detection_predictor('db_mobilenet_v3_large', pretrained=True, assume_straight_pages=True, preserve_aspect_ratio=True) | |
if surya_recognition_predictor is None: | |
surya_recognition_predictor = RecognitionPredictor() | |
if surya_detection_predictor is None: | |
surya_detection_predictor = DetectionPredictor() | |
return doctr_detector, surya_recognition_predictor, surya_detection_predictor | |
def get_cleaned_boxes(out, page): | |
h, w, _ = page.shape | |
cleaned_boxes = [] | |
for box in out[0]['words']: | |
coords = np.array(box[:4]) # 4 corner points (normalized) | |
coords *= np.array([w, h, w, h]) | |
x1, y1, x2, y2 = coords | |
x_thresh = 0.7 * page.shape[1] | |
y_thresh = 0.3* page.shape[0] | |
if x1> x_thresh and y1 < y_thresh: | |
continue | |
if (x2 - x1) * (y2 - y1) < 100: | |
continue | |
cleaned_boxes.append(coords.astype('int')) | |
return cleaned_boxes | |
# The most inefficient code in existence | |
def merge_boxes_same_line(boxes, y_thresh=5, x_thresh=60): | |
# Sort boxes first by x and then by y | |
boxes = sorted(boxes, key=lambda b: (b[1],b[0])) | |
# Trying make all boxes within certain threshold have the same y coordinate for sorting | |
# Threshold for grouping rows | |
row_threshold = 15 | |
aligned_boxes = [] | |
current_row = [] | |
current_y = boxes[0][1] | |
for box in boxes: | |
x1, y1, x2, y2 = box | |
if abs(y1 - current_y) <= row_threshold: | |
current_row.append(box) | |
else: | |
# Align all y1 and y2 in the row | |
avg_y1 = int(np.mean([b[1] for b in current_row])) | |
avg_y2 = int(np.mean([b[3] for b in current_row])) | |
aligned_boxes.extend([(b[0], avg_y1, b[2], avg_y2) for b in current_row]) | |
current_row = [box] | |
current_y = y1 | |
# Handle the last row | |
if current_row: | |
avg_y1 = int(np.mean([b[1] for b in current_row])) | |
avg_y2 = int(np.mean([b[3] for b in current_row])) | |
aligned_boxes.extend([(b[0], avg_y1, b[2], avg_y2) for b in current_row]) | |
# After aligning all boxes on y axis, re sort them | |
aligned_boxes = sorted(aligned_boxes, key=lambda b: (b[1],b[0])) | |
# Merge adjacent boxes within certain threshold | |
merged = [] | |
p_x1, p_y1, p_x2, p_y2 = aligned_boxes[0] | |
for i in range(1,len(aligned_boxes)): | |
x1, y1, x2, y2 = aligned_boxes[i] | |
if abs(p_y1 - y1) < y_thresh and abs(x1 - p_x2) < x_thresh: | |
p_x1 = min(p_x1, x1) | |
p_y1 = min(p_y1, y1) | |
p_x2 = max(p_x2, x2) | |
p_y2 = max(p_y2, y2) | |
else: | |
merged.append([p_x1, p_y1, p_x2, p_y2]) | |
p_x1, p_y1, p_x2, p_y2 = x1, y1, x2, y2 | |
merged.append([p_x1, p_y1, p_x2, p_y2]) | |
return np.array(merged) | |
def ocr_citizenship(image_path: str) -> List[List[str]]: | |
doctr_detector, surya_recognition_predictor, surya_detection_predictor = initialize_detector() | |
page = cv2.imread(image_path) | |
page = cv2.convertScaleAbs(page, alpha=1.5, beta=0) | |
page = cv2.resize(page, (720,480)) | |
out = doctr_detector([page]) | |
cleaned_boxes = get_cleaned_boxes(out,page) | |
merged = merge_boxes_same_line(cleaned_boxes) | |
pattern = r'(नेपाली\s*नागरिकताको\s*प्रमाणपत्र){e<=6}' | |
prev_y = 0 | |
start = False | |
first_start = True | |
y_thresh = 5 | |
text_combine = '' | |
full_result = [] | |
line_result = [] | |
for boxes in merged[3:]: | |
x1, y1, x2, y2 = boxes[0],boxes[1],boxes[2],boxes[3] | |
crop = page[y1:y2,x1:x2] | |
pil_image = Image.fromarray(crop).convert('L') | |
# OCR PART | |
langs = ["en",'ne'] | |
predictions = surya_recognition_predictor(images=[pil_image], langs=[langs],det_predictor=surya_detection_predictor) | |
text_combo = '' | |
for text_line in predictions[0].text_lines: | |
text_combo = text_combo + " " + text_line.text.strip() | |
text_combo = text_combo.strip() | |
# OCR PART END | |
if not start: | |
match = re.search(pattern, text_combo) | |
if match: | |
start = True | |
continue | |
if first_start: | |
first_start = False | |
prev_y = boxes[1] | |
if y1 - prev_y > y_thresh: | |
full_result.append(line_result) | |
line_result = [] | |
line_result.append(text_combo) | |
prev_y = boxes[1] | |
return full_result | |
PARSE_PROMPT = "You are a parsing agent. Your task is to generate a json response from the given text corpus." | |
def create_local_model(message, base_model): | |
try: | |
ollama_endpoint = "api/chat" | |
url = f"https://aioverlords-amnil-internal-ollama.hf.space/proxy/{ollama_endpoint}" | |
# Data to send in the POST request | |
data = { | |
"data": { | |
"model": "aisingapore/Llama-SEA-LION-v3-8B-IT", | |
"messages": message, | |
"stream": False, | |
"format": base_model.model_json_schema() | |
} | |
} | |
response = requests.post(url, json=data) | |
# Check the response | |
if response.status_code == 200: | |
print(f"Request Success:", response.json()) | |
return json.loads(response.json()["message"]["content"]) | |
# return response.json() | |
else: | |
print(f"Request Error:", response.status_code, response.text) | |
raise HTTPException(status_code=response.status_code, detail=response.text) | |
except HTTPException as http_exec: | |
raise http_exec | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
def perform_citizenship_ocr(image_path): | |
try: | |
unparsed_result = ocr_citizenship(image_path) | |
message = [ | |
{"role": "system", "content": PARSE_PROMPT}, | |
{"role": "user", "content": f"Given Text: \n{unparsed_result}"}, | |
] | |
return create_local_model(message, Citizenship) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) |