File size: 2,414 Bytes
04aed77
 
 
 
d4bdb14
04aed77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d4bdb14
04aed77
 
 
 
 
d4bdb14
 
 
 
 
 
 
 
 
 
 
 
 
04aed77
 
 
 
 
 
 
 
 
 
 
 
 
d4bdb14
04aed77
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
import base64
import requests
import os
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed

from io import BytesIO
from PIL import Image
from retry import retry

from .loader import ImagePDFLoader


@retry(tries=3)
def _extract_figures(
    url: str, img: Image.Image, task: str = "figure"
) -> list[tuple[Image.Image, float]]:
    figures = []

    with BytesIO() as buffer:
        img.save(buffer, format="PNG")

        files = [("img", ("image.png", buffer.getvalue(), "image/png"))]
        payload = {"task": task}
        rsp = requests.request("POST", url, data=payload, files=files)
        rsp.raise_for_status()

    for data in rsp.json():
        figures.append((img.crop(data["box"]), data["score"]))

    return figures


def extract_figures(
    url: str, pdf: str, task: str = "figure", max_workers: int = 4
) -> list[tuple[str, float]]:
    loader = ImagePDFLoader(pdf)
    images = loader.load()

    figures = []
    
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        future_to_image = {
            executor.submit(_extract_figures, url, image, task): image 
            for image in images
        }
        
        for future in as_completed(future_to_image):
            try:
                result = future.result()
                figures.extend(result)
            except Exception as exc:
                print(f'图像处理时发生错误: {exc}')

    base64_figures = []
    for figure, score in figures:
        with BytesIO() as buffer:
            figure.save(buffer, format="PNG")
            base64_figures.append(
                (base64.b64encode(buffer.getvalue()).decode("utf-8"), score)
            )

    return base64_figures


if __name__ == "__main__":
    url = "https://kr4t0n--yolo-layout-detection-temp-layoutdetection-predict.modal.run"
    pdf = "1.pdf"

    output_dir = Path("output")
    output_dir.mkdir(exist_ok=True)

    base64_figures = extract_figures(url, pdf, task="figurecaption")

    print(f"提取到 {len(base64_figures)} 张图像")

    for i, (b64_str, score) in enumerate(base64_figures):
        img_data = base64.b64decode(b64_str)
        img = Image.open(BytesIO(img_data))

        output_path = output_dir / f"figure_{i + 1}.png"
        img.save(output_path)
        print(f"图像已保存到: {output_path}")

    print(f"所有图像已保存到 {output_dir} 目录")