Venkat V commited on
Commit
152df72
Β·
1 Parent(s): 6ea5d07

updated with fixes to all modules

Browse files
api_backend.py CHANGED
@@ -1,3 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
1
  from fastapi import FastAPI, UploadFile, File, Form
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from fastapi.responses import JSONResponse
@@ -7,25 +18,41 @@ import io
7
  import json
8
  import base64
9
 
10
- # πŸ’‘ Import modules
11
  from yolo_module import run_yolo
12
  from ocr_module import extract_text, count_elements, validate_structure
13
- from graph_module import map_arrows, build_flowchart_json
14
  from summarizer_module import summarize_flowchart
15
 
 
16
  app = FastAPI()
17
 
18
- # πŸ”“ Enable CORS for Streamlit frontend
19
  app.add_middleware(
20
  CORSMiddleware,
21
- allow_origins=["*"], # Update with actual domain if needed
22
  allow_credentials=True,
23
  allow_methods=["*"],
24
  allow_headers=["*"],
25
  )
26
 
 
27
  @app.post("/process-image")
28
- async def process_image(file: UploadFile = File(...), debug: str = Form("false")):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  debug_mode = debug.lower() == "true"
30
  debug_log = []
31
 
@@ -33,35 +60,31 @@ async def process_image(file: UploadFile = File(...), debug: str = Form("false")
33
  debug_log.append("πŸ“₯ Received file upload")
34
  print(f"πŸ“₯ File received: {file.filename}")
35
 
36
- # πŸ–ΌοΈ Load image
37
  contents = await file.read()
38
  image = Image.open(io.BytesIO(contents)).convert("RGB")
39
  if debug_mode:
40
  debug_log.append("βœ… Image converted to RGB")
41
  print("βœ… Image converted to RGB")
42
 
43
- # πŸ“¦ YOLO Detection
44
  boxes, arrows, vis_debug = run_yolo(image)
45
  if debug_mode:
46
  debug_log.append(f"πŸ“¦ Detected {len(boxes)} boxes, {len(arrows)} arrows")
47
 
48
- # πŸ” OCR for each box
49
  for box in boxes:
50
  box["text"] = extract_text(image, box["bbox"], debug=debug_mode)
51
  print(f"πŸ” OCR for {box['id']}: {box['text']}")
52
  if debug_mode:
53
  debug_log.append(f"πŸ” {box['id']}: {box['text']}")
54
 
55
- # ➑️ Build directional edges
56
- edges = map_arrows(boxes, arrows)
57
- if debug_mode:
58
- debug_log.append(f"➑️ Mapped {len(edges)} directional edges")
59
 
60
- # 🧠 Build structured flowchart
61
- flowchart_json = build_flowchart_json(boxes, edges)
62
  print("🧠 Flowchart JSON:", json.dumps(flowchart_json, indent=2))
63
 
64
- # βœ… Sanity checks
65
  structure_info = count_elements(boxes, arrows, debug=debug_mode)
66
  validation = validate_structure(
67
  flowchart_json,
@@ -72,17 +95,18 @@ async def process_image(file: UploadFile = File(...), debug: str = Form("false")
72
  if debug_mode:
73
  debug_log.append(f"🧾 Validation: {validation}")
74
 
75
- # ✍️ Generate Summary
76
  summary = summarize_flowchart(flowchart_json)
77
  print("πŸ“ Summary:", summary)
78
 
79
- # πŸ–ΌοΈ Encode visual debug
80
  yolo_vis = None
81
  if debug_mode and vis_debug:
82
  vis_io = io.BytesIO()
83
  vis_debug.save(vis_io, format="PNG")
84
  yolo_vis = base64.b64encode(vis_io.getvalue()).decode("utf-8")
85
 
 
86
  return JSONResponse({
87
  "flowchart": flowchart_json,
88
  "summary": summary,
@@ -92,4 +116,5 @@ async def process_image(file: UploadFile = File(...), debug: str = Form("false")
92
 
93
 
94
  if __name__ == "__main__":
 
95
  uvicorn.run(app, host="0.0.0.0", port=7860)
 
1
+ """
2
+ api_backend.py
3
+
4
+ FastAPI backend for flowchart-to-English processing. This API supports receiving
5
+ an image file, running YOLO-based detection to identify boxes and arrows, performing
6
+ OCR, and generating structured JSON + English summary of the flowchart.
7
+
8
+ Endpoints:
9
+ - POST /process-image: Accepts image input and returns structured flowchart data.
10
+ """
11
+
12
  from fastapi import FastAPI, UploadFile, File, Form
13
  from fastapi.middleware.cors import CORSMiddleware
14
  from fastapi.responses import JSONResponse
 
18
  import json
19
  import base64
20
 
21
+ # πŸ”§ Import local processing modules
22
  from yolo_module import run_yolo
23
  from ocr_module import extract_text, count_elements, validate_structure
24
+ from graph_module import map_arrows, build_flowchart_json
25
  from summarizer_module import summarize_flowchart
26
 
27
+ # πŸ”₯ Initialize FastAPI app
28
  app = FastAPI()
29
 
30
+ # πŸ”“ Enable CORS to allow frontend (e.g., Streamlit on localhost) to connect
31
  app.add_middleware(
32
  CORSMiddleware,
33
+ allow_origins=["*"], # In production, replace with allowed frontend domain
34
  allow_credentials=True,
35
  allow_methods=["*"],
36
  allow_headers=["*"],
37
  )
38
 
39
+
40
  @app.post("/process-image")
41
+ async def process_image(
42
+ file: UploadFile = File(...),
43
+ debug: str = Form("false")
44
+ ):
45
+ """
46
+ Receives an uploaded flowchart image, performs object detection and OCR,
47
+ constructs a structured flowchart JSON, and generates a plain-English summary.
48
+
49
+ Args:
50
+ file (UploadFile): Flowchart image file (.png, .jpg, .jpeg).
51
+ debug (str): "true" to enable debug mode (includes OCR logs and YOLO preview).
52
+
53
+ Returns:
54
+ JSONResponse: Contains flowchart structure, summary, debug output, and optional YOLO overlay.
55
+ """
56
  debug_mode = debug.lower() == "true"
57
  debug_log = []
58
 
 
60
  debug_log.append("πŸ“₯ Received file upload")
61
  print(f"πŸ“₯ File received: {file.filename}")
62
 
63
+ # πŸ–ΌοΈ Convert file bytes to RGB image
64
  contents = await file.read()
65
  image = Image.open(io.BytesIO(contents)).convert("RGB")
66
  if debug_mode:
67
  debug_log.append("βœ… Image converted to RGB")
68
  print("βœ… Image converted to RGB")
69
 
70
+ # πŸ“¦ YOLO Detection for boxes and arrows
71
  boxes, arrows, vis_debug = run_yolo(image)
72
  if debug_mode:
73
  debug_log.append(f"πŸ“¦ Detected {len(boxes)} boxes, {len(arrows)} arrows")
74
 
75
+ # πŸ” Run OCR on each detected box
76
  for box in boxes:
77
  box["text"] = extract_text(image, box["bbox"], debug=debug_mode)
78
  print(f"πŸ” OCR for {box['id']}: {box['text']}")
79
  if debug_mode:
80
  debug_log.append(f"πŸ” {box['id']}: {box['text']}")
81
 
 
 
 
 
82
 
83
+ # 🧠 Build structured JSON from nodes and edges
84
+ flowchart_json = build_flowchart_json(boxes, arrows)
85
  print("🧠 Flowchart JSON:", json.dumps(flowchart_json, indent=2))
86
 
87
+ # βœ… Validate structure
88
  structure_info = count_elements(boxes, arrows, debug=debug_mode)
89
  validation = validate_structure(
90
  flowchart_json,
 
95
  if debug_mode:
96
  debug_log.append(f"🧾 Validation: {validation}")
97
 
98
+ # ✍️ Generate plain-English summary
99
  summary = summarize_flowchart(flowchart_json)
100
  print("πŸ“ Summary:", summary)
101
 
102
+ # πŸ–ΌοΈ Encode YOLO debug image (if debug enabled)
103
  yolo_vis = None
104
  if debug_mode and vis_debug:
105
  vis_io = io.BytesIO()
106
  vis_debug.save(vis_io, format="PNG")
107
  yolo_vis = base64.b64encode(vis_io.getvalue()).decode("utf-8")
108
 
109
+ # πŸ“€ Return full response
110
  return JSONResponse({
111
  "flowchart": flowchart_json,
112
  "summary": summary,
 
116
 
117
 
118
  if __name__ == "__main__":
119
+ # Run the FastAPI app using Uvicorn
120
  uvicorn.run(app, host="0.0.0.0", port=7860)
app.py CHANGED
@@ -1,76 +1,89 @@
1
  # app.py
 
 
 
 
 
2
  import streamlit as st
3
  from PIL import Image
4
- import io
5
  import base64
 
6
  import os
7
 
8
- # Local modules
9
- from yolo_module import run_yolo
10
- from ocr_module import extract_text
11
- from graph_module import map_arrows, build_flowchart_json
12
- from summarizer_module import summarize_flowchart
13
-
14
  st.set_page_config(page_title="Flowchart to English", layout="wide")
15
  st.title("πŸ“„ Flowchart to Plain English")
16
 
17
- # Enable debug mode
18
  debug_mode = st.toggle("πŸ”§ Show Debug Info", value=False)
19
 
20
- # Upload image
 
 
 
 
21
  uploaded_file = st.file_uploader("Upload a flowchart image", type=["png", "jpg", "jpeg"])
22
 
 
 
 
23
  if uploaded_file:
 
24
  image = Image.open(uploaded_file).convert("RGB")
25
-
26
- # Show resized preview
27
  max_width = 600
28
  ratio = max_width / float(image.size[0])
29
  resized_image = image.resize((max_width, int(image.size[1] * ratio)))
30
  st.image(resized_image, caption="πŸ“€ Uploaded Image", use_container_width=False)
31
 
32
  if st.button("πŸ” Analyze Flowchart"):
33
- progress = st.progress(0, text="Detecting boxes and arrows...")
34
- results, arrows, vis_debug = run_yolo(image)
35
- progress.progress(25, text="Running OCR...")
36
-
37
- debug_log = []
38
- debug_log.append(f"πŸ“¦ Detected {len(results)} boxes")
39
- debug_log.append(f"➑️ Detected {len(arrows)} arrows")
40
-
41
- for node in results:
42
- node["text"] = extract_text(image, node["bbox"], debug=debug_mode)
43
- label = node.get("label", "box")
44
- text = node["text"]
45
- debug_log.append(f"πŸ”– {node['id']} | Label: {label} | Text: {text}")
46
-
47
- progress.progress(50, text="Mapping arrows to nodes...")
48
- edges = map_arrows(results, arrows)
49
-
50
- progress.progress(75, text="Building graph structure...")
51
- flowchart = build_flowchart_json(results, edges)
52
-
53
- progress.progress(90, text="Generating explanation...")
54
- summary = summarize_flowchart(flowchart)
55
-
56
- # Show Debug Info first
57
- if debug_mode:
58
- st.markdown("### πŸ§ͺ Debug Info")
59
- st.code("\n".join(debug_log), language="markdown")
60
-
61
- st.markdown("### πŸ–ΌοΈ YOLO Detected Bounding Boxes")
62
- st.image(vis_debug, caption="YOLO Detected Boxes", use_container_width=True)
63
-
64
- # Show results: JSON (left), Summary (right)
65
- col1, col2 = st.columns(2)
66
- with col1:
67
- st.subheader("🧠 Flowchart JSON")
68
- st.json(flowchart)
69
- with col2:
70
- st.subheader("πŸ“ English Summary")
71
- st.markdown(summary)
72
-
73
- progress.progress(100, text="Done!")
74
-
 
 
 
 
 
 
75
  else:
76
  st.info("Upload a flowchart image to begin.")
 
1
  # app.py
2
+ """
3
+ Streamlit Frontend App: Uploads a flowchart image, sends it to FastAPI backend,
4
+ and displays the structured JSON and English summary. Supports multiple OCR engines.
5
+ """
6
+
7
  import streamlit as st
8
  from PIL import Image
9
+ import requests
10
  import base64
11
+ import io
12
  import os
13
 
14
+ # Set up Streamlit UI layout
 
 
 
 
 
15
  st.set_page_config(page_title="Flowchart to English", layout="wide")
16
  st.title("πŸ“„ Flowchart to Plain English")
17
 
18
+ # Enable debug mode toggle
19
  debug_mode = st.toggle("πŸ”§ Show Debug Info", value=False)
20
 
21
+ # OCR engine selection dropdown
22
+ ocr_engine = st.selectbox("Select OCR Engine", ["easyocr", "doctr"], index=0,
23
+ help="Choose between EasyOCR (lightweight) and Doctr (transformer-based)")
24
+
25
+ # Flowchart image uploader
26
  uploaded_file = st.file_uploader("Upload a flowchart image", type=["png", "jpg", "jpeg"])
27
 
28
+ # Backend API URL (defaults to localhost for dev)
29
+ API_URL = os.getenv("API_URL", "http://localhost:7860/process-image")
30
+
31
  if uploaded_file:
32
+ # Load and resize uploaded image for preview
33
  image = Image.open(uploaded_file).convert("RGB")
 
 
34
  max_width = 600
35
  ratio = max_width / float(image.size[0])
36
  resized_image = image.resize((max_width, int(image.size[1] * ratio)))
37
  st.image(resized_image, caption="πŸ“€ Uploaded Image", use_container_width=False)
38
 
39
  if st.button("πŸ” Analyze Flowchart"):
40
+ progress = st.progress(0, text="Sending image to backend...")
41
+
42
+ try:
43
+ # Send request to FastAPI backend
44
+ response = requests.post(
45
+ API_URL,
46
+ files={"file": uploaded_file.getvalue()},
47
+ data={
48
+ "debug": str(debug_mode).lower(),
49
+ "ocr_engine": ocr_engine
50
+ }
51
+ )
52
+
53
+ progress.progress(40, text="Processing detection and OCR...")
54
+
55
+ if response.status_code == 200:
56
+ result = response.json()
57
+
58
+ # Show debug info if enabled
59
+ if debug_mode:
60
+ st.markdown("### πŸ§ͺ Debug Info")
61
+ st.code(result.get("debug", ""), language="markdown")
62
+
63
+ # Show YOLO visual if available
64
+ if debug_mode and result.get("yolo_vis"):
65
+ st.markdown("### πŸ–ΌοΈ YOLO Detected Bounding Boxes")
66
+ yolo_bytes = base64.b64decode(result["yolo_vis"])
67
+ yolo_img = Image.open(io.BytesIO(yolo_bytes))
68
+ st.image(yolo_img, caption="YOLO Boxes", use_container_width=True)
69
+
70
+ progress.progress(80, text="Finalizing output...")
71
+
72
+ # Show flowchart JSON and generated English summary
73
+ col1, col2 = st.columns(2)
74
+ with col1:
75
+ st.subheader("🧠 Flowchart JSON")
76
+ st.json(result["flowchart"])
77
+ with col2:
78
+ st.subheader("πŸ“ English Summary")
79
+ st.markdown(result["summary"])
80
+
81
+ progress.progress(100, text="Done!")
82
+
83
+ else:
84
+ st.error(f"❌ Backend Error: {response.status_code} - {response.text}")
85
+
86
+ except Exception as e:
87
+ st.error(f"⚠️ Request Failed: {e}")
88
  else:
89
  st.info("Upload a flowchart image to begin.")
graph_module/__init__.py CHANGED
@@ -1,34 +1,83 @@
1
- # flowchart_builder.py
2
- # Arrow and graph logic for converting detected flowchart elements to structured JSON
3
-
4
- from shapely.geometry import box, Point
5
  from collections import defaultdict, deque
 
 
 
 
6
 
7
  def map_arrows(nodes, arrows):
8
  """
9
- Matches arrows to nodes based on geometric endpoints.
10
- Returns a list of (source_id, target_id, label) edges.
 
 
 
 
 
 
 
11
  """
12
  for node in nodes:
13
- node["shape"] = box(*node["bbox"])
 
 
 
 
14
 
15
  edges = []
 
 
 
 
 
 
 
 
 
 
 
 
16
  for arrow in arrows:
17
- tail_point = Point(arrow["tail"])
18
- head_point = Point(arrow["head"])
19
- label = arrow.get("label", "")
 
 
 
 
 
 
 
20
 
21
- source = next((n["id"] for n in nodes if n["shape"].contains(tail_point)), None)
22
- target = next((n["id"] for n in nodes if n["shape"].contains(head_point)), None)
 
 
 
 
 
 
 
23
 
24
  if source and target and source != target:
 
25
  edges.append((source, target, label))
 
 
26
 
27
  return edges
28
 
29
- def detect_node_type(text):
 
30
  """
31
- Heuristic-based type detection from node text.
 
 
 
 
 
 
 
32
  """
33
  text_lower = text.lower()
34
  if "start" in text_lower:
@@ -37,37 +86,43 @@ def detect_node_type(text):
37
  return "end"
38
  if "?" in text or "yes" in text_lower or "no" in text_lower:
39
  return "decision"
40
- return "process"
41
 
42
- def build_flowchart_json(nodes, edges):
 
43
  """
44
- Constructs flowchart JSON structure with parent and branching info.
 
 
 
 
 
 
 
45
  """
46
- graph = {}
 
 
 
47
  reverse_links = defaultdict(list)
48
  edge_labels = {}
49
 
50
- for node in nodes:
51
- text = node.get("text", "").strip()
52
- graph[node["id"]] = {
53
- "text": text,
54
- "type": node.get("type") or detect_node_type(text),
55
- "next": []
56
- }
57
-
58
  for src, tgt, label in edges:
59
- graph[src]["next"].append(tgt)
60
  reverse_links[tgt].append(src)
61
- edge_labels[(src, tgt)] = label.lower().strip()
 
 
 
62
 
63
- start_nodes = [nid for nid in graph if len(reverse_links[nid]) == 0]
64
- flowchart_json = {
65
- "start": start_nodes[0] if start_nodes else None,
66
  "steps": []
67
  }
68
 
69
  visited = set()
70
- queue = deque(start_nodes)
 
71
 
72
  while queue:
73
  curr = queue.popleft()
@@ -75,21 +130,24 @@ def build_flowchart_json(nodes, edges):
75
  continue
76
  visited.add(curr)
77
 
78
- node = graph[curr]
 
 
 
79
  step = {
80
  "id": curr,
81
- "text": node["text"],
82
- "type": node["type"]
83
  }
84
 
85
- parents = reverse_links[curr]
86
  if len(parents) == 1:
87
  step["parent"] = parents[0]
88
  elif len(parents) > 1:
89
  step["parents"] = parents
90
 
91
- next_nodes = node["next"]
92
- if node["type"] == "decision" and len(next_nodes) >= 2:
93
  step["branches"] = {}
94
  for tgt in next_nodes:
95
  label = edge_labels.get((curr, tgt), "")
@@ -99,14 +157,18 @@ def build_flowchart_json(nodes, edges):
99
  step["branches"]["no"] = tgt
100
  else:
101
  step["branches"].setdefault("unknown", []).append(tgt)
102
- queue.extend(next_nodes)
 
103
  elif len(next_nodes) == 1:
104
  step["next"] = next_nodes[0]
105
- queue.append(next_nodes[0])
 
106
  elif len(next_nodes) > 1:
107
  step["next"] = next_nodes
108
- queue.extend(next_nodes)
 
 
109
 
110
- flowchart_json["steps"].append(step)
111
 
112
- return flowchart_json
 
1
+ from shapely.geometry import box as shapely_box, Point
 
 
 
2
  from collections import defaultdict, deque
3
+ import math
4
+
5
+
6
+ MAX_FALLBACK_DIST = 150 # pixels
7
 
8
  def map_arrows(nodes, arrows):
9
  """
10
+ Map arrows to node boxes using geometric containment with fallback to nearest box.
11
+ Returns directional edges (source_id, target_id, label).
12
+
13
+ Args:
14
+ nodes (list): List of node dicts with 'bbox' field.
15
+ arrows (list): List of arrow dicts with 'tail' and 'head' coordinates.
16
+
17
+ Returns:
18
+ list: List of (source, target, label) tuples.
19
  """
20
  for node in nodes:
21
+ node["shape"] = shapely_box(*node["bbox"])
22
+ node["center"] = (
23
+ (node["bbox"][0] + node["bbox"][2]) // 2,
24
+ (node["bbox"][1] + node["bbox"][3]) // 2
25
+ )
26
 
27
  edges = []
28
+
29
+ def find_nearest_node(pt):
30
+ min_dist = float("inf")
31
+ nearest_id = None
32
+ for n in nodes:
33
+ cx, cy = n["center"]
34
+ dist = math.dist(pt, (cx, cy))
35
+ if dist < min_dist:
36
+ min_dist = dist
37
+ nearest_id = n["id"]
38
+ return nearest_id, min_dist
39
+
40
  for arrow in arrows:
41
+ if not isinstance(arrow, dict) or not isinstance(arrow.get("tail"), (tuple, list)) or not isinstance(arrow.get("head"), (tuple, list)):
42
+ print(f"⚠️ Skipping malformed arrow: {arrow}")
43
+ continue
44
+
45
+ tail_pt = Point(arrow["tail"])
46
+ head_pt = Point(arrow["head"])
47
+ label = arrow.get("label", "").strip()
48
+
49
+ source = next((n["id"] for n in nodes if n["shape"].buffer(10).contains(tail_pt)), None)
50
+ target = next((n["id"] for n in nodes if n["shape"].buffer(10).contains(head_pt)), None)
51
 
52
+ # Fallback to nearest node if not found
53
+ if not source:
54
+ source, dist = find_nearest_node(arrow["tail"])
55
+ if dist > MAX_FALLBACK_DIST:
56
+ source = None
57
+ if not target:
58
+ target, dist = find_nearest_node(arrow["head"])
59
+ if dist > MAX_FALLBACK_DIST:
60
+ target = None
61
 
62
  if source and target and source != target:
63
+ print(f"βœ… Mapped arrow from {source} β†’ {target} [{label}]")
64
  edges.append((source, target, label))
65
+ else:
66
+ print(f"⚠️ Could not map arrow endpoints to nodes: tail={arrow.get('tail')} head={arrow.get('head')}")
67
 
68
  return edges
69
 
70
+
71
+ def detect_node_type(text, default_type="process"):
72
  """
73
+ Heuristically infer the node type from its text.
74
+
75
+ Args:
76
+ text (str): Node label.
77
+ default_type (str): Fallback type.
78
+
79
+ Returns:
80
+ str: Inferred node type.
81
  """
82
  text_lower = text.lower()
83
  if "start" in text_lower:
 
86
  return "end"
87
  if "?" in text or "yes" in text_lower or "no" in text_lower:
88
  return "decision"
89
+ return default_type
90
 
91
+
92
+ def build_flowchart_json(nodes, arrows):
93
  """
94
+ Construct a structured flowchart JSON using basic graph traversal.
95
+
96
+ Args:
97
+ nodes (list): Detected node dicts.
98
+ arrows (list): Detected arrow dicts.
99
+
100
+ Returns:
101
+ dict: JSON with 'start' and 'steps'.
102
  """
103
+ edges = map_arrows(nodes, arrows)
104
+
105
+ # Build adjacency and reverse mappings
106
+ graph = defaultdict(list)
107
  reverse_links = defaultdict(list)
108
  edge_labels = {}
109
 
 
 
 
 
 
 
 
 
110
  for src, tgt, label in edges:
111
+ graph[src].append(tgt)
112
  reverse_links[tgt].append(src)
113
+ edge_labels[(src, tgt)] = label.lower()
114
+
115
+ all_node_ids = {n["id"] for n in nodes}
116
+ start_candidates = [nid for nid in all_node_ids if nid not in reverse_links]
117
 
118
+ flowchart = {
119
+ "start": start_candidates[0] if start_candidates else None,
 
120
  "steps": []
121
  }
122
 
123
  visited = set()
124
+ queue = deque(start_candidates)
125
+ id_to_node = {n["id"]: n for n in nodes}
126
 
127
  while queue:
128
  curr = queue.popleft()
 
130
  continue
131
  visited.add(curr)
132
 
133
+ node = id_to_node.get(curr, {})
134
+ text = node.get("text", "").strip()
135
+ node_type = node.get("type") or detect_node_type(text)
136
+
137
  step = {
138
  "id": curr,
139
+ "text": text,
140
+ "type": node_type
141
  }
142
 
143
+ parents = list(set(reverse_links[curr]))
144
  if len(parents) == 1:
145
  step["parent"] = parents[0]
146
  elif len(parents) > 1:
147
  step["parents"] = parents
148
 
149
+ next_nodes = list(set(graph[curr]))
150
+ if node_type == "decision" and next_nodes:
151
  step["branches"] = {}
152
  for tgt in next_nodes:
153
  label = edge_labels.get((curr, tgt), "")
 
157
  step["branches"]["no"] = tgt
158
  else:
159
  step["branches"].setdefault("unknown", []).append(tgt)
160
+ if tgt not in visited:
161
+ queue.append(tgt)
162
  elif len(next_nodes) == 1:
163
  step["next"] = next_nodes[0]
164
+ if next_nodes[0] not in visited:
165
+ queue.append(next_nodes[0])
166
  elif len(next_nodes) > 1:
167
  step["next"] = next_nodes
168
+ for n in next_nodes:
169
+ if n not in visited:
170
+ queue.append(n)
171
 
172
+ flowchart["steps"].append(step)
173
 
174
+ return flowchart
ocr_module/__init__.py CHANGED
@@ -1,19 +1,40 @@
1
- import easyocr
2
- from PIL import Image
 
 
 
 
3
  import numpy as np
 
4
  import cv2
5
- import torch
6
  from textblob import TextBlob
7
-
8
  from device_config import get_device
9
- device = get_device()
10
-
11
- # Enable GPU if available
12
- reader = easyocr.Reader(['en'], gpu=(device == "cuda"))
13
- print(f"βœ… EasyOCR reader initialized on: {device}")
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
 
16
  def expand_bbox(bbox, image_size, pad=10):
 
17
  x1, y1, x2, y2 = bbox
18
  x1 = max(0, x1 - pad)
19
  y1 = max(0, y1 - pad)
@@ -22,104 +43,115 @@ def expand_bbox(bbox, image_size, pad=10):
22
  return [x1, y1, x2, y2]
23
 
24
  def clean_text(text):
 
25
  blob = TextBlob(text)
26
  return str(blob.correct())
27
 
28
- def extract_text(image, bbox, debug=False, use_adaptive_threshold=False):
29
  """
30
- Run OCR on a cropped region of the image using EasyOCR with preprocessing.
31
 
32
  Parameters:
33
- image (PIL.Image): The full image.
34
- bbox (list): [x1, y1, x2, y2] coordinates of the region to crop.
35
- debug (bool): If True, show intermediate debug output.
36
- use_adaptive_threshold (bool): Use adaptive thresholding instead of Otsu's.
37
 
38
  Returns:
39
- str: Extracted and cleaned text.
40
  """
41
- # Expand bbox slightly
42
  bbox = expand_bbox(bbox, image.size, pad=10)
43
  x1, y1, x2, y2 = bbox
44
  cropped = image.crop((x1, y1, x2, y2))
45
 
46
- # Convert to OpenCV format (numpy array)
47
  cv_img = np.array(cropped)
48
-
49
- # Convert to grayscale
50
  gray = cv2.cvtColor(cv_img, cv2.COLOR_RGB2GRAY)
51
 
52
- # Apply Gaussian blur to reduce noise
53
- blurred = cv2.GaussianBlur(gray, (3, 3), 0)
 
54
 
55
- # Resize (upscale) image for better OCR accuracy
56
- scale_factor = 2.5
57
- resized = cv2.resize(blurred, (0, 0), fx=scale_factor, fy=scale_factor, interpolation=cv2.INTER_LINEAR)
58
 
59
- # Convert to RGB as EasyOCR expects color image
60
- resized_rgb = cv2.cvtColor(resized, cv2.COLOR_GRAY2RGB)
61
 
62
- # Optional: debug save
63
- if debug:
64
- debug_image = Image.fromarray(resized_rgb)
65
- debug_image.save(f"debug_ocr_crop_{x1}_{y1}.png")
66
 
67
- # Run OCR using EasyOCR
68
- try:
69
- results = reader.readtext(resized_rgb, paragraph=False, min_size=5)
70
- except Exception as e:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  if debug:
72
- print(f"⚠️ EasyOCR failed: {e}")
73
  return ""
74
 
75
- if debug:
76
- for res in results:
77
- print(f"OCR: {res[1]} (conf: {res[2]:.2f})")
78
-
79
- # Sort boxes top to bottom, then left to right
80
- results.sort(key=lambda r: (r[0][0][1], r[0][0][0]))
81
-
82
- # Filter by confidence
83
- filtered = [r for r in results if r[2] > 0.4]
84
- if not filtered and results:
85
- filtered = sorted(results, key=lambda r: -r[2])[:2] # fallback to top-2
86
-
87
- lines = []
88
- for res in filtered:
89
- lines.append(res[1])
90
-
91
- joined_text = " ".join(lines).strip()
92
-
93
- # Apply correction
94
- if joined_text:
95
- joined_text = clean_text(joined_text)
96
-
97
- return joined_text
98
-
99
-
100
  def count_elements(boxes, arrows, debug=False):
 
101
  box_count = len(boxes)
102
  arrow_count = len(arrows)
103
  if debug:
104
- print(f"πŸ“¦ Detected {box_count} boxes")
105
- print(f"➑️ Detected {arrow_count} arrows")
106
- return {
107
- "box_count": box_count,
108
- "arrow_count": arrow_count
109
- }
110
-
111
 
112
  def validate_structure(flowchart_json, expected_boxes=None, expected_arrows=None, debug=False):
 
113
  actual_boxes = len(flowchart_json.get("steps", []))
114
  actual_arrows = len(flowchart_json.get("edges", [])) if "edges" in flowchart_json else None
115
 
116
  if debug:
117
- print(f"πŸ” Flowchart JSON has {actual_boxes} steps")
118
- if actual_arrows is not None:
119
- print(f"πŸ” Flowchart JSON has {actual_arrows} edges")
120
 
121
- result = {
122
  "boxes_valid": (expected_boxes is None or expected_boxes == actual_boxes),
123
  "arrows_valid": (expected_arrows is None or expected_arrows == actual_arrows)
124
- }
125
- return result
 
1
+ """
2
+ OCR module with support for EasyOCR and Doctr.
3
+ Provides the `extract_text` function that accepts a cropped bounding box and image,
4
+ and runs OCR based on the selected engine ("easyocr" or "doctr").
5
+ """
6
+
7
  import numpy as np
8
+ from PIL import Image
9
  import cv2
 
10
  from textblob import TextBlob
 
11
  from device_config import get_device
 
 
 
 
 
12
 
13
+ # OCR engine flags
14
+ USE_EASYOCR = True
15
+ USE_DOCTR = False
16
+
17
+ # Import EasyOCR if available
18
+ try:
19
+ import easyocr
20
+ reader = easyocr.Reader(['en'], gpu=(get_device() == "cuda"))
21
+ print(f"βœ… EasyOCR reader initialized on: {get_device()}")
22
+ USE_EASYOCR = True
23
+ except ImportError:
24
+ print("⚠️ EasyOCR not installed. Falling back if Doctr is available.")
25
+
26
+ # Import Doctr if available
27
+ try:
28
+ from doctr.io import DocumentFile
29
+ from doctr.models import ocr_predictor
30
+ doctr_model = ocr_predictor(pretrained=True)
31
+ print("βœ… Doctr model loaded.")
32
+ USE_DOCTR = True
33
+ except ImportError:
34
+ print("⚠️ Doctr not installed.")
35
 
36
  def expand_bbox(bbox, image_size, pad=10):
37
+ """Expand a bounding box by padding within image bounds."""
38
  x1, y1, x2, y2 = bbox
39
  x1 = max(0, x1 - pad)
40
  y1 = max(0, y1 - pad)
 
43
  return [x1, y1, x2, y2]
44
 
45
  def clean_text(text):
46
+ """Use TextBlob to autocorrect basic OCR errors."""
47
  blob = TextBlob(text)
48
  return str(blob.correct())
49
 
50
+ def extract_text(image, bbox, debug=False, engine="easyocr"):
51
  """
52
+ Run OCR on a cropped region using EasyOCR or Doctr.
53
 
54
  Parameters:
55
+ image (PIL.Image): Full input image.
56
+ bbox (list): [x1, y1, x2, y2] bounding box.
57
+ debug (bool): Enable debug output.
58
+ engine (str): 'easyocr' or 'doctr'.
59
 
60
  Returns:
61
+ str: Cleaned OCR output.
62
  """
63
+ # Expand and crop image region
64
  bbox = expand_bbox(bbox, image.size, pad=10)
65
  x1, y1, x2, y2 = bbox
66
  cropped = image.crop((x1, y1, x2, y2))
67
 
68
+ # Convert to OpenCV grayscale
69
  cv_img = np.array(cropped)
 
 
70
  gray = cv2.cvtColor(cv_img, cv2.COLOR_RGB2GRAY)
71
 
72
+ # Enhance contrast using CLAHE (Contrast Limited Adaptive Histogram Equalization)
73
+ clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
74
+ enhanced = clahe.apply(gray)
75
 
76
+ # Apply adaptive threshold for better text separation
77
+ thresh = cv2.adaptiveThreshold(enhanced, 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY, 11, 4)
 
78
 
79
+ # Resize for better OCR resolution
80
+ resized = cv2.resize(thresh, (0, 0), fx=2.5, fy=2.5, interpolation=cv2.INTER_LINEAR)
81
 
82
+ # Convert to RGB (some OCR engines expect 3-channel images)
83
+ preprocessed = cv2.cvtColor(resized, cv2.COLOR_GRAY2RGB)
 
 
84
 
85
+ if debug:
86
+ Image.fromarray(preprocessed).save(f"debug_ocr_crop_{x1}_{y1}.png")
87
+
88
+ if engine == "doctr" and USE_DOCTR:
89
+ try:
90
+ doc = DocumentFile.from_images([Image.fromarray(preprocessed)])
91
+ result = doctr_model(doc)
92
+ out_text = " ".join([b.value for b in result.pages[0].blocks])
93
+ if debug:
94
+ print(f"πŸ“˜ Doctr OCR: {out_text}")
95
+ return clean_text(out_text)
96
+ except Exception as e:
97
+ if debug:
98
+ print(f"❌ Doctr failed: {e}")
99
+ return ""
100
+
101
+ elif engine == "easyocr" and USE_EASYOCR:
102
+ try:
103
+ results = reader.readtext(preprocessed, paragraph=False, min_size=10)
104
+ filtered = []
105
+ for r in results:
106
+ text = r[1].strip()
107
+ conf = r[2]
108
+ if conf > 0.5 and len(text) > 2 and any(c.isalnum() for c in text):
109
+ filtered.append(r)
110
+
111
+ # Remove duplicates by bounding box IoU overlap
112
+ final = []
113
+ seen = set()
114
+ for r in filtered:
115
+ t = r[1].strip()
116
+ if t.lower() not in seen:
117
+ seen.add(t.lower())
118
+ final.append(r)
119
+
120
+ final.sort(key=lambda r: (r[0][0][1], r[0][0][0]))
121
+ text = " ".join([r[1] for r in final]).strip()
122
+
123
+ if debug:
124
+ for r in final:
125
+ print(f"πŸ“± EasyOCR: {r[1]} (conf: {r[2]:.2f})")
126
+
127
+ return clean_text(text) if text else ""
128
+ except Exception as e:
129
+ if debug:
130
+ print(f"❌ EasyOCR failed: {e}")
131
+ return ""
132
+
133
+ else:
134
  if debug:
135
+ print(f"⚠️ Unsupported OCR engine: {engine} or not available.")
136
  return ""
137
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  def count_elements(boxes, arrows, debug=False):
139
+ """Return count of boxes and arrows detected."""
140
  box_count = len(boxes)
141
  arrow_count = len(arrows)
142
  if debug:
143
+ print(f"πŸ“¦ Boxes: {box_count} | ➑️ Arrows: {arrow_count}")
144
+ return {"box_count": box_count, "arrow_count": arrow_count}
 
 
 
 
 
145
 
146
  def validate_structure(flowchart_json, expected_boxes=None, expected_arrows=None, debug=False):
147
+ """Validate flowchart structure consistency based on expected counts."""
148
  actual_boxes = len(flowchart_json.get("steps", []))
149
  actual_arrows = len(flowchart_json.get("edges", [])) if "edges" in flowchart_json else None
150
 
151
  if debug:
152
+ print(f"πŸ” JSON boxes: {actual_boxes}, edges: {actual_arrows}")
 
 
153
 
154
+ return {
155
  "boxes_valid": (expected_boxes is None or expected_boxes == actual_boxes),
156
  "arrows_valid": (expected_arrows is None or expected_arrows == actual_arrows)
157
+ }
 
ocr_module/__init__pyt.py DELETED
@@ -1,135 +0,0 @@
1
- import easyocr
2
- from PIL import Image
3
- import numpy as np
4
- import cv2
5
- import torch
6
- from textblob import TextBlob
7
- import re
8
-
9
- # Enable GPU if available
10
- use_gpu = torch.cuda.is_available()
11
- reader = easyocr.Reader(['en'], gpu=use_gpu)
12
-
13
- def expand_bbox(bbox, image_size, pad=10):
14
- x1, y1, x2, y2 = bbox
15
- x1 = max(0, x1 - pad)
16
- y1 = max(0, y1 - pad)
17
- x2 = min(image_size[0], x2 + pad)
18
- y2 = min(image_size[1], y2 + pad)
19
- return [x1, y1, x2, y2]
20
-
21
- def clean_text(text):
22
- # Basic cleanup
23
- text = re.sub(r'[^A-Za-z0-9?,.:;()\'"\s-]', '', text) # remove noise characters
24
- text = re.sub(r'\s+', ' ', text).strip()
25
-
26
- # De-duplicate repeated words
27
- words = text.split()
28
- deduped = [words[0]] + [w for i, w in enumerate(words[1:], 1) if w.lower() != words[i - 1].lower()] if words else []
29
- joined = " ".join(deduped)
30
-
31
- # Run correction only if needed (long word or all caps)
32
- if len(joined) > 3 and any(len(w) > 10 or w.isupper() for w in deduped):
33
- blob = TextBlob(joined)
34
- joined = str(blob.correct())
35
-
36
- return joined
37
-
38
- def extract_text(image, bbox, debug=False, use_adaptive_threshold=False):
39
- """
40
- Run OCR on a cropped region of the image using EasyOCR with preprocessing.
41
-
42
- Parameters:
43
- image (PIL.Image): The full image.
44
- bbox (list): [x1, y1, x2, y2] coordinates of the region to crop.
45
- debug (bool): If True, show intermediate debug output.
46
- use_adaptive_threshold (bool): Use adaptive thresholding instead of Otsu's.
47
-
48
- Returns:
49
- str: Extracted and cleaned text.
50
- """
51
- # Expand bbox slightly
52
- bbox = expand_bbox(bbox, image.size, pad=10)
53
- x1, y1, x2, y2 = bbox
54
- cropped = image.crop((x1, y1, x2, y2))
55
-
56
- # Convert to OpenCV format (numpy array)
57
- cv_img = np.array(cropped)
58
-
59
- # Convert to grayscale
60
- gray = cv2.cvtColor(cv_img, cv2.COLOR_RGB2GRAY)
61
-
62
- # Apply Gaussian blur to reduce noise
63
- blurred = cv2.GaussianBlur(gray, (3, 3), 0)
64
-
65
- # Resize (upscale) image for better OCR accuracy
66
- scale_factor = 2.5
67
- resized = cv2.resize(blurred, (0, 0), fx=scale_factor, fy=scale_factor, interpolation=cv2.INTER_LINEAR)
68
-
69
- # Convert to RGB as EasyOCR expects color image
70
- resized_rgb = cv2.cvtColor(resized, cv2.COLOR_GRAY2RGB)
71
-
72
- # Optional: debug save
73
- if debug:
74
- debug_image = Image.fromarray(resized_rgb)
75
- debug_image.save(f"debug_ocr_crop_{x1}_{y1}.png")
76
-
77
- # Run OCR using EasyOCR
78
- try:
79
- results = reader.readtext(resized_rgb, paragraph=False, min_size=5)
80
- except Exception as e:
81
- if debug:
82
- print(f"⚠️ EasyOCR failed: {e}")
83
- return ""
84
-
85
- if debug:
86
- for res in results:
87
- print(f"OCR: {res[1]} (conf: {res[2]:.2f})")
88
-
89
- # Sort boxes top to bottom, then left to right
90
- results.sort(key=lambda r: (r[0][0][1], r[0][0][0]))
91
-
92
- # Filter by confidence
93
- filtered = [r for r in results if r[2] > 0.4]
94
- if not filtered and results:
95
- filtered = sorted(results, key=lambda r: -r[2])[:2] # fallback to top-2
96
-
97
- lines = []
98
- for res in filtered:
99
- lines.append(res[1])
100
-
101
- joined_text = " ".join(lines).strip()
102
-
103
- # Apply correction
104
- if joined_text:
105
- joined_text = clean_text(joined_text)
106
- if debug:
107
- print(f"🧹 Cleaned OCR text: {joined_text}")
108
-
109
- return joined_text
110
-
111
- def count_elements(boxes, arrows, debug=False):
112
- box_count = len(boxes)
113
- arrow_count = len(arrows)
114
- if debug:
115
- print(f"πŸ“¦ Detected {box_count} boxes")
116
- print(f"➑️ Detected {arrow_count} arrows")
117
- return {
118
- "box_count": box_count,
119
- "arrow_count": arrow_count
120
- }
121
-
122
- def validate_structure(flowchart_json, expected_boxes=None, expected_arrows=None, debug=False):
123
- actual_boxes = len(flowchart_json.get("steps", []))
124
- actual_arrows = len(flowchart_json.get("edges", [])) if "edges" in flowchart_json else None
125
-
126
- if debug:
127
- print(f"πŸ” Flowchart JSON has {actual_boxes} steps")
128
- if actual_arrows is not None:
129
- print(f"πŸ” Flowchart JSON has {actual_arrows} edges")
130
-
131
- result = {
132
- "boxes_valid": (expected_boxes is None or expected_boxes == actual_boxes),
133
- "arrows_valid": (expected_arrows is None or expected_arrows == actual_arrows)
134
- }
135
- return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
requirements.txt CHANGED
@@ -20,6 +20,10 @@ numpy # Core image array operations
20
  easyocr # GPU-capable OCR engine
21
  textblob # Optional: lightweight text post-processing (optional)
22
 
 
 
 
 
23
  # πŸ€– Object Detection and Language Models
24
  ultralytics # YOLOv8/v9 detection (loads .pt models)
25
  torch # Backend for YOLO and EasyOCR
 
20
  easyocr # GPU-capable OCR engine
21
  textblob # Optional: lightweight text post-processing (optional)
22
 
23
+ # --- Doctr dependencies (torch-based) ---
24
+ python-doctr[torch]
25
+ onnxruntime # Required backend for Doctr inference
26
+
27
  # πŸ€– Object Detection and Language Models
28
  ultralytics # YOLOv8/v9 detection (loads .pt models)
29
  torch # Backend for YOLO and EasyOCR
summarizer_module/__init__.py CHANGED
@@ -1,38 +1,48 @@
1
- # summarizer_module/__init__.py
2
-
3
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
4
  from device_config import get_device
5
  import torch
 
6
 
 
7
  device = get_device()
8
 
9
- # Use a small local model (e.g., Phi-2)
10
- MODEL_ID = "microsoft/phi-2" # Ensure it's downloaded and cached locally
11
 
12
- # Load model and tokenizer
13
  model = AutoModelForCausalLM.from_pretrained(MODEL_ID).to(device)
14
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
15
  summarizer = pipeline("text-generation", model=model, tokenizer=tokenizer)
16
 
17
  def summarize_flowchart(flowchart_json):
18
  """
19
- Given a flowchart JSON with 'start' and 'steps', returns a plain English explanation
20
- formatted as bullets and sub-bullets.
21
 
22
  Args:
23
- flowchart_json (dict): Structured representation of flowchart
24
 
25
  Returns:
26
- str: Bullet-style natural language summary of the logic
27
  """
 
28
  prompt = (
29
- "Turn the following flowchart into a bullet-point explanation in plain English.\n"
30
- "Use bullets for steps and sub-bullets for branches.\n"
31
- "\n"
32
- f"Flowchart JSON:\n{flowchart_json}\n"
33
- "\nExplanation:"
34
- )
35
-
36
- result = summarizer(prompt, max_new_tokens=300, do_sample=False)[0]["generated_text"]
37
- explanation = result.split("Explanation:")[-1].strip()
 
 
 
 
 
 
 
 
 
 
38
  return explanation
 
 
 
1
  from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
2
  from device_config import get_device
3
  import torch
4
+ import json
5
 
6
+ # Automatically choose device (CUDA, MPS, CPU)
7
  device = get_device()
8
 
9
+ # βš™οΈ Model config: Use phi-2-mini (replace with phi-4-mini when available)
10
+ MODEL_ID = "microsoft/Phi-4-mini-instruct"
11
 
12
+ # Load tokenizer and model
13
  model = AutoModelForCausalLM.from_pretrained(MODEL_ID).to(device)
14
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
15
  summarizer = pipeline("text-generation", model=model, tokenizer=tokenizer)
16
 
17
  def summarize_flowchart(flowchart_json):
18
  """
19
+ Generates a human-friendly explanation from flowchart JSON.
 
20
 
21
  Args:
22
+ flowchart_json (dict): Contains "start" node and a list of "steps".
23
 
24
  Returns:
25
+ str: Bullet-style explanation with proper nesting and flow.
26
  """
27
+ # πŸ“„ Prompt optimized for flow comprehension
28
  prompt = (
29
+ "You are an expert in visual reasoning and instruction generation.\n"
30
+ "Convert the following flowchart JSON into a clear, step-by-step summary using bullets.\n"
31
+ "- Each bullet represents a process step.\n"
32
+ "- Use indented sub-bullets to explain decision branches (Yes/No).\n"
33
+ "- Maintain order based on dependencies and parent-child links.\n"
34
+ "- Avoid repeating the same step more than once.\n"
35
+ "- Do not include JSON in the output, only human-readable text.\n"
36
+ "\nFlowchart:\n{flowchart}\n\nBullet Explanation:"
37
+ ).format(flowchart=json.dumps(flowchart_json, indent=2))
38
+
39
+ # 🧠 Run the model inference
40
+ result = summarizer(prompt, max_new_tokens=400, do_sample=False)[0]["generated_text"]
41
+
42
+ # Extract the portion after the final prompt marker
43
+ if "Bullet Explanation:" in result:
44
+ explanation = result.split("Bullet Explanation:")[-1].strip()
45
+ else:
46
+ explanation = result.strip()
47
+
48
  return explanation
yolo_module/__init__.py CHANGED
@@ -1,55 +1,110 @@
1
- # yolo_module.py
 
 
 
 
2
  from ultralytics import YOLO
3
  from device_config import get_device
4
- from PIL import Image, ImageDraw
5
  import numpy as np
6
  import easyocr
 
 
7
 
8
- # Load YOLO model
9
  MODEL_PATH = "models/best.pt"
10
  device = get_device()
11
  model = YOLO(MODEL_PATH).to(device)
12
  print(f"βœ… YOLO model loaded on: {device}")
13
 
 
 
 
 
 
 
 
 
 
 
14
 
15
- # Optional OCR reader for arrow label detection
16
- reader = easyocr.Reader(['en'], gpu=False)
 
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  def run_yolo(image: Image.Image):
 
 
 
 
 
 
 
 
 
 
 
 
19
  results = model.predict(image, conf=0.25, verbose=False)[0]
20
 
21
  boxes = []
22
  arrows = []
23
-
24
- # Convert image to OpenCV format for EasyOCR
25
- np_img = np.array(image)
26
 
27
  for i, box in enumerate(results.boxes):
28
  cls_id = int(box.cls)
29
- conf = float(box.conf)
30
  label = model.names[cls_id]
31
 
32
  x1, y1, x2, y2 = map(int, box.xyxy[0])
33
  bbox = [x1, y1, x2, y2]
34
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  item = {
36
  "id": f"node{i+1}",
37
  "bbox": bbox,
38
- "type": "arrow" if label in ["arrow", "control_flow"] else "box",
39
  "label": label
40
  }
41
 
42
- if item["type"] == "arrow":
43
- # Heuristically scan a small region near the middle of the arrow for a label
44
  cx, cy = (x1 + x2) // 2, (y1 + y2) // 2
45
  pad = 20
46
  crop = np_img[max(cy - pad, 0):cy + pad, max(cx - pad, 0):cx + pad]
47
-
48
  detected_label = ""
49
  if crop.size > 0:
50
- ocr_results = reader.readtext(crop)
51
- if ocr_results:
52
- detected_label = ocr_results[0][1] # (bbox, text, conf)
 
 
 
53
 
54
  arrows.append({
55
  "id": f"arrow{len(arrows)+1}",
@@ -60,5 +115,10 @@ def run_yolo(image: Image.Image):
60
  else:
61
  boxes.append(item)
62
 
 
 
 
 
63
  vis_image = results.plot(pil=True)
64
- return boxes, arrows, vis_image
 
 
1
+ """
2
+ YOLO module for detecting flowchart elements (boxes and arrows).
3
+ Includes optional OCR for labeling arrows and deduplication to eliminate overlapping detections.
4
+ """
5
+
6
  from ultralytics import YOLO
7
  from device_config import get_device
8
+ from PIL import Image
9
  import numpy as np
10
  import easyocr
11
+ from shapely.geometry import box as shapely_box
12
+ import torch
13
 
14
+ # Load YOLO model and move to appropriate device
15
  MODEL_PATH = "models/best.pt"
16
  device = get_device()
17
  model = YOLO(MODEL_PATH).to(device)
18
  print(f"βœ… YOLO model loaded on: {device}")
19
 
20
+ # EasyOCR reader used for detecting optional labels near arrows
21
+ reader = easyocr.Reader(['en'], gpu=(device == "cuda"))
22
+
23
+
24
+ def iou(box1, box2):
25
+ """Compute Intersection over Union (IoU) between two bounding boxes."""
26
+ b1 = shapely_box(*box1)
27
+ b2 = shapely_box(*box2)
28
+ return b1.intersection(b2).area / b1.union(b2).area
29
+
30
 
31
+ def deduplicate_boxes(boxes, iou_threshold=0.6):
32
+ """
33
+ Eliminate overlapping or duplicate boxes based on IoU threshold.
34
 
35
+ Args:
36
+ boxes (list): List of box dictionaries with 'bbox' key.
37
+ iou_threshold (float): Threshold above which boxes are considered duplicates.
38
+
39
+ Returns:
40
+ list: Filtered list of unique boxes.
41
+ """
42
+ filtered = []
43
+ for box in boxes:
44
+ if all(iou(box['bbox'], other['bbox']) < iou_threshold for other in filtered):
45
+ filtered.append(box)
46
+ return filtered
47
+
48
+
49
+ @torch.no_grad()
50
  def run_yolo(image: Image.Image):
51
+ """
52
+ Run YOLO model on input image and return detected boxes, arrows, and annotated image.
53
+
54
+ Args:
55
+ image (PIL.Image): Input RGB image of a flowchart.
56
+
57
+ Returns:
58
+ tuple:
59
+ boxes (list of dict): Each box has id, bbox, type, label.
60
+ arrows (list of dict): Each arrow has id, tail, head, label.
61
+ vis_image (PIL.Image): Annotated image with detections drawn.
62
+ """
63
  results = model.predict(image, conf=0.25, verbose=False)[0]
64
 
65
  boxes = []
66
  arrows = []
67
+ np_img = np.array(image) # Convert image to numpy array for OCR crops
 
 
68
 
69
  for i, box in enumerate(results.boxes):
70
  cls_id = int(box.cls)
 
71
  label = model.names[cls_id]
72
 
73
  x1, y1, x2, y2 = map(int, box.xyxy[0])
74
  bbox = [x1, y1, x2, y2]
75
 
76
+ width = x2 - x1
77
+ height = y2 - y1
78
+ aspect_ratio = width / height
79
+
80
+ # Default type assignment
81
+ item_type = "arrow" if label in ["arrow", "control_flow"] else "box"
82
+
83
+ # Adjust to 'decision' if it's nearly square (likely diamond shape)
84
+ if item_type == "box" and 0.8 < aspect_ratio < 1.2:
85
+ item_type = "decision"
86
+
87
+ # Create basic detection item
88
  item = {
89
  "id": f"node{i+1}",
90
  "bbox": bbox,
91
+ "type": item_type,
92
  "label": label
93
  }
94
 
95
+ if item_type == "arrow":
96
+ # Extract small patch at arrow center for OCR label
97
  cx, cy = (x1 + x2) // 2, (y1 + y2) // 2
98
  pad = 20
99
  crop = np_img[max(cy - pad, 0):cy + pad, max(cx - pad, 0):cx + pad]
 
100
  detected_label = ""
101
  if crop.size > 0:
102
+ try:
103
+ ocr_results = reader.readtext(crop)
104
+ if ocr_results:
105
+ detected_label = ocr_results[0][1].strip().lower()
106
+ except Exception as e:
107
+ print(f"⚠️ Arrow OCR failed: {e}")
108
 
109
  arrows.append({
110
  "id": f"arrow{len(arrows)+1}",
 
115
  else:
116
  boxes.append(item)
117
 
118
+ # Remove overlapping duplicate boxes
119
+ boxes = deduplicate_boxes(boxes)
120
+
121
+ # Create annotated image with bounding boxes
122
  vis_image = results.plot(pil=True)
123
+
124
+ return boxes, arrows, vis_image