Spaces:
Sleeping
Sleeping
import os | |
import numpy as np | |
import cv2 | |
import torch | |
import gradio as gr | |
from PIL import Image | |
from transformers import TrOCRProcessor, VisionEncoderDecoderModel | |
from kraken import blla, binarization | |
# Инициализация модели и процессора | |
print("Загрузка модели OCR...") | |
model_name = "Futyn-Maker/trocr-base-ru-notebooks" | |
processor = TrOCRProcessor.from_pretrained(model_name) | |
model = VisionEncoderDecoderModel.from_pretrained(model_name) | |
# Проверка доступности GPU | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
model.to(device) | |
print(f"Использование устройства: {device}") | |
def segment_image(image): | |
""" | |
Сегментирует изображение на строки с помощью Kraken | |
""" | |
# Конвертация в бинарное изображение | |
bw_img = binarization.nlbin(image, threshold=0.5, escale=2.0, border=0.1, high=0.9) | |
# Сегментация на строки | |
lines = blla.segment(bw_img, text_direction='horizontal-lr') | |
# Сортировка и объединение близких строк | |
sorted_lines = sorted(lines.lines, key=lambda line: line.baseline[0][1]) # Сортировка по y-координате | |
merged_lines = [] | |
if sorted_lines: | |
current_line = sorted_lines[0] | |
for next_line in sorted_lines[1:]: | |
current_y = current_line.baseline[0][1] | |
next_y = next_line.baseline[0][1] | |
if abs(next_y - current_y) < 15: | |
current_line.baseline.extend(next_line.baseline) | |
else: | |
merged_lines.append(current_line) | |
current_line = next_line | |
merged_lines.append(current_line) | |
else: | |
merged_lines = sorted_lines | |
# Извлечение областей строк | |
line_images = [] | |
for line in merged_lines: | |
baseline = np.array(line.baseline) | |
x0 = int(np.min(baseline[:, 0])) # Минимальная x-координата | |
y0 = int(np.min(baseline[:, 1])) # Минимальная y-координата | |
x1 = int(np.max(baseline[:, 0])) # Максимальная x-координата | |
y1 = int(np.max(baseline[:, 1])) # Максимальная y-координата | |
# Добавление отступа для лучшего распознавания | |
padding = 30 | |
y0 = max(0, y0 - padding) | |
y1 = min(image.height, y1 + padding) | |
# Вырезаем область строки | |
line_image = image.crop((x0, y0, x1, y1)) | |
line_images.append(line_image) | |
return line_images | |
def recognize_text(image): | |
""" | |
Распознает текст на изображении, сегментированном на строки | |
""" | |
# Сегментация изображения на строки | |
line_images = segment_image(image) | |
if not line_images: | |
return "Не удалось обнаружить строки текста на изображении." | |
# Распознавание текста для каждой строки | |
recognized_lines = [] | |
for line_image in line_images: | |
# Подготовка изображения для модели | |
pixel_values = processor(line_image, return_tensors="pt").pixel_values | |
pixel_values = pixel_values.to(device) | |
# Распознавание текста | |
with torch.no_grad(): | |
generated_ids = model.generate( | |
pixel_values, | |
max_length=256, | |
num_beams=4, | |
early_stopping=True | |
) | |
# Декодирование результата | |
line_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
recognized_lines.append(line_text) | |
# Объединение всех строк в один текст | |
full_text = "\n".join(recognized_lines) | |
return full_text | |
def save_text_to_file(text): | |
""" | |
Сохраняет распознанный текст в файл | |
""" | |
with open("recognized_text.txt", "w", encoding="utf-8") as f: | |
f.write(text) | |
return "recognized_text.txt" | |
def process_image(input_image): | |
""" | |
Основная функция для обработки изображения | |
""" | |
# Конвертация в PIL Image, если необходимо | |
if not isinstance(input_image, Image.Image): | |
input_image = Image.fromarray(input_image) | |
# Распознавание текста | |
recognized_text = recognize_text(input_image) | |
# Сохранение результата в файл | |
output_file = save_text_to_file(recognized_text) | |
return recognized_text, output_file | |
# Создание интерфейса Gradio | |
with gr.Blocks(title="Распознавание рукописного текста") as demo: | |
gr.Markdown("# Распознавание рукописного текста") | |
gr.Markdown("Загрузите изображение с рукописным текстом для распознавания.") | |
with gr.Row(): | |
input_image = gr.Image(type="pil", label="Изображение") | |
with gr.Row(): | |
submit_btn = gr.Button("Распознать текст") | |
with gr.Row(): | |
text_output = gr.Textbox(label="Распознанный текст", lines=10) | |
file_output = gr.File(label="Скачать текстовый файл") | |
submit_btn.click( | |
fn=process_image, | |
inputs=input_image, | |
outputs=[text_output, file_output] | |
) | |
if __name__ == "__main__": | |
demo.launch() | |