|
ABOUT = """ |
|
# TB-OCR Preview 0.1 Unofficial Demo |
|
|
|
This is an unofficial demo of [yifeihu/TB-OCR-preview-0.1](https://huggingface.co/yifeihu/TB-OCR-preview-0.1). |
|
|
|
Overview of TB-OCR: |
|
|
|
> TB-OCR-preview (Text Block OCR), created by [Yifei Hu](https://x.com/hu_yifei), is an end-to-end OCR model handling text, math latex, and markdown formats all at once. The model takes a block of text as the input and returns clean markdown output. Headers are marked with `##`. Math expressions are guaranteed to be wrapped in brackets `\( inline math \) \[ display math \]` for easier parsing. This model does not require line-detection or math formula detection. |
|
|
|
(From the [model card](https://huggingface.co/yifeihu/TB-OCR-preview-0.1)) |
|
""" |
|
|
|
|
|
import torch, spaces |
|
from transformers import AutoModelForCausalLM, AutoProcessor |
|
from PIL import Image |
|
import requests |
|
import os |
|
|
|
model_id = "yifeihu/TB-OCR-preview-0.1" |
|
|
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
if not torch.cuda.is_available(): |
|
ABOUT += "\n\n### ⚠️ This demo is running on CPU ⚠️\n\nThis demo is running on CPU, it will be very slow. Consider duplicating it or running it locally to skip the queue and for faster response times." |
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_id, |
|
device_map=DEVICE, |
|
trust_remote_code=True, |
|
torch_dtype="auto", |
|
|
|
|
|
) |
|
|
|
processor = AutoProcessor.from_pretrained(model_id, |
|
trust_remote_code=True, |
|
num_crops=16 |
|
) |
|
@spaces.GPU |
|
def phi_ocr(image_url): |
|
question = "Convert the text to markdown format." |
|
image = Image.open(image_url) |
|
prompt_message = [{ |
|
'role': 'user', |
|
'content': f'<|image_1|>\n{question}', |
|
}] |
|
|
|
prompt = processor.tokenizer.apply_chat_template(prompt_message, tokenize=False, add_generation_prompt=True) |
|
inputs = processor(prompt, [image], return_tensors="pt").to(DEVICE) |
|
|
|
generation_args = { |
|
"max_new_tokens": 1024, |
|
"temperature": 0.1, |
|
"do_sample": False |
|
} |
|
|
|
generate_ids = model.generate(**inputs, eos_token_id=processor.tokenizer.eos_token_id, **generation_args |
|
) |
|
|
|
generate_ids = generate_ids[:, inputs['input_ids'].shape[1]:] |
|
response = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0] |
|
|
|
response = response.split("<image_end>")[0] |
|
|
|
return response |
|
|
|
import gradio as gr |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown(ABOUT) |
|
with gr.Row(): |
|
with gr.Column(): |
|
img = gr.Image(label="Input image", type="filepath") |
|
btn = gr.Button("OCR") |
|
with gr.Column(): |
|
out = gr.Markdown() |
|
btn.click(phi_ocr, inputs=img, outputs=out) |
|
demo.queue().launch() |