yuragoithf commited on
Commit
1e268f8
·
verified ·
1 Parent(s): 7937b6d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -34
app.py CHANGED
@@ -1,43 +1,25 @@
1
- import gc
2
  import torch
3
  from transformers import TrOCRProcessor, VisionEncoderDecoderModel
4
- from craft_text_detector import Craft
5
  from PIL import Image
6
- import cv2
7
- import time
8
  import gradio as gr
9
 
10
- # Force CPU usage, disable CUDA
11
  torch.set_default_device('cpu')
12
- craft = Craft(output_dir=None, crop_type="box", cuda=False)
13
 
14
- # Load smaller model suitable for CPU
15
- processor = TrOCRProcessor.from_pretrained('microsoft/trocr-small-handwritten')
16
- model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-small-handwritten')
17
 
18
  def recognize_handwritten(image):
19
- start_time = time.time()
 
 
20
 
21
- # Convert Gradio image to OpenCV format
22
- image = cv2.cvtColor(np.array(image), cv2.COLOR_RGB2BGR)
23
- result = craft.detect_text(image=image)
24
- boxes = result["boxes"]
25
- pil_image = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB))
26
- texts = []
27
 
28
- for box in boxes:
29
- crop = pil_image.crop([box[0][0], box[0][1], box[2][0], box[2][1]])
30
- pixel_values = processor(crop, return_tensors="pt").pixel_values
31
- with torch.no_grad():
32
- generated_ids = model.generate(pixel_values)
33
- text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
34
- texts.append(text)
35
-
36
- text_data = " ".join(texts)
37
- end_time = time.time()
38
- time_difference = end_time - start_time
39
-
40
- return f"Recognized text: {text_data}\nTime: {time_difference} seconds"
41
 
42
  # Create Gradio interface
43
  interface = gr.Interface(
@@ -49,8 +31,4 @@ interface = gr.Interface(
49
  )
50
 
51
  # Launch the app
52
- interface.launch()
53
-
54
- # Cleanup
55
- craft.unload_craftnet_model()
56
- gc.collect()
 
 
1
  import torch
2
  from transformers import TrOCRProcessor, VisionEncoderDecoderModel
 
3
  from PIL import Image
 
 
4
  import gradio as gr
5
 
6
+ # Force CPU usage
7
  torch.set_default_device('cpu')
 
8
 
9
+ # Load model and processor
10
+ processor = TrOCRProcessor.from_pretrained('microsoft/trocr-base-handwritten')
11
+ model = VisionEncoderDecoderModel.from_pretrained('microsoft/trocr-base-handwritten')
12
 
13
  def recognize_handwritten(image):
14
+ # Convert uploaded image to RGB
15
+ image = image.convert("RGB")
16
+ pixel_values = processor(images=image, return_tensors="pt").pixel_values
17
 
18
+ # Generate text
19
+ generated_ids = model.generate(pixel_values)
20
+ generated_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
 
 
 
21
 
22
+ return f"Recognized text: {generated_text}"
 
 
 
 
 
 
 
 
 
 
 
 
23
 
24
  # Create Gradio interface
25
  interface = gr.Interface(
 
31
  )
32
 
33
  # Launch the app
34
+ interface.launch()