P2P / poster /figures.py
ASC8384's picture
MultiThread
d4bdb14
import base64
import requests
import os
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed
from io import BytesIO
from PIL import Image
from retry import retry
from .loader import ImagePDFLoader
@retry(tries=3)
def _extract_figures(
url: str, img: Image.Image, task: str = "figure"
) -> list[tuple[Image.Image, float]]:
figures = []
with BytesIO() as buffer:
img.save(buffer, format="PNG")
files = [("img", ("image.png", buffer.getvalue(), "image/png"))]
payload = {"task": task}
rsp = requests.request("POST", url, data=payload, files=files)
rsp.raise_for_status()
for data in rsp.json():
figures.append((img.crop(data["box"]), data["score"]))
return figures
def extract_figures(
url: str, pdf: str, task: str = "figure", max_workers: int = 4
) -> list[tuple[str, float]]:
loader = ImagePDFLoader(pdf)
images = loader.load()
figures = []
with ThreadPoolExecutor(max_workers=max_workers) as executor:
future_to_image = {
executor.submit(_extract_figures, url, image, task): image
for image in images
}
for future in as_completed(future_to_image):
try:
result = future.result()
figures.extend(result)
except Exception as exc:
print(f'图像处理时发生错误: {exc}')
base64_figures = []
for figure, score in figures:
with BytesIO() as buffer:
figure.save(buffer, format="PNG")
base64_figures.append(
(base64.b64encode(buffer.getvalue()).decode("utf-8"), score)
)
return base64_figures
if __name__ == "__main__":
url = "https://kr4t0n--yolo-layout-detection-temp-layoutdetection-predict.modal.run"
pdf = "1.pdf"
output_dir = Path("output")
output_dir.mkdir(exist_ok=True)
base64_figures = extract_figures(url, pdf, task="figurecaption")
print(f"提取到 {len(base64_figures)} 张图像")
for i, (b64_str, score) in enumerate(base64_figures):
img_data = base64.b64decode(b64_str)
img = Image.open(BytesIO(img_data))
output_path = output_dir / f"figure_{i + 1}.png"
img.save(output_path)
print(f"图像已保存到: {output_path}")
print(f"所有图像已保存到 {output_dir} 目录")