AnkitShrestha commited on
Commit
76e8a07
·
1 Parent(s): 415c25f

Add citizenship ocr endpoint

Browse files
Files changed (3) hide show
  1. main.py +11 -1
  2. requirements.txt +3 -1
  3. utils.py +139 -168
main.py CHANGED
@@ -71,7 +71,7 @@ from pydantic import BaseModel
71
  import shutil
72
 
73
  # Import from optimized utils
74
- from utils import dev_number, roman_number, dev_letter, roman_letter, predict_ne
75
 
76
  app = FastAPI(
77
  title="OCR API",
@@ -193,6 +193,16 @@ async def classify_ne(image: UploadFile = File(...)):
193
 
194
  # Implement the logic as per your requirements
195
  return JSONResponse(content={"predicted": prediction})
 
 
 
 
 
 
 
 
 
 
196
  # Health check endpoint
197
  @app.get("/health")
198
  async def health_check():
 
71
  import shutil
72
 
73
  # Import from optimized utils
74
+ from utils import dev_number, roman_number, dev_letter, roman_letter, predict_ne, ocr_citizenship_utils
75
 
76
  app = FastAPI(
77
  title="OCR API",
 
193
 
194
  # Implement the logic as per your requirements
195
  return JSONResponse(content={"predicted": prediction})
196
+
197
+ @app.post("/ocr_citizenship/")
198
+ async def ocr_citizenship(image: UploadFile = File(...)):
199
+ """OCR the provided Nepali Citizenship card"""
200
+ image_path = await save_upload_file_tmp(image)
201
+ prediction = ocr_citizenship_utils(
202
+ image_path=image_path,
203
+ )
204
+
205
+ return JSONResponse(content=prediction)
206
  # Health check endpoint
207
  @app.get("/health")
208
  async def health_check():
requirements.txt CHANGED
@@ -8,4 +8,6 @@ fastapi
8
  uvicorn
9
  pydantic
10
  python-multipart
11
- scikit-learn==1.6.1
 
 
 
8
  uvicorn
9
  pydantic
10
  python-multipart
11
+ scikit-learn==1.6.1
12
+ opencv-python-headless
13
+ surya-ocr
utils.py CHANGED
@@ -1,175 +1,21 @@
1
- # import torch
2
- # import torch.nn as nn
3
- # from PIL import Image
4
- # import numpy as np
5
- # import matplotlib.pyplot as plt
6
- # import torchvision.transforms as transforms
7
- # from doctr.io import DocumentFile
8
- # from doctr.models import recognition_predictor
9
-
10
- # character_num = "0123456789-"
11
- # character_letter = ''' "()-./0123456789:?ABCDEFGHIKLMNOPQRSTUWYabcdefghijklmnoprstuvwyँंःअआइईउऊऋऌऍऎएऐऑऒओऔकखगघङचछजझञटठडढणतथदधनऩपफबभमयरऱलळऴवशषसह़ऽािीुूृॄॅॆेैॉॊोौ्ॐ॒॑॓॔क़ख़ग़ज़ड़ढ़फ़य़ॠॢ।॥०१२३४५६७८९॰ॱॲॻॼॽॾ^''' #"()-./0123456789:?ABCDEFGHIKLMNOPQRSTUWYabcdefghijklmnoprstuvwyँंःअआइईउऊऋऌऍऎएऐऑऒओऔकखगघङचछजझञटठडढणतथदधनऩपफबभमयरऱलळऴवशषसह़ऽािीुूृॄॅॆेैॉॊोौ्ॐ॒॑॓॔क़ख़ग़ज़ड़ढ़फ़य़ॠॢ।॥०१२३४५६७८९॰ॱॲॻॼॽॾ^"
12
-
13
- # model_dev_digits_path = "models/devnagri_digits_20k_v2.pth"
14
- # model_roman_digits_path = "models/roman_digits_20k_v5.pth"
15
- # dev_letter_path = "models/small_devnagari_letter.pth"
16
-
17
- # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
-
19
- # # Define the CRNN model
20
- # class CRNN(nn.Module):
21
- # def __init__(self, num_classes, input_size=(1, 64, 256)):
22
- # super(CRNN, self).__init__()
23
-
24
- # self.conv_block = nn.Sequential(
25
- # nn.Conv2d(input_size[0], 64, kernel_size=3, stride=1, padding=1),
26
- # nn.BatchNorm2d(64),
27
- # nn.ReLU(),
28
- # nn.MaxPool2d(kernel_size=2, stride=2), # 64x128
29
-
30
- # nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
31
- # nn.BatchNorm2d(128),
32
- # nn.ReLU(),
33
- # nn.MaxPool2d(kernel_size=2, stride=2), # 32x64
34
-
35
- # nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
36
- # nn.BatchNorm2d(256),
37
- # nn.ReLU(),
38
- # nn.MaxPool2d(kernel_size=2, stride=2), # 16x32
39
-
40
- # nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
41
- # nn.BatchNorm2d(512),
42
- # nn.ReLU(),
43
- # nn.MaxPool2d(kernel_size=2, stride=2) # 8x16
44
- # )
45
-
46
- # # Dimensions after conv: batch x 512 x 8 x 16
47
- # feature_height = input_size[1] // 16 # 64 -> 4 pools → 64/2^4 = 4
48
-
49
- # self.rnn = nn.LSTM(
50
- # input_size=512 * feature_height, # 512 * 4 = 2048
51
- # hidden_size=128,
52
- # num_layers=1,
53
- # bidirectional=True,
54
- # dropout=0.3,
55
- # batch_first=True
56
- # )
57
-
58
- # self.fc = nn.Linear(256, num_classes) # 256*2 = 512
59
-
60
- # def forward(self, x):
61
- # x = self.conv_block(x) # (B, 512, H=4, W=16)
62
- # b, c, h, w = x.size()
63
- # x = x.permute(0, 3, 1, 2) # (B, W, C, H)
64
- # x = x.contiguous().view(b, w, c * h) # (B, seq_len, input_size)
65
-
66
- # x, _ = self.rnn(x) # (B, seq_len, 512)
67
- # x = self.fc(x) # (B, seq_len, num_classes)
68
- # return x
69
-
70
- # # Initialize the model
71
- # def model_init(character, model_path):
72
- # # Initialize the model with the number of classes
73
- # model = CRNN(num_classes=len(character))
74
- # model.load_state_dict(torch.load(model_path, map_location=device))
75
- # model = model.to(device)
76
- # return model
77
-
78
- # def predict_image(image_path,character, model_path):
79
- # image = Image.open(image_path).convert('L')
80
-
81
- # # if value < 128, set to 0, else set to 255
82
- # if model_path != dev_letter_path:
83
- # image = image.point(lambda x: 0 if x < 128 else 255, 'L')
84
- # image = image.resize((256, 64)) # Resize to match the input size of the model
85
- # image = np.array(image)
86
- # image = np.expand_dims(image, axis=0)[0] # Add channel dimension
87
- # # to pil image
88
- # # print(image)
89
- # image = Image.fromarray(image).convert('L')
90
-
91
- # if model_path == dev_letter_path:
92
- # image = Image.eval(image, lambda x: 255 - x)
93
-
94
- # # plt.imshow(image, cmap='gray')
95
- # # plt.axis('off')
96
- # # plt.show()
97
- # transform = transforms.Compose([
98
- # transforms.Resize((64, 256)),
99
- # transforms.ToTensor(),
100
- # transforms.Normalize((0.5,), (0.5,))
101
- # ])
102
- # image = transform(image).unsqueeze(0).to(device) # Add batch dimension and move to GPU
103
- # # Load the model weights
104
- # model = model_init(character, model_path)
105
- # # token to string
106
- # # tokens to ids
107
- # id_to_char = {i: c for i, c in enumerate(character)}
108
-
109
- # def get_string_from_token(token):
110
- # """
111
- # Convert a list of character IDs back to the corresponding string.
112
- # """
113
- # return ''.join([id_to_char[i] for i in token])
114
-
115
- # with torch.no_grad():
116
- # output = model(image)
117
- # output = output.permute(1, 0, 2) # (seq_len, batch_size, num_classes)
118
- # _, predicted = output.max(2)
119
- # predicted = predicted.permute(1, 0) # (batch_size, seq_len)
120
- # predicted_str = get_string_from_token(predicted[0].cpu().numpy())
121
- # return predicted_str
122
-
123
- # def dev_number(image):
124
- # # Load the model
125
- # model_path = model_dev_digits_path
126
- # character = character_num
127
- # # Predict the image
128
- # predicted_str = predict_image(image, character, model_path)
129
- # return predicted_str
130
-
131
- # def roman_number(image):
132
- # # Load the model
133
- # model_path = model_roman_digits_path
134
- # character = character_num
135
- # # Predict the image
136
- # predicted_str = predict_image(image, character, model_path)
137
- # return predicted_str
138
-
139
- # def dev_letter(image):
140
- # # Load the model
141
- # model_path = dev_letter_path
142
- # character = character_letter
143
- # # Predict the image
144
- # predicted_str = predict_image(image, character, model_path)
145
- # return predicted_str
146
-
147
-
148
- # # roman_letter
149
- # # Load OCR model once at startup
150
- # model = recognition_predictor(pretrained=True)
151
-
152
- # def roman_letter(image):
153
- # # Load image using doctr
154
- # img = DocumentFile.from_images(image)
155
- # # Perform OCR
156
- # result = model(img)
157
- # # Return result as JSON
158
- # return result
159
-
160
-
161
  import torch
162
  import torch.nn as nn
163
- from PIL import Image
164
  import numpy as np
165
- import torchvision.transforms as transforms
166
- from doctr.io import DocumentFile
167
- from torchvision import models
168
- from doctr.models import recognition_predictor
169
- import os
170
- from functools import lru_cache
171
  import pickle
172
 
 
173
  # Character sets
174
  CHARACTER_NUM = "0123456789-"
175
  CHARACTER_LETTER = ''' "()-./0123456789:?ABCDEFGHIKLMNOPQRSTUWYabcdefghijklmnoprstuvwyँंःअआइईउऊऋऌऍऎएऐऑऒओऔकखगघङचछजझञटठडढणतथदधनऩपफबभमयरऱलळऴवशषसह़ऽािीुूृॄॅॆेैॉॊोौ्ॐ॒॑॓॔क़ख़ग़ज़ड़ढ़फ़य़ॠॢ।॥०१२३४५६७८९॰ॱॲॻॼॽॾ^''' #"()-./0123456789:?ABCDEFGHIKLMNOPQRSTUWYabcdefghijklmnoprstuvwyँंःअआइईउऊऋऌऍऎएऐऑऒओऔकखगघङचछजझञटठडढणतथदधनऩपफबभमयरऱलळऴवशषसह़ऽािीुूृॄॅॆेैॉॊोौ्ॐ॒॑॓॔क़ख़ग़ज़ड़ढ़फ़य़ॠॢ।॥०१२३४५६७८९॰ॱॲॻॼॽॾ^"
@@ -397,4 +243,129 @@ def predict_ne(image_path, device="cpu"):
397
  with torch.no_grad():
398
  output = model(image_tensor)
399
  _, predicted = torch.max(output, 1)
400
- return le.inverse_transform([predicted.item()])[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from doctr.models import detection_predictor, recognition_predictor
2
+ from doctr.io import DocumentFile
3
+ from surya.recognition import RecognitionPredictor
4
+ from surya.detection import DetectionPredictor
5
+ from PIL import Image
6
+ # from functools import lru_cache
7
+ from torchvision import models
8
+ from typing import List
9
+ import torchvision.transforms as transforms
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  import torch
11
  import torch.nn as nn
 
12
  import numpy as np
13
+ import cv2
14
+ import regex as re
15
+ # import os
 
 
 
16
  import pickle
17
 
18
+
19
  # Character sets
20
  CHARACTER_NUM = "0123456789-"
21
  CHARACTER_LETTER = ''' "()-./0123456789:?ABCDEFGHIKLMNOPQRSTUWYabcdefghijklmnoprstuvwyँंःअआइईउऊऋऌऍऎएऐऑऒओऔकखगघङचछजझञटठडढणतथदधनऩपफबभमयरऱलळऴवशषसह़ऽािीुूृॄॅॆेैॉॊोौ्ॐ॒॑॓॔क़ख़ग़ज़ड़ढ़फ़य़ॠॢ।॥०१२३४५६७८९॰ॱॲॻॼॽॾ^''' #"()-./0123456789:?ABCDEFGHIKLMNOPQRSTUWYabcdefghijklmnoprstuvwyँंःअआइईउऊऋऌऍऎएऐऑऒओऔकखगघङचछजझञटठडढणतथदधनऩपफबभमयरऱलळऴवशषसह़ऽािीुूृॄॅॆेैॉॊोौ्ॐ॒॑॓॔क़ख़ग़ज़ड़ढ़फ़य़ॠॢ।॥०१२३४५६७८९॰ॱॲॻॼॽॾ^"
 
243
  with torch.no_grad():
244
  output = model(image_tensor)
245
  _, predicted = torch.max(output, 1)
246
+ return le.inverse_transform([predicted.item()])[0]
247
+
248
+ doctr_detector = None
249
+ recognition_predictor = None
250
+ detection_predictor = None
251
+ def initialize_detector():
252
+ global doctr_detector, recognition_predictor, detection_predictor
253
+ if doctr_detector is None:
254
+ doctr_detector = detection_predictor('db_mobilenet_v3_large', pretrained=True, assume_straight_pages=True, preserve_aspect_ratio=True)
255
+ if recognition_predictor is None:
256
+ recognition_predictor = RecognitionPredictor()
257
+ if detection_predictor is None:
258
+ detection_predictor = DetectionPredictor()
259
+ return doctr_detector, recognition_predictor, detection_predictor
260
+
261
+ def get_cleaned_boxes(out, page):
262
+ h, w, _ = page.shape
263
+ cleaned_boxes = []
264
+ for box in out[0]['words']:
265
+ coords = np.array(box[:4]) # 4 corner points (normalized)
266
+ coords *= np.array([w, h, w, h])
267
+ x1, y1, x2, y2 = coords
268
+ x_thresh = 0.7 * page.shape[1]
269
+ y_thresh = 0.3* page.shape[0]
270
+ if x1> x_thresh and y1 < y_thresh:
271
+ continue
272
+ if (x2 - x1) * (y2 - y1) < 100:
273
+ continue
274
+ cleaned_boxes.append(coords.astype('int'))
275
+ return cleaned_boxes
276
+ # The most inefficient code in existence
277
+ def merge_boxes_same_line(boxes, y_thresh=5, x_thresh=60):
278
+ # Sort boxes first by x and then by y
279
+ boxes = sorted(boxes, key=lambda b: (b[1],b[0]))
280
+ # Trying make all boxes within certain threshold have the same y coordinate for sorting
281
+ # Threshold for grouping rows
282
+ row_threshold = 15
283
+
284
+ aligned_boxes = []
285
+ current_row = []
286
+ current_y = boxes[0][1]
287
+
288
+ for box in boxes:
289
+ x1, y1, x2, y2 = box
290
+ if abs(y1 - current_y) <= row_threshold:
291
+ current_row.append(box)
292
+ else:
293
+ # Align all y1 and y2 in the row
294
+ avg_y1 = int(np.mean([b[1] for b in current_row]))
295
+ avg_y2 = int(np.mean([b[3] for b in current_row]))
296
+ aligned_boxes.extend([(b[0], avg_y1, b[2], avg_y2) for b in current_row])
297
+ current_row = [box]
298
+ current_y = y1
299
+
300
+ # Handle the last row
301
+ if current_row:
302
+ avg_y1 = int(np.mean([b[1] for b in current_row]))
303
+ avg_y2 = int(np.mean([b[3] for b in current_row]))
304
+ aligned_boxes.extend([(b[0], avg_y1, b[2], avg_y2) for b in current_row])
305
+ # After aligning all boxes on y axis, re sort them
306
+ aligned_boxes = sorted(aligned_boxes, key=lambda b: (b[1],b[0]))
307
+
308
+ # Merge adjacent boxes within certain threshold
309
+ merged = []
310
+ p_x1, p_y1, p_x2, p_y2 = aligned_boxes[0]
311
+ for i in range(1,len(aligned_boxes)):
312
+ x1, y1, x2, y2 = aligned_boxes[i]
313
+ if abs(p_y1 - y1) < y_thresh and abs(x1 - p_x2) < x_thresh:
314
+ p_x1 = min(p_x1, x1)
315
+ p_y1 = min(p_y1, y1)
316
+ p_x2 = max(p_x2, x2)
317
+ p_y2 = max(p_y2, y2)
318
+ else:
319
+ merged.append([p_x1, p_y1, p_x2, p_y2])
320
+ p_x1, p_y1, p_x2, p_y2 = x1, y1, x2, y2
321
+
322
+ merged.append([p_x1, p_y1, p_x2, p_y2])
323
+
324
+ return np.array(merged)
325
+
326
+ def ocr_citizenship_utils(image_path: str) -> List[List[str]]:
327
+ doctr_detector, recognition_predictor, detection_predictor = initialize_detector()
328
+ page = cv2.imread(image_path)
329
+ page = cv2.convertScaleAbs(page, alpha=1.5, beta=0)
330
+ page = cv2.resize(page, (720,480))
331
+ out = doctr_detector([page])
332
+ cleaned_boxes = get_cleaned_boxes(out,page)
333
+ merged = merge_boxes_same_line(cleaned_boxes)
334
+ pattern = r'(नेपाली\s*नागरिकताको\s*प्रमाणपत्र){e<=6}'
335
+ prev_y = 0
336
+ start = False
337
+ first_start = True
338
+ y_thresh = 5
339
+ text_combine = ''
340
+ full_result = []
341
+ line_result = []
342
+
343
+ for boxes in merged[3:]:
344
+ x1, y1, x2, y2 = boxes[0],boxes[1],boxes[2],boxes[3]
345
+ crop = page[y1:y2,x1:x2]
346
+ pil_image = Image.fromarray(crop)
347
+
348
+ # OCR PART
349
+ langs = ["en",'ne']
350
+ predictions = recognition_predictor_surya([pil_image], [langs], detection_predictor_surya)
351
+ text_combo = ''
352
+ for text_line in predictions[0].text_lines:
353
+ text_combo = text_combo + " " + text_line.text.strip()
354
+ text_combo = text_combo.strip()
355
+ # OCR PART END
356
+
357
+ if not start:
358
+ match = re.search(pattern, text_combo)
359
+ if match:
360
+ start = True
361
+ continue
362
+ if first_start:
363
+ first_start = False
364
+ prev_y = boxes[1]
365
+ if y1 - prev_y > y_thresh:
366
+ full_result.append(line_result)
367
+ line_result = []
368
+ line_result.append(text_combo)
369
+ prev_y = boxes[1]
370
+
371
+ return full_result