UniquePratham commited on
Commit
afedbd6
·
verified ·
1 Parent(s): 8b34af2

Update ocr_cpu.py

Browse files
Files changed (1) hide show
  1. ocr_cpu.py +70 -25
ocr_cpu.py CHANGED
@@ -1,27 +1,54 @@
 
 
1
  import os
2
  import torch
3
- from transformers import AutoModel, AutoTokenizer
 
 
 
 
 
4
 
5
- # Load model and tokenizer
6
- model_name = "srimanth-d/GOT_CPU" # Using GOT model on CPU
7
- tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True, return_tensors='pt')
 
 
8
 
9
- # Load the model
10
- model = AutoModel.from_pretrained(
11
- model_name,
12
  trust_remote_code=True,
13
  low_cpu_mem_usage=True,
14
  use_safetensors=True,
15
- pad_token_id=tokenizer.eos_token_id,
16
  )
17
 
18
- # Ensure the model is in evaluation mode and loaded on CPU
19
- device = torch.device("cpu")
20
- model = model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
- # OCR function to extract text
23
  def extract_text_got(uploaded_file):
24
- """Use GOT-OCR2.0 model to extract text from the uploaded image."""
 
 
25
  temp_file_path = 'temp_image.jpg'
26
 
27
  try:
@@ -38,7 +65,7 @@ def extract_text_got(uploaded_file):
38
  for ocr_type in ocr_types:
39
  with torch.no_grad():
40
  print(f"Running OCR with type: {ocr_type}")
41
- outputs = model.chat(tokenizer, temp_file_path, ocr_type=ocr_type)
42
 
43
  if isinstance(outputs, list) and outputs[0].strip():
44
  return outputs[0].strip() # Return the result if successful
@@ -56,22 +83,40 @@ def extract_text_got(uploaded_file):
56
  os.remove(temp_file_path)
57
  print(f"Temporary file {temp_file_path} removed.")
58
 
59
- # Function to clean extracted text using AI
 
 
 
60
  def clean_text_with_ai(extracted_text):
61
  """
62
- Cleans extracted text by leveraging an AI model to intelligently remove extra spaces.
63
  """
64
  try:
65
- # Prepare the input for the AI model
66
- inputs = tokenizer(extracted_text, return_tensors="pt").to(device)
67
-
68
- # Generate cleaned text using the AI model
 
 
 
69
  with torch.no_grad():
70
- outputs = model.generate(**inputs, max_new_tokens=100) # Adjust max_new_tokens as needed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
 
72
- # Decode the generated output
73
- cleaned_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
74
-
75
- return cleaned_text.strip() # Return the cleaned text
76
  except Exception as e:
77
  return f"Error during AI text cleaning: {str(e)}"
 
1
+ # ocr_cpu.py
2
+
3
  import os
4
  import torch
5
+ from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM
6
+ import re
7
+
8
+ # -----------------------------
9
+ # OCR Model Initialization
10
+ # -----------------------------
11
 
12
+ # Load OCR model and tokenizer
13
+ ocr_model_name = "srimanth-d/GOT_CPU" # Using GOT model on CPU
14
+ ocr_tokenizer = AutoTokenizer.from_pretrained(
15
+ ocr_model_name, trust_remote_code=True, return_tensors='pt'
16
+ )
17
 
18
+ # Load the OCR model
19
+ ocr_model = AutoModel.from_pretrained(
20
+ ocr_model_name,
21
  trust_remote_code=True,
22
  low_cpu_mem_usage=True,
23
  use_safetensors=True,
24
+ pad_token_id=ocr_tokenizer.eos_token_id,
25
  )
26
 
27
+ # Ensure the OCR model is in evaluation mode and loaded on CPU
28
+ ocr_device = torch.device("cpu")
29
+ ocr_model = ocr_model.eval().to(ocr_device)
30
+
31
+ # -----------------------------
32
+ # Text Cleaning Model Initialization
33
+ # -----------------------------
34
+
35
+ # Load Text Cleaning model and tokenizer
36
+ clean_model_name = "gpt2" # You can choose a different model if preferred
37
+ clean_tokenizer = AutoTokenizer.from_pretrained(clean_model_name)
38
+ clean_model = AutoModelForCausalLM.from_pretrained(clean_model_name)
39
+
40
+ # Ensure the Text Cleaning model is in evaluation mode and loaded on CPU
41
+ clean_device = torch.device("cpu")
42
+ clean_model = clean_model.eval().to(clean_device)
43
+
44
+ # -----------------------------
45
+ # OCR Function
46
+ # -----------------------------
47
 
 
48
  def extract_text_got(uploaded_file):
49
+ """
50
+ Use GOT-OCR2.0 model to extract text from the uploaded image.
51
+ """
52
  temp_file_path = 'temp_image.jpg'
53
 
54
  try:
 
65
  for ocr_type in ocr_types:
66
  with torch.no_grad():
67
  print(f"Running OCR with type: {ocr_type}")
68
+ outputs = ocr_model.chat(ocr_tokenizer, temp_file_path, ocr_type=ocr_type)
69
 
70
  if isinstance(outputs, list) and outputs[0].strip():
71
  return outputs[0].strip() # Return the result if successful
 
83
  os.remove(temp_file_path)
84
  print(f"Temporary file {temp_file_path} removed.")
85
 
86
+ # -----------------------------
87
+ # Text Cleaning Function
88
+ # -----------------------------
89
+
90
  def clean_text_with_ai(extracted_text):
91
  """
92
+ Cleans extracted text by leveraging a language model to intelligently remove extra spaces and correct formatting.
93
  """
94
  try:
95
+ # Define the prompt for cleaning
96
+ prompt = f"Please clean the following text by removing extra spaces and ensuring proper formatting:\n\n{extracted_text}\n\nCleaned Text:"
97
+
98
+ # Tokenize the input prompt
99
+ inputs = clean_tokenizer.encode(prompt, return_tensors="pt").to(clean_device)
100
+
101
+ # Generate the cleaned text
102
  with torch.no_grad():
103
+ outputs = clean_model.generate(
104
+ inputs,
105
+ max_length=500, # Adjust as needed
106
+ temperature=0.7,
107
+ top_p=0.9,
108
+ do_sample=True,
109
+ eos_token_id=clean_tokenizer.eos_token_id,
110
+ pad_token_id=clean_tokenizer.eos_token_id
111
+ )
112
+
113
+ # Decode the generated text
114
+ cleaned_text = clean_tokenizer.decode(outputs[0], skip_special_tokens=True)
115
+
116
+ # Extract the cleaned text after the prompt
117
+ cleaned_text = cleaned_text.split("Cleaned Text:")[-1].strip()
118
+
119
+ return cleaned_text
120
 
 
 
 
 
121
  except Exception as e:
122
  return f"Error during AI text cleaning: {str(e)}"