Futyn-Maker
Add the app
2602ab3
import os
import numpy as np
import cv2
import torch
import gradio as gr
from PIL import Image
from transformers import TrOCRProcessor, VisionEncoderDecoderModel
from kraken import blla, binarization
# Инициализация модели и процессора
print("Загрузка модели OCR...")
model_name = "Futyn-Maker/trocr-base-ru-notebooks"
processor = TrOCRProcessor.from_pretrained(model_name)
model = VisionEncoderDecoderModel.from_pretrained(model_name)
# Проверка доступности GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
print(f"Использование устройства: {device}")
def segment_image(image):
"""
Сегментирует изображение на строки с помощью Kraken
"""
# Конвертация в бинарное изображение
bw_img = binarization.nlbin(image, threshold=0.5, escale=2.0, border=0.1, high=0.9)
# Сегментация на строки
lines = blla.segment(bw_img, text_direction='horizontal-lr')
# Сортировка и объединение близких строк
sorted_lines = sorted(lines.lines, key=lambda line: line.baseline[0][1]) # Сортировка по y-координате
merged_lines = []
if sorted_lines:
current_line = sorted_lines[0]
for next_line in sorted_lines[1:]:
current_y = current_line.baseline[0][1]
next_y = next_line.baseline[0][1]
if abs(next_y - current_y) < 15:
current_line.baseline.extend(next_line.baseline)
else:
merged_lines.append(current_line)
current_line = next_line
merged_lines.append(current_line)
else:
merged_lines = sorted_lines
# Извлечение областей строк
line_images = []
for line in merged_lines:
baseline = np.array(line.baseline)
x0 = int(np.min(baseline[:, 0])) # Минимальная x-координата
y0 = int(np.min(baseline[:, 1])) # Минимальная y-координата
x1 = int(np.max(baseline[:, 0])) # Максимальная x-координата
y1 = int(np.max(baseline[:, 1])) # Максимальная y-координата
# Добавление отступа для лучшего распознавания
padding = 30
y0 = max(0, y0 - padding)
y1 = min(image.height, y1 + padding)
# Вырезаем область строки
line_image = image.crop((x0, y0, x1, y1))
line_images.append(line_image)
return line_images
def recognize_text(image):
"""
Распознает текст на изображении, сегментированном на строки
"""
# Сегментация изображения на строки
line_images = segment_image(image)
if not line_images:
return "Не удалось обнаружить строки текста на изображении."
# Распознавание текста для каждой строки
recognized_lines = []
for line_image in line_images:
# Подготовка изображения для модели
pixel_values = processor(line_image, return_tensors="pt").pixel_values
pixel_values = pixel_values.to(device)
# Распознавание текста
with torch.no_grad():
generated_ids = model.generate(
pixel_values,
max_length=256,
num_beams=4,
early_stopping=True
)
# Декодирование результата
line_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
recognized_lines.append(line_text)
# Объединение всех строк в один текст
full_text = "\n".join(recognized_lines)
return full_text
def save_text_to_file(text):
"""
Сохраняет распознанный текст в файл
"""
with open("recognized_text.txt", "w", encoding="utf-8") as f:
f.write(text)
return "recognized_text.txt"
def process_image(input_image):
"""
Основная функция для обработки изображения
"""
# Конвертация в PIL Image, если необходимо
if not isinstance(input_image, Image.Image):
input_image = Image.fromarray(input_image)
# Распознавание текста
recognized_text = recognize_text(input_image)
# Сохранение результата в файл
output_file = save_text_to_file(recognized_text)
return recognized_text, output_file
# Создание интерфейса Gradio
with gr.Blocks(title="Распознавание рукописного текста") as demo:
gr.Markdown("# Распознавание рукописного текста")
gr.Markdown("Загрузите изображение с рукописным текстом для распознавания.")
with gr.Row():
input_image = gr.Image(type="pil", label="Изображение")
with gr.Row():
submit_btn = gr.Button("Распознать текст")
with gr.Row():
text_output = gr.Textbox(label="Распознанный текст", lines=10)
file_output = gr.File(label="Скачать текстовый файл")
submit_btn.click(
fn=process_image,
inputs=input_image,
outputs=[text_output, file_output]
)
if __name__ == "__main__":
demo.launch()