from doctr.models import detection_predictor, recognition_predictor from doctr.io import DocumentFile from surya.recognition import RecognitionPredictor from surya.detection import DetectionPredictor from PIL import Image # from functools import lru_cache from torchvision import models from typing import List from fastapi import HTTPException from data_models import Citizenship import json import torchvision.transforms as transforms import torch import torch.nn as nn import numpy as np import cv2 import regex as re import requests # import os import pickle # Character sets CHARACTER_NUM = "0123456789-" CHARACTER_LETTER = ''' "()-./0123456789:?ABCDEFGHIKLMNOPQRSTUWYabcdefghijklmnoprstuvwyँंःअआइईउऊऋऌऍऎएऐऑऒओऔकखगघङचछजझञटठडढणतथदधनऩपफबभमयरऱलळऴवशषसह़ऽािीुूृॄॅॆेैॉॊोौ्ॐ॒॑॓॔क़ख़ग़ज़ड़ढ़फ़य़ॠॢ।॥०१२३४५६७८९॰ॱॲॻॼॽॾ^''' #"()-./0123456789:?ABCDEFGHIKLMNOPQRSTUWYabcdefghijklmnoprstuvwyँंःअआइईउऊऋऌऍऎएऐऑऒओऔकखगघङचछजझञटठडढणतथदधनऩपफबभमयरऱलळऴवशषसह़ऽािीुूृॄॅॆेैॉॊोौ्ॐ॒॑॓॔क़ख़ग़ज़ड़ढ़फ़य़ॠॢ।॥०१२३४५६७८९॰ॱॲॻॼॽॾ^" # Model paths - these should be configurable MODEL_PATHS = { 'dev_digits': "models/devnagri_digits_20k_v2.pth", 'roman_digits': "models/roman_digits_20k_v5.pth", 'dev_letter': "models/small_devnagari_letter.pth", 'classify_ne': "models/nepali_english_classifier.pth" } # Use GPU if available DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") class ResNetClassifier(nn.Module): def __init__(self, num_classes=2): super(ResNetClassifier, self).__init__() self.base_model = models.resnet50(weights='IMAGENET1K_V2') # Pre-trained ResNet-50 for param in self.base_model.parameters(): param.requires_grad = False # Freeze base model num_ftrs = self.base_model.fc.in_features self.base_model.fc = nn.Sequential( nn.Linear(num_ftrs, 128), nn.ReLU(), nn.Linear(128, num_classes) ) def forward(self, x): return self.base_model(x) # Define the CRNN model class CRNN(nn.Module): def __init__(self, num_classes, input_size=(1, 64, 256)): super(CRNN, self).__init__() self.conv_block = nn.Sequential( nn.Conv2d(input_size[0], 64, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2), # 64x128 nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(128), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2), # 32x64 nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(256), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2), # 16x32 nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1), nn.BatchNorm2d(512), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2) # 8x16 ) # Dimensions after conv: batch x 512 x 8 x 16 feature_height = input_size[1] // 16 # 64 -> 4 pools → 64/2^4 = 4 self.rnn = nn.LSTM( input_size=512 * feature_height, # 512 * 4 = 2048 hidden_size=128, num_layers=1, bidirectional=True, dropout=0.3, batch_first=True ) self.fc = nn.Linear(256, num_classes) # 256 for bidirectional def forward(self, x): x = self.conv_block(x) # (B, 512, H=4, W=16) b, c, h, w = x.size() x = x.permute(0, 3, 1, 2) # (B, W, C, H) x = x.contiguous().view(b, w, c * h) # (B, seq_len, input_size) x, _ = self.rnn(x) # (B, seq_len, 512) x = self.fc(x) # (B, seq_len, num_classes) return x class OCRModelManager: """ Singleton class to manage OCR models and prevent repeated loading """ _instance = None def __new__(cls): if cls._instance is None: cls._instance = super(OCRModelManager, cls).__new__(cls) cls._instance.models = {} cls._instance.char_maps = {} cls._instance.transforms = {} cls._instance.initialize_transforms() # Initialize doctr model once cls._instance.roman_letter_model = recognition_predictor(pretrained=True) return cls._instance def initialize_transforms(self): """Initialize standard transforms used across models""" self.transforms['standard'] = transforms.Compose([ transforms.Resize((64, 256)), transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ]) def get_model(self, model_type, character_set): """Get or load a model based on type""" if model_type not in self.models: if model_type not in MODEL_PATHS: raise ValueError(f"Unknown model type: {model_type}") # Create character to ID mapping self.char_maps[model_type] = { 'id_to_char': {i: c for i, c in enumerate(character_set)}, 'char_to_id': {c: i for i, c in enumerate(character_set)} } # Initialize and load model model = CRNN(num_classes=len(character_set)) model.load_state_dict(torch.load(MODEL_PATHS[model_type], map_location=DEVICE)) model.eval() # Set to evaluation mode model = model.to(DEVICE) self.models[model_type] = model return self.models[model_type], self.char_maps[model_type] def preprocess_image(self, image_path, model_type): """Preprocess image based on model type""" image = Image.open(image_path).convert('L') # Apply specific preprocessing based on model type if model_type != 'dev_letter': # Binarize the image for digit models image = image.point(lambda x: 0 if x < 128 else 255, 'L') # Resize to model input size image = image.resize((256, 64)) # Invert colors for dev_letter model if model_type == 'dev_letter': image = Image.eval(image, lambda x: 255 - x) # Apply transforms tensor_image = self.transforms['standard'](image).unsqueeze(0).to(DEVICE) return tensor_image def predict(self, image_path, model_type, character_set): """Make a prediction using the specified model""" # Get or load model model, char_map = self.get_model(model_type, character_set) # Preprocess image tensor_image = self.preprocess_image(image_path, model_type) # Run inference with torch.no_grad(): output = model(tensor_image) output = output.permute(1, 0, 2) # (seq_len, batch_size, num_classes) _, predicted = output.max(2) predicted = predicted.permute(1, 0) # (batch_size, seq_len) # Convert tokens to string predicted_str = ''.join([char_map['id_to_char'][i] for i in predicted[0].cpu().numpy()]) return predicted_str def predict_roman_letter(self, image_path): """Predict using the doctr model for Roman letters""" img = DocumentFile.from_images(image_path) result = self.roman_letter_model(img) # print(result) return result[0][0] # Initialize the model manager as a singleton ocr_manager = OCRModelManager() # Simplified API functions def dev_number(image_path): """Recognize Devanagari digits in an image""" return ocr_manager.predict(image_path, 'dev_digits', CHARACTER_NUM) def roman_number(image_path): """Recognize Roman digits in an image""" return ocr_manager.predict(image_path, 'roman_digits', CHARACTER_NUM) def dev_letter(image_path): """Recognize Devanagari letters in an image""" return ocr_manager.predict(image_path, 'dev_letter', CHARACTER_LETTER) def roman_letter(image_path): """Recognize Roman letters in an image""" return ocr_manager.predict_roman_letter(image_path) def predict_ne(image_path, device="cpu"): # load label encoder device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') model = ResNetClassifier(num_classes=4).to(device) # model.eval() transform = transforms.Compose([ transforms.Resize(256), # Resize shorter side to 256 transforms.CenterCrop(224), # Crop center 224x224 patch transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) image = Image.open(image_path).convert('RGB') image_tensor = transform(image).unsqueeze(0).to(device) # loading model weights/state_dict model.load_state_dict(torch.load('models/dev_roman_classifier.pth', map_location=device)) model.eval() # loading label encoder with open('models/dev_roman_label_encoder.pkl', 'rb') as f: le = pickle.load(f) with torch.no_grad(): output = model(image_tensor) _, predicted = torch.max(output, 1) return le.inverse_transform([predicted.item()])[0] doctr_detector = None surya_recognition_predictor = None surya_detection_predictor = None def initialize_detector(): global doctr_detector, surya_recognition_predictor, surya_detection_predictor if doctr_detector is None: doctr_detector = detection_predictor('db_mobilenet_v3_large', pretrained=True, assume_straight_pages=True, preserve_aspect_ratio=True) if surya_recognition_predictor is None: surya_recognition_predictor = RecognitionPredictor() if surya_detection_predictor is None: surya_detection_predictor = DetectionPredictor() return doctr_detector, surya_recognition_predictor, surya_detection_predictor def get_cleaned_boxes(out, page): h, w, _ = page.shape cleaned_boxes = [] for box in out[0]['words']: coords = np.array(box[:4]) # 4 corner points (normalized) coords *= np.array([w, h, w, h]) x1, y1, x2, y2 = coords x_thresh = 0.7 * page.shape[1] y_thresh = 0.3* page.shape[0] if x1> x_thresh and y1 < y_thresh: continue if (x2 - x1) * (y2 - y1) < 100: continue cleaned_boxes.append(coords.astype('int')) return cleaned_boxes # The most inefficient code in existence def merge_boxes_same_line(boxes, y_thresh=5, x_thresh=60): # Sort boxes first by x and then by y boxes = sorted(boxes, key=lambda b: (b[1],b[0])) # Trying make all boxes within certain threshold have the same y coordinate for sorting # Threshold for grouping rows row_threshold = 15 aligned_boxes = [] current_row = [] current_y = boxes[0][1] for box in boxes: x1, y1, x2, y2 = box if abs(y1 - current_y) <= row_threshold: current_row.append(box) else: # Align all y1 and y2 in the row avg_y1 = int(np.mean([b[1] for b in current_row])) avg_y2 = int(np.mean([b[3] for b in current_row])) aligned_boxes.extend([(b[0], avg_y1, b[2], avg_y2) for b in current_row]) current_row = [box] current_y = y1 # Handle the last row if current_row: avg_y1 = int(np.mean([b[1] for b in current_row])) avg_y2 = int(np.mean([b[3] for b in current_row])) aligned_boxes.extend([(b[0], avg_y1, b[2], avg_y2) for b in current_row]) # After aligning all boxes on y axis, re sort them aligned_boxes = sorted(aligned_boxes, key=lambda b: (b[1],b[0])) # Merge adjacent boxes within certain threshold merged = [] p_x1, p_y1, p_x2, p_y2 = aligned_boxes[0] for i in range(1,len(aligned_boxes)): x1, y1, x2, y2 = aligned_boxes[i] if abs(p_y1 - y1) < y_thresh and abs(x1 - p_x2) < x_thresh: p_x1 = min(p_x1, x1) p_y1 = min(p_y1, y1) p_x2 = max(p_x2, x2) p_y2 = max(p_y2, y2) else: merged.append([p_x1, p_y1, p_x2, p_y2]) p_x1, p_y1, p_x2, p_y2 = x1, y1, x2, y2 merged.append([p_x1, p_y1, p_x2, p_y2]) return np.array(merged) def ocr_citizenship(image_path: str) -> List[List[str]]: doctr_detector, surya_recognition_predictor, surya_detection_predictor = initialize_detector() page = cv2.imread(image_path) page = cv2.convertScaleAbs(page, alpha=1.5, beta=0) page = cv2.resize(page, (720,480)) out = doctr_detector([page]) cleaned_boxes = get_cleaned_boxes(out,page) merged = merge_boxes_same_line(cleaned_boxes) pattern = r'(नेपाली\s*नागरिकताको\s*प्रमाणपत्र){e<=6}' prev_y = 0 start = False first_start = True y_thresh = 5 text_combine = '' full_result = [] line_result = [] for boxes in merged[3:]: x1, y1, x2, y2 = boxes[0],boxes[1],boxes[2],boxes[3] crop = page[y1:y2,x1:x2] pil_image = Image.fromarray(crop).convert('L') # OCR PART langs = ["en",'ne'] predictions = surya_recognition_predictor(images=[pil_image], langs=[langs],det_predictor=surya_detection_predictor) text_combo = '' for text_line in predictions[0].text_lines: text_combo = text_combo + " " + text_line.text.strip() text_combo = text_combo.strip() # OCR PART END if not start: match = re.search(pattern, text_combo) if match: start = True continue if first_start: first_start = False prev_y = boxes[1] if y1 - prev_y > y_thresh: full_result.append(line_result) line_result = [] line_result.append(text_combo) prev_y = boxes[1] return full_result PARSE_PROMPT = "You are a parsing agent. Your task is to generate a json response from the given text corpus." def create_local_model(message, base_model): try: ollama_endpoint = "api/chat" url = f"https://aioverlords-amnil-internal-ollama.hf.space/proxy/{ollama_endpoint}" # Data to send in the POST request data = { "data": { "model": "aisingapore/Llama-SEA-LION-v3-8B-IT", "messages": message, "stream": False, "format": base_model.model_json_schema() } } response = requests.post(url, json=data) # Check the response if response.status_code == 200: print(f"Request Success:", response.json()) return json.loads(response.json()["message"]["content"]) # return response.json() else: print(f"Request Error:", response.status_code, response.text) raise HTTPException(status_code=response.status_code, detail=response.text) except HTTPException as http_exec: raise http_exec except Exception as e: raise HTTPException(status_code=500, detail=str(e)) def perform_citizenship_ocr(image_path): try: unparsed_result = ocr_citizenship(image_path) message = [ {"role": "system", "content": PARSE_PROMPT}, {"role": "user", "content": f"Given Text: \n{unparsed_result}"}, ] return create_local_model(message, Citizenship) except Exception as e: raise HTTPException(status_code=500, detail=str(e))