rockerritesh commited on
Commit
1b5f903
·
verified ·
1 Parent(s): f70009a

Upload 4 files

Browse files
Files changed (4) hide show
  1. Dockerfile +60 -0
  2. main.py +62 -0
  3. requirements.txt +9 -0
  4. utils.py +164 -0
Dockerfile ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use an official Python runtime as a parent image
2
+ FROM python:3.9-slim
3
+
4
+ # Set the working directory
5
+ WORKDIR /app
6
+
7
+ # Install necessary system dependencies
8
+ RUN apt-get update && apt-get install -y \
9
+ gcc \
10
+ poppler-utils \
11
+ cmake \
12
+ libglib2.0-0 \
13
+ libsm6 \
14
+ libxext6 \
15
+ libxrender-dev \
16
+ libgl1-mesa-glx \
17
+ && apt-get clean \
18
+ && rm -rf /var/lib/apt/lists/*
19
+ # Set the CC environment variable to ensure TorchInductor uses the correct compiler
20
+ ENV CC=gcc
21
+
22
+ # Copy the requirements file and install dependencies
23
+ COPY requirements.txt .
24
+ RUN pip install --no-cache-dir -r requirements.txt
25
+
26
+ # Create cache and config directories with appropriate permissions
27
+ RUN mkdir -p /app/cache && chmod 777 /app/cache
28
+ RUN mkdir -p /app/config && chmod 777 /app/config
29
+ RUN mkdir -p /app/triton_cache && chmod 777 /app/triton_cache
30
+ RUN mkdir -p /app/torchinductor_cache && chmod 777 /app/torchinductor_cache
31
+ RUN mkdir -p /mnt/data && chmod 777 /mnt/data
32
+ RUN mkdir -p /app/.cache && chmod -R 777 /app/.cache
33
+
34
+ # Create directories for Matplotlib and Fontconfig with appropriate permissions
35
+ RUN mkdir -p /app/matplotlib && chmod 777 /app/matplotlib
36
+ RUN mkdir -p /app/fontconfig && chmod 777 /app/fontconfig
37
+
38
+ # Set environment variables for Hugging Face cache, config, Triton, and TorchInductor directories
39
+ ENV HF_HOME=/app/cache
40
+ ENV XDG_CACHE_HOME=/app/.cache
41
+ ENV XDG_CONFIG_HOME=/app/config
42
+ ENV TRITON_CACHE_DIR=/app/triton_cache
43
+ ENV TORCHINDUCTOR_CACHE_DIR=/app/torchinductor_cache
44
+ ENV MPLCONFIGDIR=/app/matplotlib
45
+ ENV FONTCONFIG_PATH=/app/fontconfig
46
+ ENV TORCH_HOME=/app/torchinductor_cache
47
+ ENV TRITON_CACHE=/app/triton_cache
48
+ ENV TOKENIZERS_PARALLELISM=false
49
+
50
+ # Copy the application code
51
+ COPY main.py .
52
+ COPY utils.py ./
53
+ COPY models /app/models
54
+ COPY fonts /app/fonts
55
+
56
+ # Expose the port FastAPI will run on
57
+ EXPOSE 7860
58
+
59
+ # Command to run the FastAPI app
60
+ CMD ["uvicorn", "main:app", "--host", "0.0.0.0", "--port", "7860"]
main.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ! pip uninstall -y tensorflow
2
+ # ! pip install "python-doctr[torch,viz]"
3
+
4
+ from fastapi import FastAPI, UploadFile, File
5
+ from fastapi.responses import JSONResponse
6
+ from utils import dev_number, roman_number, dev_letter, roman_letter
7
+ import tempfile
8
+
9
+ app = FastAPI()
10
+
11
+
12
+ @app.post("/ocr_dev_number/")
13
+ async def extract_dev_number(image: UploadFile = File(...)):
14
+ # Save uploaded image temporarily
15
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp:
16
+ content = await image.read()
17
+ tmp.write(content)
18
+ tmp_path = tmp.name
19
+
20
+ # predict the image
21
+ predicted_str = dev_number(tmp_path)
22
+ # Return result as JSON
23
+ return JSONResponse(content={"predicted_str": predicted_str})
24
+
25
+ @app.post("/ocr_roman_number/")
26
+ async def extract_roman_number(image: UploadFile = File(...)):
27
+ # Save uploaded image temporarily
28
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp:
29
+ content = await image.read()
30
+ tmp.write(content)
31
+ tmp_path = tmp.name
32
+
33
+ # predict the image
34
+ predicted_str = roman_number(tmp_path)
35
+ # Return result as JSON
36
+ return JSONResponse(content={"predicted_str": predicted_str})
37
+
38
+ @app.post("/ocr_dev_letter/")
39
+ async def extract_dev_letter(image: UploadFile = File(...)):
40
+ # Save uploaded image temporarily
41
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp:
42
+ content = await image.read()
43
+ tmp.write(content)
44
+ tmp_path = tmp.name
45
+
46
+ # predict the image
47
+ predicted_str = dev_letter(tmp_path)
48
+ # Return result as JSON
49
+ return JSONResponse(content={"predicted_str": predicted_str})
50
+
51
+ @app.post("/ocr_roman_letter/")
52
+ async def extract_roman_letter(image: UploadFile = File(...)):
53
+ # Save uploaded image temporarily
54
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".png") as tmp:
55
+ content = await image.read()
56
+ tmp.write(content)
57
+ tmp_path = tmp.name
58
+
59
+ # predict the image
60
+ predicted_str = roman_letter(tmp_path)
61
+ # Return result as JSON
62
+ return JSONResponse(content={"predicted_str": predicted_str})
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ python-doctr[torch,viz]
2
+ torch
3
+ torchvision
4
+ numpy
5
+ matplotlib
6
+ pillow
7
+ fastapi
8
+ uvicorn
9
+ pydantic
utils.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from PIL import Image
4
+ import numpy as np
5
+ import matplotlib.pyplot as plt
6
+ import torchvision.transforms as transforms
7
+ from doctr.io import DocumentFile
8
+ from doctr.models import recognition_predictor
9
+
10
+ character_num = "0123456789-"
11
+ character_letter = ''' "()-./0123456789:?ABCDEFGHIKLMNOPQRSTUWYabcdefghijklmnoprstuvwyँंःअआइईउऊऋऌऍऎएऐऑऒओऔकखगघङचछजझञटठडढणतथदधनऩपफबभमयरऱलळऴवशषसह़ऽािीुूृॄॅॆेैॉॊोौ्ॐ॒॑॓॔क़ख़ग़ज़ड़ढ़फ़य़ॠॢ।॥०१२३४५६७८९॰ॱॲॻॼॽॾ^''' #"()-./0123456789:?ABCDEFGHIKLMNOPQRSTUWYabcdefghijklmnoprstuvwyँंःअआइईउऊऋऌऍऎएऐऑऒओऔकखगघङचछजझञटठडढणतथदधनऩपफबभमयरऱलळऴवशषसह़ऽािीुूृॄॅॆेैॉॊोौ्ॐ॒॑॓॔क़ख़ग़ज़ड़ढ़फ़य़ॠॢ।॥०१२३४५६७८९॰ॱॲॻॼॽॾ^"
12
+
13
+ model_dev_digits_path = "models/devnagri_digits_20k_v2.pth"
14
+ model_roman_digits_path = "models/roman_digits_20k_v5.pth"
15
+ dev_letter_path = "models/small_devnagari_letter.pth"
16
+
17
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
18
+
19
+ # Define the CRNN model
20
+ class CRNN(nn.Module):
21
+ def __init__(self, num_classes, input_size=(1, 64, 256)):
22
+ super(CRNN, self).__init__()
23
+
24
+ self.conv_block = nn.Sequential(
25
+ nn.Conv2d(input_size[0], 64, kernel_size=3, stride=1, padding=1),
26
+ nn.BatchNorm2d(64),
27
+ nn.ReLU(),
28
+ nn.MaxPool2d(kernel_size=2, stride=2), # 64x128
29
+
30
+ nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1),
31
+ nn.BatchNorm2d(128),
32
+ nn.ReLU(),
33
+ nn.MaxPool2d(kernel_size=2, stride=2), # 32x64
34
+
35
+ nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1),
36
+ nn.BatchNorm2d(256),
37
+ nn.ReLU(),
38
+ nn.MaxPool2d(kernel_size=2, stride=2), # 16x32
39
+
40
+ nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1),
41
+ nn.BatchNorm2d(512),
42
+ nn.ReLU(),
43
+ nn.MaxPool2d(kernel_size=2, stride=2) # 8x16
44
+ )
45
+
46
+ # Dimensions after conv: batch x 512 x 8 x 16
47
+ feature_height = input_size[1] // 16 # 64 -> 4 pools → 64/2^4 = 4
48
+
49
+ self.rnn = nn.LSTM(
50
+ input_size=512 * feature_height, # 512 * 4 = 2048
51
+ hidden_size=128,
52
+ num_layers=1,
53
+ bidirectional=True,
54
+ dropout=0.3,
55
+ batch_first=True
56
+ )
57
+
58
+ self.fc = nn.Linear(256, num_classes) # 256*2 = 512
59
+
60
+ def forward(self, x):
61
+ x = self.conv_block(x) # (B, 512, H=4, W=16)
62
+ b, c, h, w = x.size()
63
+ x = x.permute(0, 3, 1, 2) # (B, W, C, H)
64
+ x = x.contiguous().view(b, w, c * h) # (B, seq_len, input_size)
65
+
66
+ x, _ = self.rnn(x) # (B, seq_len, 512)
67
+ x = self.fc(x) # (B, seq_len, num_classes)
68
+ return x
69
+
70
+ # Initialize the model
71
+ def model_init(character, model_path):
72
+ # Initialize the model with the number of classes
73
+ model = CRNN(num_classes=len(character))
74
+ model.load_state_dict(torch.load(model_path, map_location=device))
75
+ model = model.to(device)
76
+ return model
77
+
78
+ def predict_image(image_path,character, model_path):
79
+ image = Image.open(image_path).convert('L')
80
+
81
+ # if value < 128, set to 0, else set to 255
82
+ if model_path != dev_letter_path:
83
+ image = image.point(lambda x: 0 if x < 128 else 255, 'L')
84
+ image = image.resize((256, 64)) # Resize to match the input size of the model
85
+ image = np.array(image)
86
+ image = np.expand_dims(image, axis=0)[0] # Add channel dimension
87
+ # to pil image
88
+ # print(image)
89
+ image = Image.fromarray(image).convert('L')
90
+
91
+ if model_path == dev_letter_path:
92
+ image = Image.eval(image, lambda x: 255 - x)
93
+
94
+ # plt.imshow(image, cmap='gray')
95
+ # plt.axis('off')
96
+ # plt.show()
97
+ transform = transforms.Compose([
98
+ transforms.Resize((64, 256)),
99
+ transforms.ToTensor(),
100
+ transforms.Normalize((0.5,), (0.5,))
101
+ ])
102
+ image = transform(image).unsqueeze(0).to(device) # Add batch dimension and move to GPU
103
+ # Load the model weights
104
+ model = model_init(character, model_path)
105
+ # token to string
106
+ # tokens to ids
107
+ id_to_char = {i: c for i, c in enumerate(character)}
108
+
109
+ def get_string_from_token(token):
110
+ """
111
+ Convert a list of character IDs back to the corresponding string.
112
+ """
113
+ return ''.join([id_to_char[i] for i in token])
114
+
115
+ with torch.no_grad():
116
+ output = model(image)
117
+ output = output.permute(1, 0, 2) # (seq_len, batch_size, num_classes)
118
+ _, predicted = output.max(2)
119
+ predicted = predicted.permute(1, 0) # (batch_size, seq_len)
120
+ predicted_str = get_string_from_token(predicted[0].cpu().numpy())
121
+ return predicted_str
122
+
123
+ def dev_number(image):
124
+ # Load the model
125
+ model_path = model_dev_digits_path
126
+ character = character_num
127
+ # Predict the image
128
+ predicted_str = predict_image(image, character, model_path)
129
+ return predicted_str
130
+
131
+ def roman_number(image):
132
+ # Load the model
133
+ model_path = model_roman_digits_path
134
+ character = character_num
135
+ # Predict the image
136
+ predicted_str = predict_image(image, character, model_path)
137
+ return predicted_str
138
+
139
+ def dev_letter(image):
140
+ # Load the model
141
+ model_path = dev_letter_path
142
+ character = character_letter
143
+ # Predict the image
144
+ predicted_str = predict_image(image, character, model_path)
145
+ return predicted_str
146
+
147
+
148
+ # roman_letter
149
+ # Load OCR model once at startup
150
+ model = recognition_predictor(pretrained=True)
151
+
152
+ def roman_letter(image):
153
+ # Load image using doctr
154
+ img = DocumentFile.from_images(image)
155
+ # Perform OCR
156
+ result = model(img)
157
+ # Return result as JSON
158
+ return result
159
+
160
+
161
+
162
+
163
+
164
+