File size: 4,237 Bytes
04aed77 bdc569e 04aed77 d4bdb14 04aed77 d4bdb14 04aed77 d4bdb14 04aed77 d4bdb14 04aed77 1ee1a5c 04aed77 d4bdb14 04aed77 bdc569e 04aed77 d32760a 04aed77 1ee1a5c 04aed77 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
import base64
import copy
from datetime import datetime
import json
import fire
import os
import pathlib
from poster.figures import extract_figures
from poster.poster import (
generate_html_v2,
generate_poster_v3,
replace_figures_in_poster,
replace_figures_size_in_poster,
)
def generate_paper_poster(
url: str,
pdf: str,
vendor: str = "openai",
model: str = "gpt-4o-mini",
text_prompt: str = "",
figures_prompt: str = "",
output: str = "poster.json",
):
"""Generate a paper poster
Args:
url: URL of the PDF file
pdf: Local path of the PDF file
model: Name of the model to use, default is gpt-4o-mini
text_prompt: Text prompt template,
figures_prompt: Figures prompt template,
output: Output file path, default is poster.json
"""
pdf_stem = pdf.replace(".pdf", "")
figures_cache = f"{pdf_stem}_figures.json"
figures_cap_cache = f"{pdf_stem}_figures_cap.json"
figures = []
# figures_cap = []
print("开始提取图片...")
if os.path.exists(figures_cache) and os.path.exists(figures_cap_cache):
print(f"使用缓存的图片: {figures_cache}")
with open(figures_cache, "r") as f:
figures = json.load(f)
# with open(figures_cap_cache, "r") as f:
# figures_cap = json.load(f)
else:
figures_img = extract_figures(url, pdf, task="figure")
figures_table = extract_figures(url, pdf, task="table")
# img_caption = extract_figures(url, pdf, task="figurecaption")
# table_caption = extract_figures(url, pdf, task="tablecaption")
threshold = 0.75
# while True:
figures = [
image
for image, score in figures_img + figures_table
if score >= threshold
]
# figures_cap = [
# image
# for image, score in img_caption + table_caption
# if score >= threshold
# ]
# print(f"{threshold:.2f} 提取到 {len(figures)} / {len(figures_cap)} 张图像")
# if len(figures) == len(figures_cap):
# break
# threshold -= 0.05
with open(figures_cache, "w") as f:
json.dump(figures, f, ensure_ascii=False)
# with open(figures_cap_cache, "w") as f:
# json.dump(figures_cap, f, ensure_ascii=False)
print("开始生成海报...")
max_attempts = 3
attempt = 0
while True:
try:
result = generate_poster_v3(
vendor, model, text_prompt, figures_prompt, pdf, figures, figures
)
poster = result["image_based_poster"]
backup_poster = copy.deepcopy(poster)
poster = replace_figures_in_poster(poster, figures)
# with open(output, "w") as f:
# json.dump(poster.model_dump(), f, ensure_ascii=False)
poster_size = replace_figures_size_in_poster(backup_poster, figures)
print(f"[{datetime.now().strftime('%Y-%m-%d %H:%M:%S')}] Now generating HTML...")
result = generate_html_v2(vendor, model, poster_size, figures)
html = result["html_with_figures"]
# with open(output.replace(".json", ".html"), "w") as f:
# f.write(html)
# take_screenshot(output, html)
print("海报生成成功!")
return poster, html
except Exception as e:
if (
"content management policy" in str(e)
or "message larger than max" in str(e)
or "exceeds the maximum length" in str(e)
or "maximum context length" in str(e)
or "Input is too long" in str(e)
or "image exceeds 5 MB" in str(e)
or "too many total text bytes" in str(e)
or "Range of input length" in str(e)
or "Invalid text" in str(e)
):
raise
print(f"处理文件 {pdf} 时出错: {e}")
attempt += 1
if attempt > max_attempts:
return None, None
if __name__ == "__main__":
fire.Fire(generate_paper_poster)
|