|
import modal |
|
from io import BytesIO |
|
from pathlib import Path |
|
from fastapi import File, UploadFile, Form |
|
image = ( |
|
modal.Image.debian_slim(python_version="3.10") |
|
.apt_install(["libgl1-mesa-glx", "libglib2.0-0"]) |
|
.pip_install( |
|
"ultralytics>=8.2.85", |
|
"doclayout-yolo==0.0.2", |
|
"huggingface-hub", |
|
"fastapi", |
|
) |
|
) |
|
volume = modal.Volume.from_name("yolo-layout-detection", create_if_missing=True) |
|
volume_path = Path("/root") / "data" |
|
model_path = volume_path / "path2doclayout_yolo_ft.pt" |
|
app = modal.App( |
|
"yolo-layout-detection-temp", |
|
image=image, |
|
volumes={volume_path: volume}, |
|
) |
|
@app.function() |
|
def download_model(): |
|
from huggingface_hub import snapshot_download |
|
snapshot_download( |
|
repo_id="opendatalab/pdf-extract-kit-1.0", |
|
local_dir=volume_path, |
|
allow_patterns='path2*', |
|
max_workers=20, |
|
) |
|
@app.cls(gpu="a10g") |
|
class LayoutDetection: |
|
@modal.enter() |
|
def load_model(self): |
|
from doclayout_yolo import YOLOv10 |
|
self.model = YOLOv10(model_path) |
|
@modal.web_endpoint(method="POST", docs=True) |
|
async def predict(self, img: UploadFile = File(...), task: str = Form(...)): |
|
from PIL import Image |
|
img_bytes = await img.read() |
|
img = Image.open(BytesIO(img_bytes)) |
|
results = self.model.predict(img) |
|
|
|
figs = [] |
|
for result in results: |
|
boxes = result.__dict__['boxes'].xyxy.cpu().tolist() |
|
classes = result.__dict__['boxes'].cls.cpu().tolist() |
|
scores = result.__dict__['boxes'].conf.cpu().tolist() |
|
targets, captions = [], [] |
|
for box, cls, score in zip(boxes, classes, scores): |
|
if task == "figure": |
|
if cls == 3: |
|
targets.append({"box": box, "score": score}) |
|
elif task == "table": |
|
if cls == 5: |
|
targets.append({"box": box, "score": score}) |
|
elif task == "figurecaption": |
|
if cls == 3: |
|
targets.append({"box": box, "score": score}) |
|
elif cls == 4: |
|
captions.append({"box": box, "score": score}) |
|
elif task == "tablecaption": |
|
if cls == 5: |
|
targets.append({"box": box, "score": score}) |
|
elif cls == 6 or cls == 7: |
|
captions.append({"box": box, "score": score}) |
|
if not captions: |
|
figs = targets |
|
else: |
|
matches = [] |
|
for target in targets: |
|
min_distance = float('inf') |
|
for caption in captions: |
|
target_box, caption_box = target["box"], caption["box"] |
|
distance = abs(target_box[0] - caption_box[0]) + abs(target_box[3] - caption_box[1]) |
|
if distance < min_distance: |
|
min_distance = distance |
|
correct_match = (target, caption) |
|
matches.append(correct_match) |
|
for target, caption in matches: |
|
target_box, caption_box = target["box"], caption["box"] |
|
union_box = [ |
|
min(target_box[0], caption_box[0]), |
|
min(target_box[1], caption_box[1]), |
|
max(target_box[2], caption_box[2]), |
|
max(target_box[3], caption_box[3]), |
|
] |
|
figs.append({"box": union_box, "score": 1.0}) |
|
return figs |