yuragoithf commited on
Commit
a476151
·
verified ·
1 Parent(s): d8cdd7b

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +56 -0
  2. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import torch
3
+ from transformers import TrOCRProcessor, VisionEncoderDecoderModel
4
+ from craft_text_detector import Craft
5
+ from PIL import Image
6
+ import cv2
7
+ import time
8
+ import gradio as gr
9
+
10
+ # Force CPU usage, disable CUDA
11
+ torch.set_default_device('cpu')
12
+ craft = Craft(output_dir=None, crop_type="box", cuda=False)
13
+
14
+ # Load smaller model suitable for CPU
15
+ processor = TrOCRProcessor.from_pretrained('microsoft/trocr-small-handwritten')
16
+ model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-small-handwritten')
17
+
18
+ def recognize_handwritten(image):
19
+ start_time = time.time()
20
+
21
+ # Convert Gradio image to OpenCV format
22
+ image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
23
+ result = craft.detect_text(image=image)
24
+ boxes = result["boxes"]
25
+ pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
26
+ texts = []
27
+
28
+ for box in boxes:
29
+ crop = pil_image.crop([box[0][0], box[0][1], box[2][0], box[2][1]])
30
+ pixel_values = processor(crop, return_tensors="pt").pixel_values
31
+ with torch.no_grad():
32
+ generated_ids = model.generate(pixel_values)
33
+ text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
34
+ texts.append(text)
35
+
36
+ text_data = " ".join(texts)
37
+ end_time = time.time()
38
+ time_difference = end_time - start_time
39
+
40
+ return f"Recognized text: {text_data}\nTime: {time_difference} seconds"
41
+
42
+ # Create Gradio interface
43
+ interface = gr.Interface(
44
+ fn=recognize_handwritten,
45
+ inputs=gr.Image(type="pil"),
46
+ outputs="text",
47
+ title="Handwritten Text Recognition",
48
+ description="Upload an image containing handwritten text to recognize it."
49
+ )
50
+
51
+ # Launch the app
52
+ interface.launch()
53
+
54
+ # Cleanup
55
+ craft.unload_craftnet_model()
56
+ gc.collect()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch
2
+ transformers
3
+ craft-text-detector
4
+ opencv-python
5
+ Pillow
6
+ gradio