|
import base64 |
|
import requests |
|
import os |
|
from pathlib import Path |
|
from concurrent.futures import ThreadPoolExecutor, as_completed |
|
|
|
from io import BytesIO |
|
from PIL import Image |
|
from retry import retry |
|
|
|
from .loader import ImagePDFLoader |
|
|
|
|
|
@retry(tries=3) |
|
def _extract_figures( |
|
url: str, img: Image.Image, task: str = "figure" |
|
) -> list[tuple[Image.Image, float]]: |
|
figures = [] |
|
|
|
with BytesIO() as buffer: |
|
img.save(buffer, format="PNG") |
|
|
|
files = [("img", ("image.png", buffer.getvalue(), "image/png"))] |
|
payload = {"task": task} |
|
rsp = requests.request("POST", url, data=payload, files=files) |
|
rsp.raise_for_status() |
|
|
|
for data in rsp.json(): |
|
figures.append((img.crop(data["box"]), data["score"])) |
|
|
|
return figures |
|
|
|
|
|
def extract_figures( |
|
url: str, pdf: str, task: str = "figure", max_workers: int = 4 |
|
) -> list[tuple[str, float]]: |
|
loader = ImagePDFLoader(pdf) |
|
images = loader.load() |
|
|
|
figures = [] |
|
|
|
with ThreadPoolExecutor(max_workers=max_workers) as executor: |
|
future_to_image = { |
|
executor.submit(_extract_figures, url, image, task): image |
|
for image in images |
|
} |
|
|
|
for future in as_completed(future_to_image): |
|
try: |
|
result = future.result() |
|
figures.extend(result) |
|
except Exception as exc: |
|
print(f'图像处理时发生错误: {exc}') |
|
|
|
base64_figures = [] |
|
for figure, score in figures: |
|
with BytesIO() as buffer: |
|
figure.save(buffer, format="PNG") |
|
base64_figures.append( |
|
(base64.b64encode(buffer.getvalue()).decode("utf-8"), score) |
|
) |
|
|
|
return base64_figures |
|
|
|
|
|
if __name__ == "__main__": |
|
url = "https://kr4t0n--yolo-layout-detection-temp-layoutdetection-predict.modal.run" |
|
pdf = "1.pdf" |
|
|
|
output_dir = Path("output") |
|
output_dir.mkdir(exist_ok=True) |
|
|
|
base64_figures = extract_figures(url, pdf, task="figurecaption") |
|
|
|
print(f"提取到 {len(base64_figures)} 张图像") |
|
|
|
for i, (b64_str, score) in enumerate(base64_figures): |
|
img_data = base64.b64decode(b64_str) |
|
img = Image.open(BytesIO(img_data)) |
|
|
|
output_path = output_dir / f"figure_{i + 1}.png" |
|
img.save(output_path) |
|
print(f"图像已保存到: {output_path}") |
|
|
|
print(f"所有图像已保存到 {output_dir} 目录") |
|
|