MultiThread
Browse files- main.py +24 -24
- poster/figures.py +16 -4
- poster/poster.py +58 -12
main.py
CHANGED
@@ -39,45 +39,45 @@ def generate_paper_poster(
|
|
39 |
figures_cap_cache = f"{pdf_stem}_figures_cap.json"
|
40 |
|
41 |
figures = []
|
42 |
-
figures_cap = []
|
43 |
print("开始提取图片...")
|
44 |
if os.path.exists(figures_cache) and os.path.exists(figures_cap_cache):
|
45 |
print(f"使用缓存的图片: {figures_cache}")
|
46 |
with open(figures_cache, "r") as f:
|
47 |
figures = json.load(f)
|
48 |
-
with open(figures_cap_cache, "r") as f:
|
49 |
-
|
50 |
else:
|
51 |
figures_img = extract_figures(url, pdf, task="figure")
|
52 |
figures_table = extract_figures(url, pdf, task="table")
|
53 |
-
img_caption = extract_figures(url, pdf, task="figurecaption")
|
54 |
-
table_caption = extract_figures(url, pdf, task="tablecaption")
|
55 |
-
threshold = 0.
|
56 |
-
while True:
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
figures_cap = [
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
]
|
67 |
-
print(f"{threshold:.2f} 提取到 {len(figures)} / {len(figures_cap)} 张图像")
|
68 |
-
if len(figures) == len(figures_cap):
|
69 |
-
|
70 |
-
threshold -= 0.05
|
71 |
|
72 |
with open(figures_cache, "w") as f:
|
73 |
json.dump(figures, f, ensure_ascii=False)
|
74 |
-
with open(figures_cap_cache, "w") as f:
|
75 |
-
|
76 |
|
77 |
while True:
|
78 |
try:
|
79 |
result = generate_poster_v3(
|
80 |
-
vendor, model, text_prompt, figures_prompt, pdf,
|
81 |
)
|
82 |
|
83 |
poster = result["image_based_poster"]
|
|
|
39 |
figures_cap_cache = f"{pdf_stem}_figures_cap.json"
|
40 |
|
41 |
figures = []
|
42 |
+
# figures_cap = []
|
43 |
print("开始提取图片...")
|
44 |
if os.path.exists(figures_cache) and os.path.exists(figures_cap_cache):
|
45 |
print(f"使用缓存的图片: {figures_cache}")
|
46 |
with open(figures_cache, "r") as f:
|
47 |
figures = json.load(f)
|
48 |
+
# with open(figures_cap_cache, "r") as f:
|
49 |
+
# figures_cap = json.load(f)
|
50 |
else:
|
51 |
figures_img = extract_figures(url, pdf, task="figure")
|
52 |
figures_table = extract_figures(url, pdf, task="table")
|
53 |
+
# img_caption = extract_figures(url, pdf, task="figurecaption")
|
54 |
+
# table_caption = extract_figures(url, pdf, task="tablecaption")
|
55 |
+
threshold = 0.75
|
56 |
+
# while True:
|
57 |
+
figures = [
|
58 |
+
image
|
59 |
+
for image, score in figures_img + figures_table
|
60 |
+
if score >= threshold
|
61 |
+
]
|
62 |
+
# figures_cap = [
|
63 |
+
# image
|
64 |
+
# for image, score in img_caption + table_caption
|
65 |
+
# if score >= threshold
|
66 |
+
# ]
|
67 |
+
# print(f"{threshold:.2f} 提取到 {len(figures)} / {len(figures_cap)} 张图像")
|
68 |
+
# if len(figures) == len(figures_cap):
|
69 |
+
# break
|
70 |
+
# threshold -= 0.05
|
71 |
|
72 |
with open(figures_cache, "w") as f:
|
73 |
json.dump(figures, f, ensure_ascii=False)
|
74 |
+
# with open(figures_cap_cache, "w") as f:
|
75 |
+
# json.dump(figures_cap, f, ensure_ascii=False)
|
76 |
|
77 |
while True:
|
78 |
try:
|
79 |
result = generate_poster_v3(
|
80 |
+
vendor, model, text_prompt, figures_prompt, pdf, figures, figures
|
81 |
)
|
82 |
|
83 |
poster = result["image_based_poster"]
|
poster/figures.py
CHANGED
@@ -2,6 +2,7 @@ import base64
|
|
2 |
import requests
|
3 |
import os
|
4 |
from pathlib import Path
|
|
|
5 |
|
6 |
from io import BytesIO
|
7 |
from PIL import Image
|
@@ -31,14 +32,25 @@ def _extract_figures(
|
|
31 |
|
32 |
|
33 |
def extract_figures(
|
34 |
-
url: str, pdf: str, task: str = "figure"
|
35 |
) -> list[tuple[str, float]]:
|
36 |
loader = ImagePDFLoader(pdf)
|
37 |
images = loader.load()
|
38 |
|
39 |
figures = []
|
40 |
-
|
41 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
|
43 |
base64_figures = []
|
44 |
for figure, score in figures:
|
@@ -52,7 +64,7 @@ def extract_figures(
|
|
52 |
|
53 |
|
54 |
if __name__ == "__main__":
|
55 |
-
url = ""
|
56 |
pdf = "1.pdf"
|
57 |
|
58 |
output_dir = Path("output")
|
|
|
2 |
import requests
|
3 |
import os
|
4 |
from pathlib import Path
|
5 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
6 |
|
7 |
from io import BytesIO
|
8 |
from PIL import Image
|
|
|
32 |
|
33 |
|
34 |
def extract_figures(
|
35 |
+
url: str, pdf: str, task: str = "figure", max_workers: int = 4
|
36 |
) -> list[tuple[str, float]]:
|
37 |
loader = ImagePDFLoader(pdf)
|
38 |
images = loader.load()
|
39 |
|
40 |
figures = []
|
41 |
+
|
42 |
+
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
43 |
+
future_to_image = {
|
44 |
+
executor.submit(_extract_figures, url, image, task): image
|
45 |
+
for image in images
|
46 |
+
}
|
47 |
+
|
48 |
+
for future in as_completed(future_to_image):
|
49 |
+
try:
|
50 |
+
result = future.result()
|
51 |
+
figures.extend(result)
|
52 |
+
except Exception as exc:
|
53 |
+
print(f'图像处理时发生错误: {exc}')
|
54 |
|
55 |
base64_figures = []
|
56 |
for figure, score in figures:
|
|
|
64 |
|
65 |
|
66 |
if __name__ == "__main__":
|
67 |
+
url = "https://kr4t0n--yolo-layout-detection-temp-layoutdetection-predict.modal.run"
|
68 |
pdf = "1.pdf"
|
69 |
|
70 |
output_dir = Path("output")
|
poster/poster.py
CHANGED
@@ -6,6 +6,7 @@ import re
|
|
6 |
import subprocess
|
7 |
import time
|
8 |
import cairosvg
|
|
|
9 |
|
10 |
from PIL import Image
|
11 |
from pdf2image import convert_from_path
|
@@ -388,8 +389,12 @@ def generate_html_v2(vendor: str, model: str, poster: BaseModel, figures: list[s
|
|
388 |
/ poster_total_size
|
389 |
)
|
390 |
|
391 |
-
max_attempts =
|
392 |
-
attempt =
|
|
|
|
|
|
|
|
|
393 |
|
394 |
while True:
|
395 |
body = re.search(r"```html\n(.*?)\n```", output, re.DOTALL).group(1)
|
@@ -401,6 +406,13 @@ def generate_html_v2(vendor: str, model: str, poster: BaseModel, figures: list[s
|
|
401 |
section_sizes = get_sizes("section", html_with_figures)
|
402 |
|
403 |
proportion = calculate_blank_proportion(poster_sizes, section_sizes)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
404 |
if proportion <= 0.1:
|
405 |
print(
|
406 |
f"Attempted {attempt} times, remaining {proportion:.0%} blank spaces."
|
@@ -409,10 +421,16 @@ def generate_html_v2(vendor: str, model: str, poster: BaseModel, figures: list[s
|
|
409 |
|
410 |
attempt += 1
|
411 |
if attempt > max_attempts:
|
412 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
413 |
|
414 |
react = [
|
415 |
-
# AIMessage(""),
|
416 |
HumanMessage(
|
417 |
content=f"""# Previous Body
|
418 |
{body}
|
@@ -514,10 +532,16 @@ def generate_poster_v3(
|
|
514 |
model=model,
|
515 |
temperature=1,
|
516 |
max_tokens=8000,
|
517 |
-
# model_kwargs={
|
518 |
-
# "extra_body": {"chat_template_kwargs": {"enable_thinking": False}}
|
519 |
-
# },
|
520 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
521 |
loader = PyMuPDFLoader(pdf)
|
522 |
pages = loader.load()
|
523 |
paper_content = "\n".join([page.page_content for page in pages])
|
@@ -629,16 +653,38 @@ Paper content:
|
|
629 |
figures_with_descriptions = f.read()
|
630 |
else:
|
631 |
figure_chain = figures_description_prompt | (mllm if use_claude else llm)
|
632 |
-
|
|
|
|
|
633 |
figure_description_response = figure_chain.invoke({"image_data": figure})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
634 |
figures_with_descriptions += f"""
|
635 |
<figure_{i}>
|
636 |
-
{
|
637 |
</figure_{i}>
|
638 |
"""
|
639 |
-
figure_list.append(
|
640 |
-
|
641 |
-
|
|
|
|
|
642 |
if use_claude:
|
643 |
with open(figures_description_cache, "w") as f:
|
644 |
f.write(figures_with_descriptions)
|
|
|
6 |
import subprocess
|
7 |
import time
|
8 |
import cairosvg
|
9 |
+
from concurrent.futures import ThreadPoolExecutor
|
10 |
|
11 |
from PIL import Image
|
12 |
from pdf2image import convert_from_path
|
|
|
389 |
/ poster_total_size
|
390 |
)
|
391 |
|
392 |
+
max_attempts = 6
|
393 |
+
attempt = 0
|
394 |
+
|
395 |
+
min_proportion = float('inf')
|
396 |
+
min_html = None
|
397 |
+
min_html_with_figures = None
|
398 |
|
399 |
while True:
|
400 |
body = re.search(r"```html\n(.*?)\n```", output, re.DOTALL).group(1)
|
|
|
406 |
section_sizes = get_sizes("section", html_with_figures)
|
407 |
|
408 |
proportion = calculate_blank_proportion(poster_sizes, section_sizes)
|
409 |
+
|
410 |
+
print(f"当前比例: {proportion:.0%}")
|
411 |
+
if proportion < min_proportion:
|
412 |
+
min_proportion = proportion
|
413 |
+
min_html = html
|
414 |
+
min_html_with_figures = html_with_figures
|
415 |
+
|
416 |
if proportion <= 0.1:
|
417 |
print(
|
418 |
f"Attempted {attempt} times, remaining {proportion:.0%} blank spaces."
|
|
|
421 |
|
422 |
attempt += 1
|
423 |
if attempt > max_attempts:
|
424 |
+
if min_proportion <= 0.2:
|
425 |
+
print(
|
426 |
+
f"Reached max attempts ({max_attempts}), returning best result with {min_proportion:.0%} blank spaces."
|
427 |
+
)
|
428 |
+
return {"html": min_html, "html_with_figures": min_html_with_figures}
|
429 |
+
else:
|
430 |
+
raise ValueError(f"Invalid blank spaces: {min_proportion:.0%}")
|
431 |
+
|
432 |
|
433 |
react = [
|
|
|
434 |
HumanMessage(
|
435 |
content=f"""# Previous Body
|
436 |
{body}
|
|
|
532 |
model=model,
|
533 |
temperature=1,
|
534 |
max_tokens=8000,
|
|
|
|
|
|
|
535 |
)
|
536 |
+
elif vendor == "azure":
|
537 |
+
llm = AzureChatOpenAI(
|
538 |
+
azure_deployment=model,
|
539 |
+
temperature=1,
|
540 |
+
max_tokens=8000,
|
541 |
+
)
|
542 |
+
else:
|
543 |
+
raise ValueError(f"Unsupported vendor: {vendor}")
|
544 |
+
|
545 |
loader = PyMuPDFLoader(pdf)
|
546 |
pages = loader.load()
|
547 |
paper_content = "\n".join([page.page_content for page in pages])
|
|
|
653 |
figures_with_descriptions = f.read()
|
654 |
else:
|
655 |
figure_chain = figures_description_prompt | (mllm if use_claude else llm)
|
656 |
+
|
657 |
+
def process_single_figure(figure_data):
|
658 |
+
figure, index = figure_data
|
659 |
figure_description_response = figure_chain.invoke({"image_data": figure})
|
660 |
+
return {
|
661 |
+
"index": index,
|
662 |
+
"figure": figure,
|
663 |
+
"description": figure_description_response.content
|
664 |
+
}
|
665 |
+
|
666 |
+
figure_data_list = [(figure, i) for i, figure in enumerate(figures)]
|
667 |
+
|
668 |
+
with ThreadPoolExecutor(max_workers=4) as executor:
|
669 |
+
results = list(tqdm(
|
670 |
+
executor.map(process_single_figure, figure_data_list),
|
671 |
+
total=len(figure_data_list),
|
672 |
+
desc=f"处理图片 {pdf}"
|
673 |
+
))
|
674 |
+
|
675 |
+
for result in results:
|
676 |
+
i = result["index"]
|
677 |
+
print(f"处理图片 {i} 完成")
|
678 |
figures_with_descriptions += f"""
|
679 |
<figure_{i}>
|
680 |
+
{result["description"]}
|
681 |
</figure_{i}>
|
682 |
"""
|
683 |
+
figure_list.append({
|
684 |
+
"figure": result["figure"],
|
685 |
+
"description": result["description"]
|
686 |
+
})
|
687 |
+
|
688 |
if use_claude:
|
689 |
with open(figures_description_cache, "w") as f:
|
690 |
f.write(figures_with_descriptions)
|