ASC8384 commited on
Commit
d4bdb14
·
1 Parent(s): 42bb4dc

MultiThread

Browse files
Files changed (3) hide show
  1. main.py +24 -24
  2. poster/figures.py +16 -4
  3. poster/poster.py +58 -12
main.py CHANGED
@@ -39,45 +39,45 @@ def generate_paper_poster(
39
  figures_cap_cache = f"{pdf_stem}_figures_cap.json"
40
 
41
  figures = []
42
- figures_cap = []
43
  print("开始提取图片...")
44
  if os.path.exists(figures_cache) and os.path.exists(figures_cap_cache):
45
  print(f"使用缓存的图片: {figures_cache}")
46
  with open(figures_cache, "r") as f:
47
  figures = json.load(f)
48
- with open(figures_cap_cache, "r") as f:
49
- figures_cap = json.load(f)
50
  else:
51
  figures_img = extract_figures(url, pdf, task="figure")
52
  figures_table = extract_figures(url, pdf, task="table")
53
- img_caption = extract_figures(url, pdf, task="figurecaption")
54
- table_caption = extract_figures(url, pdf, task="tablecaption")
55
- threshold = 0.85
56
- while True:
57
- figures = [
58
- image
59
- for image, score in figures_img + figures_table
60
- if score >= threshold
61
- ]
62
- figures_cap = [
63
- image
64
- for image, score in img_caption + table_caption
65
- if score >= threshold
66
- ]
67
- print(f"{threshold:.2f} 提取到 {len(figures)} / {len(figures_cap)} 张图像")
68
- if len(figures) == len(figures_cap):
69
- break
70
- threshold -= 0.05
71
 
72
  with open(figures_cache, "w") as f:
73
  json.dump(figures, f, ensure_ascii=False)
74
- with open(figures_cap_cache, "w") as f:
75
- json.dump(figures_cap, f, ensure_ascii=False)
76
 
77
  while True:
78
  try:
79
  result = generate_poster_v3(
80
- vendor, model, text_prompt, figures_prompt, pdf, figures_cap, figures
81
  )
82
 
83
  poster = result["image_based_poster"]
 
39
  figures_cap_cache = f"{pdf_stem}_figures_cap.json"
40
 
41
  figures = []
42
+ # figures_cap = []
43
  print("开始提取图片...")
44
  if os.path.exists(figures_cache) and os.path.exists(figures_cap_cache):
45
  print(f"使用缓存的图片: {figures_cache}")
46
  with open(figures_cache, "r") as f:
47
  figures = json.load(f)
48
+ # with open(figures_cap_cache, "r") as f:
49
+ # figures_cap = json.load(f)
50
  else:
51
  figures_img = extract_figures(url, pdf, task="figure")
52
  figures_table = extract_figures(url, pdf, task="table")
53
+ # img_caption = extract_figures(url, pdf, task="figurecaption")
54
+ # table_caption = extract_figures(url, pdf, task="tablecaption")
55
+ threshold = 0.75
56
+ # while True:
57
+ figures = [
58
+ image
59
+ for image, score in figures_img + figures_table
60
+ if score >= threshold
61
+ ]
62
+ # figures_cap = [
63
+ # image
64
+ # for image, score in img_caption + table_caption
65
+ # if score >= threshold
66
+ # ]
67
+ # print(f"{threshold:.2f} 提取到 {len(figures)} / {len(figures_cap)} 张图像")
68
+ # if len(figures) == len(figures_cap):
69
+ # break
70
+ # threshold -= 0.05
71
 
72
  with open(figures_cache, "w") as f:
73
  json.dump(figures, f, ensure_ascii=False)
74
+ # with open(figures_cap_cache, "w") as f:
75
+ # json.dump(figures_cap, f, ensure_ascii=False)
76
 
77
  while True:
78
  try:
79
  result = generate_poster_v3(
80
+ vendor, model, text_prompt, figures_prompt, pdf, figures, figures
81
  )
82
 
83
  poster = result["image_based_poster"]
poster/figures.py CHANGED
@@ -2,6 +2,7 @@ import base64
2
  import requests
3
  import os
4
  from pathlib import Path
 
5
 
6
  from io import BytesIO
7
  from PIL import Image
@@ -31,14 +32,25 @@ def _extract_figures(
31
 
32
 
33
  def extract_figures(
34
- url: str, pdf: str, task: str = "figure"
35
  ) -> list[tuple[str, float]]:
36
  loader = ImagePDFLoader(pdf)
37
  images = loader.load()
38
 
39
  figures = []
40
- for image in images:
41
- figures.extend(_extract_figures(url, image, task))
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  base64_figures = []
44
  for figure, score in figures:
@@ -52,7 +64,7 @@ def extract_figures(
52
 
53
 
54
  if __name__ == "__main__":
55
- url = ""
56
  pdf = "1.pdf"
57
 
58
  output_dir = Path("output")
 
2
  import requests
3
  import os
4
  from pathlib import Path
5
+ from concurrent.futures import ThreadPoolExecutor, as_completed
6
 
7
  from io import BytesIO
8
  from PIL import Image
 
32
 
33
 
34
  def extract_figures(
35
+ url: str, pdf: str, task: str = "figure", max_workers: int = 4
36
  ) -> list[tuple[str, float]]:
37
  loader = ImagePDFLoader(pdf)
38
  images = loader.load()
39
 
40
  figures = []
41
+
42
+ with ThreadPoolExecutor(max_workers=max_workers) as executor:
43
+ future_to_image = {
44
+ executor.submit(_extract_figures, url, image, task): image
45
+ for image in images
46
+ }
47
+
48
+ for future in as_completed(future_to_image):
49
+ try:
50
+ result = future.result()
51
+ figures.extend(result)
52
+ except Exception as exc:
53
+ print(f'图像处理时发生错误: {exc}')
54
 
55
  base64_figures = []
56
  for figure, score in figures:
 
64
 
65
 
66
  if __name__ == "__main__":
67
+ url = "https://kr4t0n--yolo-layout-detection-temp-layoutdetection-predict.modal.run"
68
  pdf = "1.pdf"
69
 
70
  output_dir = Path("output")
poster/poster.py CHANGED
@@ -6,6 +6,7 @@ import re
6
  import subprocess
7
  import time
8
  import cairosvg
 
9
 
10
  from PIL import Image
11
  from pdf2image import convert_from_path
@@ -388,8 +389,12 @@ def generate_html_v2(vendor: str, model: str, poster: BaseModel, figures: list[s
388
  / poster_total_size
389
  )
390
 
391
- max_attempts = 5
392
- attempt = 1
 
 
 
 
393
 
394
  while True:
395
  body = re.search(r"```html\n(.*?)\n```", output, re.DOTALL).group(1)
@@ -401,6 +406,13 @@ def generate_html_v2(vendor: str, model: str, poster: BaseModel, figures: list[s
401
  section_sizes = get_sizes("section", html_with_figures)
402
 
403
  proportion = calculate_blank_proportion(poster_sizes, section_sizes)
 
 
 
 
 
 
 
404
  if proportion <= 0.1:
405
  print(
406
  f"Attempted {attempt} times, remaining {proportion:.0%} blank spaces."
@@ -409,10 +421,16 @@ def generate_html_v2(vendor: str, model: str, poster: BaseModel, figures: list[s
409
 
410
  attempt += 1
411
  if attempt > max_attempts:
412
- raise ValueError(f"Invalid blank spaces: {proportion:.0%}")
 
 
 
 
 
 
 
413
 
414
  react = [
415
- # AIMessage(""),
416
  HumanMessage(
417
  content=f"""# Previous Body
418
  {body}
@@ -514,10 +532,16 @@ def generate_poster_v3(
514
  model=model,
515
  temperature=1,
516
  max_tokens=8000,
517
- # model_kwargs={
518
- # "extra_body": {"chat_template_kwargs": {"enable_thinking": False}}
519
- # },
520
  )
 
 
 
 
 
 
 
 
 
521
  loader = PyMuPDFLoader(pdf)
522
  pages = loader.load()
523
  paper_content = "\n".join([page.page_content for page in pages])
@@ -629,16 +653,38 @@ Paper content:
629
  figures_with_descriptions = f.read()
630
  else:
631
  figure_chain = figures_description_prompt | (mllm if use_claude else llm)
632
- for i, figure in enumerate(tqdm(figures, desc=f"处理图片 {pdf}")):
 
 
633
  figure_description_response = figure_chain.invoke({"image_data": figure})
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
634
  figures_with_descriptions += f"""
635
  <figure_{i}>
636
- {figure_description_response.content}
637
  </figure_{i}>
638
  """
639
- figure_list.append(
640
- {"figure": figure, "description": figure_description_response.content}
641
- )
 
 
642
  if use_claude:
643
  with open(figures_description_cache, "w") as f:
644
  f.write(figures_with_descriptions)
 
6
  import subprocess
7
  import time
8
  import cairosvg
9
+ from concurrent.futures import ThreadPoolExecutor
10
 
11
  from PIL import Image
12
  from pdf2image import convert_from_path
 
389
  / poster_total_size
390
  )
391
 
392
+ max_attempts = 6
393
+ attempt = 0
394
+
395
+ min_proportion = float('inf')
396
+ min_html = None
397
+ min_html_with_figures = None
398
 
399
  while True:
400
  body = re.search(r"```html\n(.*?)\n```", output, re.DOTALL).group(1)
 
406
  section_sizes = get_sizes("section", html_with_figures)
407
 
408
  proportion = calculate_blank_proportion(poster_sizes, section_sizes)
409
+
410
+ print(f"当前比例: {proportion:.0%}")
411
+ if proportion < min_proportion:
412
+ min_proportion = proportion
413
+ min_html = html
414
+ min_html_with_figures = html_with_figures
415
+
416
  if proportion <= 0.1:
417
  print(
418
  f"Attempted {attempt} times, remaining {proportion:.0%} blank spaces."
 
421
 
422
  attempt += 1
423
  if attempt > max_attempts:
424
+ if min_proportion <= 0.2:
425
+ print(
426
+ f"Reached max attempts ({max_attempts}), returning best result with {min_proportion:.0%} blank spaces."
427
+ )
428
+ return {"html": min_html, "html_with_figures": min_html_with_figures}
429
+ else:
430
+ raise ValueError(f"Invalid blank spaces: {min_proportion:.0%}")
431
+
432
 
433
  react = [
 
434
  HumanMessage(
435
  content=f"""# Previous Body
436
  {body}
 
532
  model=model,
533
  temperature=1,
534
  max_tokens=8000,
 
 
 
535
  )
536
+ elif vendor == "azure":
537
+ llm = AzureChatOpenAI(
538
+ azure_deployment=model,
539
+ temperature=1,
540
+ max_tokens=8000,
541
+ )
542
+ else:
543
+ raise ValueError(f"Unsupported vendor: {vendor}")
544
+
545
  loader = PyMuPDFLoader(pdf)
546
  pages = loader.load()
547
  paper_content = "\n".join([page.page_content for page in pages])
 
653
  figures_with_descriptions = f.read()
654
  else:
655
  figure_chain = figures_description_prompt | (mllm if use_claude else llm)
656
+
657
+ def process_single_figure(figure_data):
658
+ figure, index = figure_data
659
  figure_description_response = figure_chain.invoke({"image_data": figure})
660
+ return {
661
+ "index": index,
662
+ "figure": figure,
663
+ "description": figure_description_response.content
664
+ }
665
+
666
+ figure_data_list = [(figure, i) for i, figure in enumerate(figures)]
667
+
668
+ with ThreadPoolExecutor(max_workers=4) as executor:
669
+ results = list(tqdm(
670
+ executor.map(process_single_figure, figure_data_list),
671
+ total=len(figure_data_list),
672
+ desc=f"处理图片 {pdf}"
673
+ ))
674
+
675
+ for result in results:
676
+ i = result["index"]
677
+ print(f"处理图片 {i} 完成")
678
  figures_with_descriptions += f"""
679
  <figure_{i}>
680
+ {result["description"]}
681
  </figure_{i}>
682
  """
683
+ figure_list.append({
684
+ "figure": result["figure"],
685
+ "description": result["description"]
686
+ })
687
+
688
  if use_claude:
689
  with open(figures_description_cache, "w") as f:
690
  f.write(figures_with_descriptions)