Spaces:
Sleeping
Sleeping
Venkat V
commited on
Commit
Β·
152df72
1
Parent(s):
6ea5d07
updated with fixes to all modules
Browse files- api_backend.py +42 -17
- app.py +66 -53
- graph_module/__init__.py +105 -43
- ocr_module/__init__.py +106 -74
- ocr_module/__init__pyt.py +0 -135
- requirements.txt +4 -0
- summarizer_module/__init__.py +28 -18
- yolo_module/__init__.py +77 -17
api_backend.py
CHANGED
@@ -1,3 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from fastapi import FastAPI, UploadFile, File, Form
|
2 |
from fastapi.middleware.cors import CORSMiddleware
|
3 |
from fastapi.responses import JSONResponse
|
@@ -7,25 +18,41 @@ import io
|
|
7 |
import json
|
8 |
import base64
|
9 |
|
10 |
-
#
|
11 |
from yolo_module import run_yolo
|
12 |
from ocr_module import extract_text, count_elements, validate_structure
|
13 |
-
from graph_module import map_arrows, build_flowchart_json
|
14 |
from summarizer_module import summarize_flowchart
|
15 |
|
|
|
16 |
app = FastAPI()
|
17 |
|
18 |
-
# π Enable CORS
|
19 |
app.add_middleware(
|
20 |
CORSMiddleware,
|
21 |
-
allow_origins=["*"], #
|
22 |
allow_credentials=True,
|
23 |
allow_methods=["*"],
|
24 |
allow_headers=["*"],
|
25 |
)
|
26 |
|
|
|
27 |
@app.post("/process-image")
|
28 |
-
async def process_image(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
debug_mode = debug.lower() == "true"
|
30 |
debug_log = []
|
31 |
|
@@ -33,35 +60,31 @@ async def process_image(file: UploadFile = File(...), debug: str = Form("false")
|
|
33 |
debug_log.append("π₯ Received file upload")
|
34 |
print(f"π₯ File received: {file.filename}")
|
35 |
|
36 |
-
# πΌοΈ
|
37 |
contents = await file.read()
|
38 |
image = Image.open(io.BytesIO(contents)).convert("RGB")
|
39 |
if debug_mode:
|
40 |
debug_log.append("β
Image converted to RGB")
|
41 |
print("β
Image converted to RGB")
|
42 |
|
43 |
-
# π¦ YOLO Detection
|
44 |
boxes, arrows, vis_debug = run_yolo(image)
|
45 |
if debug_mode:
|
46 |
debug_log.append(f"π¦ Detected {len(boxes)} boxes, {len(arrows)} arrows")
|
47 |
|
48 |
-
# π OCR
|
49 |
for box in boxes:
|
50 |
box["text"] = extract_text(image, box["bbox"], debug=debug_mode)
|
51 |
print(f"π OCR for {box['id']}: {box['text']}")
|
52 |
if debug_mode:
|
53 |
debug_log.append(f"π {box['id']}: {box['text']}")
|
54 |
|
55 |
-
# β‘οΈ Build directional edges
|
56 |
-
edges = map_arrows(boxes, arrows)
|
57 |
-
if debug_mode:
|
58 |
-
debug_log.append(f"β‘οΈ Mapped {len(edges)} directional edges")
|
59 |
|
60 |
-
# π§ Build structured
|
61 |
-
flowchart_json = build_flowchart_json(boxes,
|
62 |
print("π§ Flowchart JSON:", json.dumps(flowchart_json, indent=2))
|
63 |
|
64 |
-
# β
|
65 |
structure_info = count_elements(boxes, arrows, debug=debug_mode)
|
66 |
validation = validate_structure(
|
67 |
flowchart_json,
|
@@ -72,17 +95,18 @@ async def process_image(file: UploadFile = File(...), debug: str = Form("false")
|
|
72 |
if debug_mode:
|
73 |
debug_log.append(f"π§Ύ Validation: {validation}")
|
74 |
|
75 |
-
# βοΈ Generate
|
76 |
summary = summarize_flowchart(flowchart_json)
|
77 |
print("π Summary:", summary)
|
78 |
|
79 |
-
# πΌοΈ Encode
|
80 |
yolo_vis = None
|
81 |
if debug_mode and vis_debug:
|
82 |
vis_io = io.BytesIO()
|
83 |
vis_debug.save(vis_io, format="PNG")
|
84 |
yolo_vis = base64.b64encode(vis_io.getvalue()).decode("utf-8")
|
85 |
|
|
|
86 |
return JSONResponse({
|
87 |
"flowchart": flowchart_json,
|
88 |
"summary": summary,
|
@@ -92,4 +116,5 @@ async def process_image(file: UploadFile = File(...), debug: str = Form("false")
|
|
92 |
|
93 |
|
94 |
if __name__ == "__main__":
|
|
|
95 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|
|
|
1 |
+
"""
|
2 |
+
api_backend.py
|
3 |
+
|
4 |
+
FastAPI backend for flowchart-to-English processing. This API supports receiving
|
5 |
+
an image file, running YOLO-based detection to identify boxes and arrows, performing
|
6 |
+
OCR, and generating structured JSON + English summary of the flowchart.
|
7 |
+
|
8 |
+
Endpoints:
|
9 |
+
- POST /process-image: Accepts image input and returns structured flowchart data.
|
10 |
+
"""
|
11 |
+
|
12 |
from fastapi import FastAPI, UploadFile, File, Form
|
13 |
from fastapi.middleware.cors import CORSMiddleware
|
14 |
from fastapi.responses import JSONResponse
|
|
|
18 |
import json
|
19 |
import base64
|
20 |
|
21 |
+
# π§ Import local processing modules
|
22 |
from yolo_module import run_yolo
|
23 |
from ocr_module import extract_text, count_elements, validate_structure
|
24 |
+
from graph_module import map_arrows, build_flowchart_json
|
25 |
from summarizer_module import summarize_flowchart
|
26 |
|
27 |
+
# π₯ Initialize FastAPI app
|
28 |
app = FastAPI()
|
29 |
|
30 |
+
# π Enable CORS to allow frontend (e.g., Streamlit on localhost) to connect
|
31 |
app.add_middleware(
|
32 |
CORSMiddleware,
|
33 |
+
allow_origins=["*"], # In production, replace with allowed frontend domain
|
34 |
allow_credentials=True,
|
35 |
allow_methods=["*"],
|
36 |
allow_headers=["*"],
|
37 |
)
|
38 |
|
39 |
+
|
40 |
@app.post("/process-image")
|
41 |
+
async def process_image(
|
42 |
+
file: UploadFile = File(...),
|
43 |
+
debug: str = Form("false")
|
44 |
+
):
|
45 |
+
"""
|
46 |
+
Receives an uploaded flowchart image, performs object detection and OCR,
|
47 |
+
constructs a structured flowchart JSON, and generates a plain-English summary.
|
48 |
+
|
49 |
+
Args:
|
50 |
+
file (UploadFile): Flowchart image file (.png, .jpg, .jpeg).
|
51 |
+
debug (str): "true" to enable debug mode (includes OCR logs and YOLO preview).
|
52 |
+
|
53 |
+
Returns:
|
54 |
+
JSONResponse: Contains flowchart structure, summary, debug output, and optional YOLO overlay.
|
55 |
+
"""
|
56 |
debug_mode = debug.lower() == "true"
|
57 |
debug_log = []
|
58 |
|
|
|
60 |
debug_log.append("π₯ Received file upload")
|
61 |
print(f"π₯ File received: {file.filename}")
|
62 |
|
63 |
+
# πΌοΈ Convert file bytes to RGB image
|
64 |
contents = await file.read()
|
65 |
image = Image.open(io.BytesIO(contents)).convert("RGB")
|
66 |
if debug_mode:
|
67 |
debug_log.append("β
Image converted to RGB")
|
68 |
print("β
Image converted to RGB")
|
69 |
|
70 |
+
# π¦ YOLO Detection for boxes and arrows
|
71 |
boxes, arrows, vis_debug = run_yolo(image)
|
72 |
if debug_mode:
|
73 |
debug_log.append(f"π¦ Detected {len(boxes)} boxes, {len(arrows)} arrows")
|
74 |
|
75 |
+
# π Run OCR on each detected box
|
76 |
for box in boxes:
|
77 |
box["text"] = extract_text(image, box["bbox"], debug=debug_mode)
|
78 |
print(f"π OCR for {box['id']}: {box['text']}")
|
79 |
if debug_mode:
|
80 |
debug_log.append(f"π {box['id']}: {box['text']}")
|
81 |
|
|
|
|
|
|
|
|
|
82 |
|
83 |
+
# π§ Build structured JSON from nodes and edges
|
84 |
+
flowchart_json = build_flowchart_json(boxes, arrows)
|
85 |
print("π§ Flowchart JSON:", json.dumps(flowchart_json, indent=2))
|
86 |
|
87 |
+
# β
Validate structure
|
88 |
structure_info = count_elements(boxes, arrows, debug=debug_mode)
|
89 |
validation = validate_structure(
|
90 |
flowchart_json,
|
|
|
95 |
if debug_mode:
|
96 |
debug_log.append(f"π§Ύ Validation: {validation}")
|
97 |
|
98 |
+
# βοΈ Generate plain-English summary
|
99 |
summary = summarize_flowchart(flowchart_json)
|
100 |
print("π Summary:", summary)
|
101 |
|
102 |
+
# πΌοΈ Encode YOLO debug image (if debug enabled)
|
103 |
yolo_vis = None
|
104 |
if debug_mode and vis_debug:
|
105 |
vis_io = io.BytesIO()
|
106 |
vis_debug.save(vis_io, format="PNG")
|
107 |
yolo_vis = base64.b64encode(vis_io.getvalue()).decode("utf-8")
|
108 |
|
109 |
+
# π€ Return full response
|
110 |
return JSONResponse({
|
111 |
"flowchart": flowchart_json,
|
112 |
"summary": summary,
|
|
|
116 |
|
117 |
|
118 |
if __name__ == "__main__":
|
119 |
+
# Run the FastAPI app using Uvicorn
|
120 |
uvicorn.run(app, host="0.0.0.0", port=7860)
|
app.py
CHANGED
@@ -1,76 +1,89 @@
|
|
1 |
# app.py
|
|
|
|
|
|
|
|
|
|
|
2 |
import streamlit as st
|
3 |
from PIL import Image
|
4 |
-
import
|
5 |
import base64
|
|
|
6 |
import os
|
7 |
|
8 |
-
#
|
9 |
-
from yolo_module import run_yolo
|
10 |
-
from ocr_module import extract_text
|
11 |
-
from graph_module import map_arrows, build_flowchart_json
|
12 |
-
from summarizer_module import summarize_flowchart
|
13 |
-
|
14 |
st.set_page_config(page_title="Flowchart to English", layout="wide")
|
15 |
st.title("π Flowchart to Plain English")
|
16 |
|
17 |
-
# Enable debug mode
|
18 |
debug_mode = st.toggle("π§ Show Debug Info", value=False)
|
19 |
|
20 |
-
#
|
|
|
|
|
|
|
|
|
21 |
uploaded_file = st.file_uploader("Upload a flowchart image", type=["png", "jpg", "jpeg"])
|
22 |
|
|
|
|
|
|
|
23 |
if uploaded_file:
|
|
|
24 |
image = Image.open(uploaded_file).convert("RGB")
|
25 |
-
|
26 |
-
# Show resized preview
|
27 |
max_width = 600
|
28 |
ratio = max_width / float(image.size[0])
|
29 |
resized_image = image.resize((max_width, int(image.size[1] * ratio)))
|
30 |
st.image(resized_image, caption="π€ Uploaded Image", use_container_width=False)
|
31 |
|
32 |
if st.button("π Analyze Flowchart"):
|
33 |
-
progress = st.progress(0, text="
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
else:
|
76 |
st.info("Upload a flowchart image to begin.")
|
|
|
1 |
# app.py
|
2 |
+
"""
|
3 |
+
Streamlit Frontend App: Uploads a flowchart image, sends it to FastAPI backend,
|
4 |
+
and displays the structured JSON and English summary. Supports multiple OCR engines.
|
5 |
+
"""
|
6 |
+
|
7 |
import streamlit as st
|
8 |
from PIL import Image
|
9 |
+
import requests
|
10 |
import base64
|
11 |
+
import io
|
12 |
import os
|
13 |
|
14 |
+
# Set up Streamlit UI layout
|
|
|
|
|
|
|
|
|
|
|
15 |
st.set_page_config(page_title="Flowchart to English", layout="wide")
|
16 |
st.title("π Flowchart to Plain English")
|
17 |
|
18 |
+
# Enable debug mode toggle
|
19 |
debug_mode = st.toggle("π§ Show Debug Info", value=False)
|
20 |
|
21 |
+
# OCR engine selection dropdown
|
22 |
+
ocr_engine = st.selectbox("Select OCR Engine", ["easyocr", "doctr"], index=0,
|
23 |
+
help="Choose between EasyOCR (lightweight) and Doctr (transformer-based)")
|
24 |
+
|
25 |
+
# Flowchart image uploader
|
26 |
uploaded_file = st.file_uploader("Upload a flowchart image", type=["png", "jpg", "jpeg"])
|
27 |
|
28 |
+
# Backend API URL (defaults to localhost for dev)
|
29 |
+
API_URL = os.getenv("API_URL", "http://localhost:7860/process-image")
|
30 |
+
|
31 |
if uploaded_file:
|
32 |
+
# Load and resize uploaded image for preview
|
33 |
image = Image.open(uploaded_file).convert("RGB")
|
|
|
|
|
34 |
max_width = 600
|
35 |
ratio = max_width / float(image.size[0])
|
36 |
resized_image = image.resize((max_width, int(image.size[1] * ratio)))
|
37 |
st.image(resized_image, caption="π€ Uploaded Image", use_container_width=False)
|
38 |
|
39 |
if st.button("π Analyze Flowchart"):
|
40 |
+
progress = st.progress(0, text="Sending image to backend...")
|
41 |
+
|
42 |
+
try:
|
43 |
+
# Send request to FastAPI backend
|
44 |
+
response = requests.post(
|
45 |
+
API_URL,
|
46 |
+
files={"file": uploaded_file.getvalue()},
|
47 |
+
data={
|
48 |
+
"debug": str(debug_mode).lower(),
|
49 |
+
"ocr_engine": ocr_engine
|
50 |
+
}
|
51 |
+
)
|
52 |
+
|
53 |
+
progress.progress(40, text="Processing detection and OCR...")
|
54 |
+
|
55 |
+
if response.status_code == 200:
|
56 |
+
result = response.json()
|
57 |
+
|
58 |
+
# Show debug info if enabled
|
59 |
+
if debug_mode:
|
60 |
+
st.markdown("### π§ͺ Debug Info")
|
61 |
+
st.code(result.get("debug", ""), language="markdown")
|
62 |
+
|
63 |
+
# Show YOLO visual if available
|
64 |
+
if debug_mode and result.get("yolo_vis"):
|
65 |
+
st.markdown("### πΌοΈ YOLO Detected Bounding Boxes")
|
66 |
+
yolo_bytes = base64.b64decode(result["yolo_vis"])
|
67 |
+
yolo_img = Image.open(io.BytesIO(yolo_bytes))
|
68 |
+
st.image(yolo_img, caption="YOLO Boxes", use_container_width=True)
|
69 |
+
|
70 |
+
progress.progress(80, text="Finalizing output...")
|
71 |
+
|
72 |
+
# Show flowchart JSON and generated English summary
|
73 |
+
col1, col2 = st.columns(2)
|
74 |
+
with col1:
|
75 |
+
st.subheader("π§ Flowchart JSON")
|
76 |
+
st.json(result["flowchart"])
|
77 |
+
with col2:
|
78 |
+
st.subheader("π English Summary")
|
79 |
+
st.markdown(result["summary"])
|
80 |
+
|
81 |
+
progress.progress(100, text="Done!")
|
82 |
+
|
83 |
+
else:
|
84 |
+
st.error(f"β Backend Error: {response.status_code} - {response.text}")
|
85 |
+
|
86 |
+
except Exception as e:
|
87 |
+
st.error(f"β οΈ Request Failed: {e}")
|
88 |
else:
|
89 |
st.info("Upload a flowchart image to begin.")
|
graph_module/__init__.py
CHANGED
@@ -1,34 +1,83 @@
|
|
1 |
-
|
2 |
-
# Arrow and graph logic for converting detected flowchart elements to structured JSON
|
3 |
-
|
4 |
-
from shapely.geometry import box, Point
|
5 |
from collections import defaultdict, deque
|
|
|
|
|
|
|
|
|
6 |
|
7 |
def map_arrows(nodes, arrows):
|
8 |
"""
|
9 |
-
|
10 |
-
Returns
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
"""
|
12 |
for node in nodes:
|
13 |
-
node["shape"] =
|
|
|
|
|
|
|
|
|
14 |
|
15 |
edges = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
for arrow in arrows:
|
17 |
-
|
18 |
-
|
19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
-
|
22 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
23 |
|
24 |
if source and target and source != target:
|
|
|
25 |
edges.append((source, target, label))
|
|
|
|
|
26 |
|
27 |
return edges
|
28 |
|
29 |
-
|
|
|
30 |
"""
|
31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
32 |
"""
|
33 |
text_lower = text.lower()
|
34 |
if "start" in text_lower:
|
@@ -37,37 +86,43 @@ def detect_node_type(text):
|
|
37 |
return "end"
|
38 |
if "?" in text or "yes" in text_lower or "no" in text_lower:
|
39 |
return "decision"
|
40 |
-
return
|
41 |
|
42 |
-
|
|
|
43 |
"""
|
44 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
"""
|
46 |
-
|
|
|
|
|
|
|
47 |
reverse_links = defaultdict(list)
|
48 |
edge_labels = {}
|
49 |
|
50 |
-
for node in nodes:
|
51 |
-
text = node.get("text", "").strip()
|
52 |
-
graph[node["id"]] = {
|
53 |
-
"text": text,
|
54 |
-
"type": node.get("type") or detect_node_type(text),
|
55 |
-
"next": []
|
56 |
-
}
|
57 |
-
|
58 |
for src, tgt, label in edges:
|
59 |
-
graph[src]
|
60 |
reverse_links[tgt].append(src)
|
61 |
-
edge_labels[(src, tgt)] = label.lower()
|
|
|
|
|
|
|
62 |
|
63 |
-
|
64 |
-
|
65 |
-
"start": start_nodes[0] if start_nodes else None,
|
66 |
"steps": []
|
67 |
}
|
68 |
|
69 |
visited = set()
|
70 |
-
queue = deque(
|
|
|
71 |
|
72 |
while queue:
|
73 |
curr = queue.popleft()
|
@@ -75,21 +130,24 @@ def build_flowchart_json(nodes, edges):
|
|
75 |
continue
|
76 |
visited.add(curr)
|
77 |
|
78 |
-
node =
|
|
|
|
|
|
|
79 |
step = {
|
80 |
"id": curr,
|
81 |
-
"text":
|
82 |
-
"type":
|
83 |
}
|
84 |
|
85 |
-
parents = reverse_links[curr]
|
86 |
if len(parents) == 1:
|
87 |
step["parent"] = parents[0]
|
88 |
elif len(parents) > 1:
|
89 |
step["parents"] = parents
|
90 |
|
91 |
-
next_nodes =
|
92 |
-
if
|
93 |
step["branches"] = {}
|
94 |
for tgt in next_nodes:
|
95 |
label = edge_labels.get((curr, tgt), "")
|
@@ -99,14 +157,18 @@ def build_flowchart_json(nodes, edges):
|
|
99 |
step["branches"]["no"] = tgt
|
100 |
else:
|
101 |
step["branches"].setdefault("unknown", []).append(tgt)
|
102 |
-
|
|
|
103 |
elif len(next_nodes) == 1:
|
104 |
step["next"] = next_nodes[0]
|
105 |
-
|
|
|
106 |
elif len(next_nodes) > 1:
|
107 |
step["next"] = next_nodes
|
108 |
-
|
|
|
|
|
109 |
|
110 |
-
|
111 |
|
112 |
-
return
|
|
|
1 |
+
from shapely.geometry import box as shapely_box, Point
|
|
|
|
|
|
|
2 |
from collections import defaultdict, deque
|
3 |
+
import math
|
4 |
+
|
5 |
+
|
6 |
+
MAX_FALLBACK_DIST = 150 # pixels
|
7 |
|
8 |
def map_arrows(nodes, arrows):
|
9 |
"""
|
10 |
+
Map arrows to node boxes using geometric containment with fallback to nearest box.
|
11 |
+
Returns directional edges (source_id, target_id, label).
|
12 |
+
|
13 |
+
Args:
|
14 |
+
nodes (list): List of node dicts with 'bbox' field.
|
15 |
+
arrows (list): List of arrow dicts with 'tail' and 'head' coordinates.
|
16 |
+
|
17 |
+
Returns:
|
18 |
+
list: List of (source, target, label) tuples.
|
19 |
"""
|
20 |
for node in nodes:
|
21 |
+
node["shape"] = shapely_box(*node["bbox"])
|
22 |
+
node["center"] = (
|
23 |
+
(node["bbox"][0] + node["bbox"][2]) // 2,
|
24 |
+
(node["bbox"][1] + node["bbox"][3]) // 2
|
25 |
+
)
|
26 |
|
27 |
edges = []
|
28 |
+
|
29 |
+
def find_nearest_node(pt):
|
30 |
+
min_dist = float("inf")
|
31 |
+
nearest_id = None
|
32 |
+
for n in nodes:
|
33 |
+
cx, cy = n["center"]
|
34 |
+
dist = math.dist(pt, (cx, cy))
|
35 |
+
if dist < min_dist:
|
36 |
+
min_dist = dist
|
37 |
+
nearest_id = n["id"]
|
38 |
+
return nearest_id, min_dist
|
39 |
+
|
40 |
for arrow in arrows:
|
41 |
+
if not isinstance(arrow, dict) or not isinstance(arrow.get("tail"), (tuple, list)) or not isinstance(arrow.get("head"), (tuple, list)):
|
42 |
+
print(f"β οΈ Skipping malformed arrow: {arrow}")
|
43 |
+
continue
|
44 |
+
|
45 |
+
tail_pt = Point(arrow["tail"])
|
46 |
+
head_pt = Point(arrow["head"])
|
47 |
+
label = arrow.get("label", "").strip()
|
48 |
+
|
49 |
+
source = next((n["id"] for n in nodes if n["shape"].buffer(10).contains(tail_pt)), None)
|
50 |
+
target = next((n["id"] for n in nodes if n["shape"].buffer(10).contains(head_pt)), None)
|
51 |
|
52 |
+
# Fallback to nearest node if not found
|
53 |
+
if not source:
|
54 |
+
source, dist = find_nearest_node(arrow["tail"])
|
55 |
+
if dist > MAX_FALLBACK_DIST:
|
56 |
+
source = None
|
57 |
+
if not target:
|
58 |
+
target, dist = find_nearest_node(arrow["head"])
|
59 |
+
if dist > MAX_FALLBACK_DIST:
|
60 |
+
target = None
|
61 |
|
62 |
if source and target and source != target:
|
63 |
+
print(f"β
Mapped arrow from {source} β {target} [{label}]")
|
64 |
edges.append((source, target, label))
|
65 |
+
else:
|
66 |
+
print(f"β οΈ Could not map arrow endpoints to nodes: tail={arrow.get('tail')} head={arrow.get('head')}")
|
67 |
|
68 |
return edges
|
69 |
|
70 |
+
|
71 |
+
def detect_node_type(text, default_type="process"):
|
72 |
"""
|
73 |
+
Heuristically infer the node type from its text.
|
74 |
+
|
75 |
+
Args:
|
76 |
+
text (str): Node label.
|
77 |
+
default_type (str): Fallback type.
|
78 |
+
|
79 |
+
Returns:
|
80 |
+
str: Inferred node type.
|
81 |
"""
|
82 |
text_lower = text.lower()
|
83 |
if "start" in text_lower:
|
|
|
86 |
return "end"
|
87 |
if "?" in text or "yes" in text_lower or "no" in text_lower:
|
88 |
return "decision"
|
89 |
+
return default_type
|
90 |
|
91 |
+
|
92 |
+
def build_flowchart_json(nodes, arrows):
|
93 |
"""
|
94 |
+
Construct a structured flowchart JSON using basic graph traversal.
|
95 |
+
|
96 |
+
Args:
|
97 |
+
nodes (list): Detected node dicts.
|
98 |
+
arrows (list): Detected arrow dicts.
|
99 |
+
|
100 |
+
Returns:
|
101 |
+
dict: JSON with 'start' and 'steps'.
|
102 |
"""
|
103 |
+
edges = map_arrows(nodes, arrows)
|
104 |
+
|
105 |
+
# Build adjacency and reverse mappings
|
106 |
+
graph = defaultdict(list)
|
107 |
reverse_links = defaultdict(list)
|
108 |
edge_labels = {}
|
109 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
for src, tgt, label in edges:
|
111 |
+
graph[src].append(tgt)
|
112 |
reverse_links[tgt].append(src)
|
113 |
+
edge_labels[(src, tgt)] = label.lower()
|
114 |
+
|
115 |
+
all_node_ids = {n["id"] for n in nodes}
|
116 |
+
start_candidates = [nid for nid in all_node_ids if nid not in reverse_links]
|
117 |
|
118 |
+
flowchart = {
|
119 |
+
"start": start_candidates[0] if start_candidates else None,
|
|
|
120 |
"steps": []
|
121 |
}
|
122 |
|
123 |
visited = set()
|
124 |
+
queue = deque(start_candidates)
|
125 |
+
id_to_node = {n["id"]: n for n in nodes}
|
126 |
|
127 |
while queue:
|
128 |
curr = queue.popleft()
|
|
|
130 |
continue
|
131 |
visited.add(curr)
|
132 |
|
133 |
+
node = id_to_node.get(curr, {})
|
134 |
+
text = node.get("text", "").strip()
|
135 |
+
node_type = node.get("type") or detect_node_type(text)
|
136 |
+
|
137 |
step = {
|
138 |
"id": curr,
|
139 |
+
"text": text,
|
140 |
+
"type": node_type
|
141 |
}
|
142 |
|
143 |
+
parents = list(set(reverse_links[curr]))
|
144 |
if len(parents) == 1:
|
145 |
step["parent"] = parents[0]
|
146 |
elif len(parents) > 1:
|
147 |
step["parents"] = parents
|
148 |
|
149 |
+
next_nodes = list(set(graph[curr]))
|
150 |
+
if node_type == "decision" and next_nodes:
|
151 |
step["branches"] = {}
|
152 |
for tgt in next_nodes:
|
153 |
label = edge_labels.get((curr, tgt), "")
|
|
|
157 |
step["branches"]["no"] = tgt
|
158 |
else:
|
159 |
step["branches"].setdefault("unknown", []).append(tgt)
|
160 |
+
if tgt not in visited:
|
161 |
+
queue.append(tgt)
|
162 |
elif len(next_nodes) == 1:
|
163 |
step["next"] = next_nodes[0]
|
164 |
+
if next_nodes[0] not in visited:
|
165 |
+
queue.append(next_nodes[0])
|
166 |
elif len(next_nodes) > 1:
|
167 |
step["next"] = next_nodes
|
168 |
+
for n in next_nodes:
|
169 |
+
if n not in visited:
|
170 |
+
queue.append(n)
|
171 |
|
172 |
+
flowchart["steps"].append(step)
|
173 |
|
174 |
+
return flowchart
|
ocr_module/__init__.py
CHANGED
@@ -1,19 +1,40 @@
|
|
1 |
-
|
2 |
-
|
|
|
|
|
|
|
|
|
3 |
import numpy as np
|
|
|
4 |
import cv2
|
5 |
-
import torch
|
6 |
from textblob import TextBlob
|
7 |
-
|
8 |
from device_config import get_device
|
9 |
-
device = get_device()
|
10 |
-
|
11 |
-
# Enable GPU if available
|
12 |
-
reader = easyocr.Reader(['en'], gpu=(device == "cuda"))
|
13 |
-
print(f"β
EasyOCR reader initialized on: {device}")
|
14 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
def expand_bbox(bbox, image_size, pad=10):
|
|
|
17 |
x1, y1, x2, y2 = bbox
|
18 |
x1 = max(0, x1 - pad)
|
19 |
y1 = max(0, y1 - pad)
|
@@ -22,104 +43,115 @@ def expand_bbox(bbox, image_size, pad=10):
|
|
22 |
return [x1, y1, x2, y2]
|
23 |
|
24 |
def clean_text(text):
|
|
|
25 |
blob = TextBlob(text)
|
26 |
return str(blob.correct())
|
27 |
|
28 |
-
def extract_text(image, bbox, debug=False,
|
29 |
"""
|
30 |
-
Run OCR on a cropped region
|
31 |
|
32 |
Parameters:
|
33 |
-
image (PIL.Image):
|
34 |
-
bbox (list): [x1, y1, x2, y2]
|
35 |
-
debug (bool):
|
36 |
-
|
37 |
|
38 |
Returns:
|
39 |
-
str:
|
40 |
"""
|
41 |
-
# Expand
|
42 |
bbox = expand_bbox(bbox, image.size, pad=10)
|
43 |
x1, y1, x2, y2 = bbox
|
44 |
cropped = image.crop((x1, y1, x2, y2))
|
45 |
|
46 |
-
# Convert to OpenCV
|
47 |
cv_img = np.array(cropped)
|
48 |
-
|
49 |
-
# Convert to grayscale
|
50 |
gray = cv2.cvtColor(cv_img, cv2.COLOR_RGB2GRAY)
|
51 |
|
52 |
-
#
|
53 |
-
|
|
|
54 |
|
55 |
-
#
|
56 |
-
|
57 |
-
resized = cv2.resize(blurred, (0, 0), fx=scale_factor, fy=scale_factor, interpolation=cv2.INTER_LINEAR)
|
58 |
|
59 |
-
#
|
60 |
-
|
61 |
|
62 |
-
#
|
63 |
-
|
64 |
-
debug_image = Image.fromarray(resized_rgb)
|
65 |
-
debug_image.save(f"debug_ocr_crop_{x1}_{y1}.png")
|
66 |
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
if debug:
|
72 |
-
print(f"β οΈ
|
73 |
return ""
|
74 |
|
75 |
-
if debug:
|
76 |
-
for res in results:
|
77 |
-
print(f"OCR: {res[1]} (conf: {res[2]:.2f})")
|
78 |
-
|
79 |
-
# Sort boxes top to bottom, then left to right
|
80 |
-
results.sort(key=lambda r: (r[0][0][1], r[0][0][0]))
|
81 |
-
|
82 |
-
# Filter by confidence
|
83 |
-
filtered = [r for r in results if r[2] > 0.4]
|
84 |
-
if not filtered and results:
|
85 |
-
filtered = sorted(results, key=lambda r: -r[2])[:2] # fallback to top-2
|
86 |
-
|
87 |
-
lines = []
|
88 |
-
for res in filtered:
|
89 |
-
lines.append(res[1])
|
90 |
-
|
91 |
-
joined_text = " ".join(lines).strip()
|
92 |
-
|
93 |
-
# Apply correction
|
94 |
-
if joined_text:
|
95 |
-
joined_text = clean_text(joined_text)
|
96 |
-
|
97 |
-
return joined_text
|
98 |
-
|
99 |
-
|
100 |
def count_elements(boxes, arrows, debug=False):
|
|
|
101 |
box_count = len(boxes)
|
102 |
arrow_count = len(arrows)
|
103 |
if debug:
|
104 |
-
print(f"π¦
|
105 |
-
|
106 |
-
return {
|
107 |
-
"box_count": box_count,
|
108 |
-
"arrow_count": arrow_count
|
109 |
-
}
|
110 |
-
|
111 |
|
112 |
def validate_structure(flowchart_json, expected_boxes=None, expected_arrows=None, debug=False):
|
|
|
113 |
actual_boxes = len(flowchart_json.get("steps", []))
|
114 |
actual_arrows = len(flowchart_json.get("edges", [])) if "edges" in flowchart_json else None
|
115 |
|
116 |
if debug:
|
117 |
-
print(f"π
|
118 |
-
if actual_arrows is not None:
|
119 |
-
print(f"π Flowchart JSON has {actual_arrows} edges")
|
120 |
|
121 |
-
|
122 |
"boxes_valid": (expected_boxes is None or expected_boxes == actual_boxes),
|
123 |
"arrows_valid": (expected_arrows is None or expected_arrows == actual_arrows)
|
124 |
-
}
|
125 |
-
return result
|
|
|
1 |
+
"""
|
2 |
+
OCR module with support for EasyOCR and Doctr.
|
3 |
+
Provides the `extract_text` function that accepts a cropped bounding box and image,
|
4 |
+
and runs OCR based on the selected engine ("easyocr" or "doctr").
|
5 |
+
"""
|
6 |
+
|
7 |
import numpy as np
|
8 |
+
from PIL import Image
|
9 |
import cv2
|
|
|
10 |
from textblob import TextBlob
|
|
|
11 |
from device_config import get_device
|
|
|
|
|
|
|
|
|
|
|
12 |
|
13 |
+
# OCR engine flags
|
14 |
+
USE_EASYOCR = True
|
15 |
+
USE_DOCTR = False
|
16 |
+
|
17 |
+
# Import EasyOCR if available
|
18 |
+
try:
|
19 |
+
import easyocr
|
20 |
+
reader = easyocr.Reader(['en'], gpu=(get_device() == "cuda"))
|
21 |
+
print(f"β
EasyOCR reader initialized on: {get_device()}")
|
22 |
+
USE_EASYOCR = True
|
23 |
+
except ImportError:
|
24 |
+
print("β οΈ EasyOCR not installed. Falling back if Doctr is available.")
|
25 |
+
|
26 |
+
# Import Doctr if available
|
27 |
+
try:
|
28 |
+
from doctr.io import DocumentFile
|
29 |
+
from doctr.models import ocr_predictor
|
30 |
+
doctr_model = ocr_predictor(pretrained=True)
|
31 |
+
print("β
Doctr model loaded.")
|
32 |
+
USE_DOCTR = True
|
33 |
+
except ImportError:
|
34 |
+
print("β οΈ Doctr not installed.")
|
35 |
|
36 |
def expand_bbox(bbox, image_size, pad=10):
|
37 |
+
"""Expand a bounding box by padding within image bounds."""
|
38 |
x1, y1, x2, y2 = bbox
|
39 |
x1 = max(0, x1 - pad)
|
40 |
y1 = max(0, y1 - pad)
|
|
|
43 |
return [x1, y1, x2, y2]
|
44 |
|
45 |
def clean_text(text):
|
46 |
+
"""Use TextBlob to autocorrect basic OCR errors."""
|
47 |
blob = TextBlob(text)
|
48 |
return str(blob.correct())
|
49 |
|
50 |
+
def extract_text(image, bbox, debug=False, engine="easyocr"):
|
51 |
"""
|
52 |
+
Run OCR on a cropped region using EasyOCR or Doctr.
|
53 |
|
54 |
Parameters:
|
55 |
+
image (PIL.Image): Full input image.
|
56 |
+
bbox (list): [x1, y1, x2, y2] bounding box.
|
57 |
+
debug (bool): Enable debug output.
|
58 |
+
engine (str): 'easyocr' or 'doctr'.
|
59 |
|
60 |
Returns:
|
61 |
+
str: Cleaned OCR output.
|
62 |
"""
|
63 |
+
# Expand and crop image region
|
64 |
bbox = expand_bbox(bbox, image.size, pad=10)
|
65 |
x1, y1, x2, y2 = bbox
|
66 |
cropped = image.crop((x1, y1, x2, y2))
|
67 |
|
68 |
+
# Convert to OpenCV grayscale
|
69 |
cv_img = np.array(cropped)
|
|
|
|
|
70 |
gray = cv2.cvtColor(cv_img, cv2.COLOR_RGB2GRAY)
|
71 |
|
72 |
+
# Enhance contrast using CLAHE (Contrast Limited Adaptive Histogram Equalization)
|
73 |
+
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
|
74 |
+
enhanced = clahe.apply(gray)
|
75 |
|
76 |
+
# Apply adaptive threshold for better text separation
|
77 |
+
thresh = cv2.adaptiveThreshold(enhanced, 255, cv2.ADAPTIVE_THRESH_MEAN_C, cv2.THRESH_BINARY, 11, 4)
|
|
|
78 |
|
79 |
+
# Resize for better OCR resolution
|
80 |
+
resized = cv2.resize(thresh, (0, 0), fx=2.5, fy=2.5, interpolation=cv2.INTER_LINEAR)
|
81 |
|
82 |
+
# Convert to RGB (some OCR engines expect 3-channel images)
|
83 |
+
preprocessed = cv2.cvtColor(resized, cv2.COLOR_GRAY2RGB)
|
|
|
|
|
84 |
|
85 |
+
if debug:
|
86 |
+
Image.fromarray(preprocessed).save(f"debug_ocr_crop_{x1}_{y1}.png")
|
87 |
+
|
88 |
+
if engine == "doctr" and USE_DOCTR:
|
89 |
+
try:
|
90 |
+
doc = DocumentFile.from_images([Image.fromarray(preprocessed)])
|
91 |
+
result = doctr_model(doc)
|
92 |
+
out_text = " ".join([b.value for b in result.pages[0].blocks])
|
93 |
+
if debug:
|
94 |
+
print(f"π Doctr OCR: {out_text}")
|
95 |
+
return clean_text(out_text)
|
96 |
+
except Exception as e:
|
97 |
+
if debug:
|
98 |
+
print(f"β Doctr failed: {e}")
|
99 |
+
return ""
|
100 |
+
|
101 |
+
elif engine == "easyocr" and USE_EASYOCR:
|
102 |
+
try:
|
103 |
+
results = reader.readtext(preprocessed, paragraph=False, min_size=10)
|
104 |
+
filtered = []
|
105 |
+
for r in results:
|
106 |
+
text = r[1].strip()
|
107 |
+
conf = r[2]
|
108 |
+
if conf > 0.5 and len(text) > 2 and any(c.isalnum() for c in text):
|
109 |
+
filtered.append(r)
|
110 |
+
|
111 |
+
# Remove duplicates by bounding box IoU overlap
|
112 |
+
final = []
|
113 |
+
seen = set()
|
114 |
+
for r in filtered:
|
115 |
+
t = r[1].strip()
|
116 |
+
if t.lower() not in seen:
|
117 |
+
seen.add(t.lower())
|
118 |
+
final.append(r)
|
119 |
+
|
120 |
+
final.sort(key=lambda r: (r[0][0][1], r[0][0][0]))
|
121 |
+
text = " ".join([r[1] for r in final]).strip()
|
122 |
+
|
123 |
+
if debug:
|
124 |
+
for r in final:
|
125 |
+
print(f"π± EasyOCR: {r[1]} (conf: {r[2]:.2f})")
|
126 |
+
|
127 |
+
return clean_text(text) if text else ""
|
128 |
+
except Exception as e:
|
129 |
+
if debug:
|
130 |
+
print(f"β EasyOCR failed: {e}")
|
131 |
+
return ""
|
132 |
+
|
133 |
+
else:
|
134 |
if debug:
|
135 |
+
print(f"β οΈ Unsupported OCR engine: {engine} or not available.")
|
136 |
return ""
|
137 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
138 |
def count_elements(boxes, arrows, debug=False):
|
139 |
+
"""Return count of boxes and arrows detected."""
|
140 |
box_count = len(boxes)
|
141 |
arrow_count = len(arrows)
|
142 |
if debug:
|
143 |
+
print(f"π¦ Boxes: {box_count} | β‘οΈ Arrows: {arrow_count}")
|
144 |
+
return {"box_count": box_count, "arrow_count": arrow_count}
|
|
|
|
|
|
|
|
|
|
|
145 |
|
146 |
def validate_structure(flowchart_json, expected_boxes=None, expected_arrows=None, debug=False):
|
147 |
+
"""Validate flowchart structure consistency based on expected counts."""
|
148 |
actual_boxes = len(flowchart_json.get("steps", []))
|
149 |
actual_arrows = len(flowchart_json.get("edges", [])) if "edges" in flowchart_json else None
|
150 |
|
151 |
if debug:
|
152 |
+
print(f"π JSON boxes: {actual_boxes}, edges: {actual_arrows}")
|
|
|
|
|
153 |
|
154 |
+
return {
|
155 |
"boxes_valid": (expected_boxes is None or expected_boxes == actual_boxes),
|
156 |
"arrows_valid": (expected_arrows is None or expected_arrows == actual_arrows)
|
157 |
+
}
|
|
ocr_module/__init__pyt.py
DELETED
@@ -1,135 +0,0 @@
|
|
1 |
-
import easyocr
|
2 |
-
from PIL import Image
|
3 |
-
import numpy as np
|
4 |
-
import cv2
|
5 |
-
import torch
|
6 |
-
from textblob import TextBlob
|
7 |
-
import re
|
8 |
-
|
9 |
-
# Enable GPU if available
|
10 |
-
use_gpu = torch.cuda.is_available()
|
11 |
-
reader = easyocr.Reader(['en'], gpu=use_gpu)
|
12 |
-
|
13 |
-
def expand_bbox(bbox, image_size, pad=10):
|
14 |
-
x1, y1, x2, y2 = bbox
|
15 |
-
x1 = max(0, x1 - pad)
|
16 |
-
y1 = max(0, y1 - pad)
|
17 |
-
x2 = min(image_size[0], x2 + pad)
|
18 |
-
y2 = min(image_size[1], y2 + pad)
|
19 |
-
return [x1, y1, x2, y2]
|
20 |
-
|
21 |
-
def clean_text(text):
|
22 |
-
# Basic cleanup
|
23 |
-
text = re.sub(r'[^A-Za-z0-9?,.:;()\'"\s-]', '', text) # remove noise characters
|
24 |
-
text = re.sub(r'\s+', ' ', text).strip()
|
25 |
-
|
26 |
-
# De-duplicate repeated words
|
27 |
-
words = text.split()
|
28 |
-
deduped = [words[0]] + [w for i, w in enumerate(words[1:], 1) if w.lower() != words[i - 1].lower()] if words else []
|
29 |
-
joined = " ".join(deduped)
|
30 |
-
|
31 |
-
# Run correction only if needed (long word or all caps)
|
32 |
-
if len(joined) > 3 and any(len(w) > 10 or w.isupper() for w in deduped):
|
33 |
-
blob = TextBlob(joined)
|
34 |
-
joined = str(blob.correct())
|
35 |
-
|
36 |
-
return joined
|
37 |
-
|
38 |
-
def extract_text(image, bbox, debug=False, use_adaptive_threshold=False):
|
39 |
-
"""
|
40 |
-
Run OCR on a cropped region of the image using EasyOCR with preprocessing.
|
41 |
-
|
42 |
-
Parameters:
|
43 |
-
image (PIL.Image): The full image.
|
44 |
-
bbox (list): [x1, y1, x2, y2] coordinates of the region to crop.
|
45 |
-
debug (bool): If True, show intermediate debug output.
|
46 |
-
use_adaptive_threshold (bool): Use adaptive thresholding instead of Otsu's.
|
47 |
-
|
48 |
-
Returns:
|
49 |
-
str: Extracted and cleaned text.
|
50 |
-
"""
|
51 |
-
# Expand bbox slightly
|
52 |
-
bbox = expand_bbox(bbox, image.size, pad=10)
|
53 |
-
x1, y1, x2, y2 = bbox
|
54 |
-
cropped = image.crop((x1, y1, x2, y2))
|
55 |
-
|
56 |
-
# Convert to OpenCV format (numpy array)
|
57 |
-
cv_img = np.array(cropped)
|
58 |
-
|
59 |
-
# Convert to grayscale
|
60 |
-
gray = cv2.cvtColor(cv_img, cv2.COLOR_RGB2GRAY)
|
61 |
-
|
62 |
-
# Apply Gaussian blur to reduce noise
|
63 |
-
blurred = cv2.GaussianBlur(gray, (3, 3), 0)
|
64 |
-
|
65 |
-
# Resize (upscale) image for better OCR accuracy
|
66 |
-
scale_factor = 2.5
|
67 |
-
resized = cv2.resize(blurred, (0, 0), fx=scale_factor, fy=scale_factor, interpolation=cv2.INTER_LINEAR)
|
68 |
-
|
69 |
-
# Convert to RGB as EasyOCR expects color image
|
70 |
-
resized_rgb = cv2.cvtColor(resized, cv2.COLOR_GRAY2RGB)
|
71 |
-
|
72 |
-
# Optional: debug save
|
73 |
-
if debug:
|
74 |
-
debug_image = Image.fromarray(resized_rgb)
|
75 |
-
debug_image.save(f"debug_ocr_crop_{x1}_{y1}.png")
|
76 |
-
|
77 |
-
# Run OCR using EasyOCR
|
78 |
-
try:
|
79 |
-
results = reader.readtext(resized_rgb, paragraph=False, min_size=5)
|
80 |
-
except Exception as e:
|
81 |
-
if debug:
|
82 |
-
print(f"β οΈ EasyOCR failed: {e}")
|
83 |
-
return ""
|
84 |
-
|
85 |
-
if debug:
|
86 |
-
for res in results:
|
87 |
-
print(f"OCR: {res[1]} (conf: {res[2]:.2f})")
|
88 |
-
|
89 |
-
# Sort boxes top to bottom, then left to right
|
90 |
-
results.sort(key=lambda r: (r[0][0][1], r[0][0][0]))
|
91 |
-
|
92 |
-
# Filter by confidence
|
93 |
-
filtered = [r for r in results if r[2] > 0.4]
|
94 |
-
if not filtered and results:
|
95 |
-
filtered = sorted(results, key=lambda r: -r[2])[:2] # fallback to top-2
|
96 |
-
|
97 |
-
lines = []
|
98 |
-
for res in filtered:
|
99 |
-
lines.append(res[1])
|
100 |
-
|
101 |
-
joined_text = " ".join(lines).strip()
|
102 |
-
|
103 |
-
# Apply correction
|
104 |
-
if joined_text:
|
105 |
-
joined_text = clean_text(joined_text)
|
106 |
-
if debug:
|
107 |
-
print(f"π§Ή Cleaned OCR text: {joined_text}")
|
108 |
-
|
109 |
-
return joined_text
|
110 |
-
|
111 |
-
def count_elements(boxes, arrows, debug=False):
|
112 |
-
box_count = len(boxes)
|
113 |
-
arrow_count = len(arrows)
|
114 |
-
if debug:
|
115 |
-
print(f"π¦ Detected {box_count} boxes")
|
116 |
-
print(f"β‘οΈ Detected {arrow_count} arrows")
|
117 |
-
return {
|
118 |
-
"box_count": box_count,
|
119 |
-
"arrow_count": arrow_count
|
120 |
-
}
|
121 |
-
|
122 |
-
def validate_structure(flowchart_json, expected_boxes=None, expected_arrows=None, debug=False):
|
123 |
-
actual_boxes = len(flowchart_json.get("steps", []))
|
124 |
-
actual_arrows = len(flowchart_json.get("edges", [])) if "edges" in flowchart_json else None
|
125 |
-
|
126 |
-
if debug:
|
127 |
-
print(f"π Flowchart JSON has {actual_boxes} steps")
|
128 |
-
if actual_arrows is not None:
|
129 |
-
print(f"π Flowchart JSON has {actual_arrows} edges")
|
130 |
-
|
131 |
-
result = {
|
132 |
-
"boxes_valid": (expected_boxes is None or expected_boxes == actual_boxes),
|
133 |
-
"arrows_valid": (expected_arrows is None or expected_arrows == actual_arrows)
|
134 |
-
}
|
135 |
-
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
requirements.txt
CHANGED
@@ -20,6 +20,10 @@ numpy # Core image array operations
|
|
20 |
easyocr # GPU-capable OCR engine
|
21 |
textblob # Optional: lightweight text post-processing (optional)
|
22 |
|
|
|
|
|
|
|
|
|
23 |
# π€ Object Detection and Language Models
|
24 |
ultralytics # YOLOv8/v9 detection (loads .pt models)
|
25 |
torch # Backend for YOLO and EasyOCR
|
|
|
20 |
easyocr # GPU-capable OCR engine
|
21 |
textblob # Optional: lightweight text post-processing (optional)
|
22 |
|
23 |
+
# --- Doctr dependencies (torch-based) ---
|
24 |
+
python-doctr[torch]
|
25 |
+
onnxruntime # Required backend for Doctr inference
|
26 |
+
|
27 |
# π€ Object Detection and Language Models
|
28 |
ultralytics # YOLOv8/v9 detection (loads .pt models)
|
29 |
torch # Backend for YOLO and EasyOCR
|
summarizer_module/__init__.py
CHANGED
@@ -1,38 +1,48 @@
|
|
1 |
-
# summarizer_module/__init__.py
|
2 |
-
|
3 |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
|
4 |
from device_config import get_device
|
5 |
import torch
|
|
|
6 |
|
|
|
7 |
device = get_device()
|
8 |
|
9 |
-
#
|
10 |
-
MODEL_ID = "microsoft/
|
11 |
|
12 |
-
# Load
|
13 |
model = AutoModelForCausalLM.from_pretrained(MODEL_ID).to(device)
|
14 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
15 |
summarizer = pipeline("text-generation", model=model, tokenizer=tokenizer)
|
16 |
|
17 |
def summarize_flowchart(flowchart_json):
|
18 |
"""
|
19 |
-
|
20 |
-
formatted as bullets and sub-bullets.
|
21 |
|
22 |
Args:
|
23 |
-
flowchart_json (dict):
|
24 |
|
25 |
Returns:
|
26 |
-
str: Bullet-style
|
27 |
"""
|
|
|
28 |
prompt = (
|
29 |
-
"
|
30 |
-
"
|
31 |
-
"
|
32 |
-
|
33 |
-
"
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
return explanation
|
|
|
|
|
|
|
1 |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
|
2 |
from device_config import get_device
|
3 |
import torch
|
4 |
+
import json
|
5 |
|
6 |
+
# Automatically choose device (CUDA, MPS, CPU)
|
7 |
device = get_device()
|
8 |
|
9 |
+
# βοΈ Model config: Use phi-2-mini (replace with phi-4-mini when available)
|
10 |
+
MODEL_ID = "microsoft/Phi-4-mini-instruct"
|
11 |
|
12 |
+
# Load tokenizer and model
|
13 |
model = AutoModelForCausalLM.from_pretrained(MODEL_ID).to(device)
|
14 |
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
|
15 |
summarizer = pipeline("text-generation", model=model, tokenizer=tokenizer)
|
16 |
|
17 |
def summarize_flowchart(flowchart_json):
|
18 |
"""
|
19 |
+
Generates a human-friendly explanation from flowchart JSON.
|
|
|
20 |
|
21 |
Args:
|
22 |
+
flowchart_json (dict): Contains "start" node and a list of "steps".
|
23 |
|
24 |
Returns:
|
25 |
+
str: Bullet-style explanation with proper nesting and flow.
|
26 |
"""
|
27 |
+
# π Prompt optimized for flow comprehension
|
28 |
prompt = (
|
29 |
+
"You are an expert in visual reasoning and instruction generation.\n"
|
30 |
+
"Convert the following flowchart JSON into a clear, step-by-step summary using bullets.\n"
|
31 |
+
"- Each bullet represents a process step.\n"
|
32 |
+
"- Use indented sub-bullets to explain decision branches (Yes/No).\n"
|
33 |
+
"- Maintain order based on dependencies and parent-child links.\n"
|
34 |
+
"- Avoid repeating the same step more than once.\n"
|
35 |
+
"- Do not include JSON in the output, only human-readable text.\n"
|
36 |
+
"\nFlowchart:\n{flowchart}\n\nBullet Explanation:"
|
37 |
+
).format(flowchart=json.dumps(flowchart_json, indent=2))
|
38 |
+
|
39 |
+
# π§ Run the model inference
|
40 |
+
result = summarizer(prompt, max_new_tokens=400, do_sample=False)[0]["generated_text"]
|
41 |
+
|
42 |
+
# Extract the portion after the final prompt marker
|
43 |
+
if "Bullet Explanation:" in result:
|
44 |
+
explanation = result.split("Bullet Explanation:")[-1].strip()
|
45 |
+
else:
|
46 |
+
explanation = result.strip()
|
47 |
+
|
48 |
return explanation
|
yolo_module/__init__.py
CHANGED
@@ -1,55 +1,110 @@
|
|
1 |
-
|
|
|
|
|
|
|
|
|
2 |
from ultralytics import YOLO
|
3 |
from device_config import get_device
|
4 |
-
from PIL import Image
|
5 |
import numpy as np
|
6 |
import easyocr
|
|
|
|
|
7 |
|
8 |
-
# Load YOLO model
|
9 |
MODEL_PATH = "models/best.pt"
|
10 |
device = get_device()
|
11 |
model = YOLO(MODEL_PATH).to(device)
|
12 |
print(f"β
YOLO model loaded on: {device}")
|
13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
|
15 |
-
|
16 |
-
|
|
|
17 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
def run_yolo(image: Image.Image):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
19 |
results = model.predict(image, conf=0.25, verbose=False)[0]
|
20 |
|
21 |
boxes = []
|
22 |
arrows = []
|
23 |
-
|
24 |
-
# Convert image to OpenCV format for EasyOCR
|
25 |
-
np_img = np.array(image)
|
26 |
|
27 |
for i, box in enumerate(results.boxes):
|
28 |
cls_id = int(box.cls)
|
29 |
-
conf = float(box.conf)
|
30 |
label = model.names[cls_id]
|
31 |
|
32 |
x1, y1, x2, y2 = map(int, box.xyxy[0])
|
33 |
bbox = [x1, y1, x2, y2]
|
34 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
item = {
|
36 |
"id": f"node{i+1}",
|
37 |
"bbox": bbox,
|
38 |
-
"type":
|
39 |
"label": label
|
40 |
}
|
41 |
|
42 |
-
if
|
43 |
-
#
|
44 |
cx, cy = (x1 + x2) // 2, (y1 + y2) // 2
|
45 |
pad = 20
|
46 |
crop = np_img[max(cy - pad, 0):cy + pad, max(cx - pad, 0):cx + pad]
|
47 |
-
|
48 |
detected_label = ""
|
49 |
if crop.size > 0:
|
50 |
-
|
51 |
-
|
52 |
-
|
|
|
|
|
|
|
53 |
|
54 |
arrows.append({
|
55 |
"id": f"arrow{len(arrows)+1}",
|
@@ -60,5 +115,10 @@ def run_yolo(image: Image.Image):
|
|
60 |
else:
|
61 |
boxes.append(item)
|
62 |
|
|
|
|
|
|
|
|
|
63 |
vis_image = results.plot(pil=True)
|
64 |
-
|
|
|
|
1 |
+
"""
|
2 |
+
YOLO module for detecting flowchart elements (boxes and arrows).
|
3 |
+
Includes optional OCR for labeling arrows and deduplication to eliminate overlapping detections.
|
4 |
+
"""
|
5 |
+
|
6 |
from ultralytics import YOLO
|
7 |
from device_config import get_device
|
8 |
+
from PIL import Image
|
9 |
import numpy as np
|
10 |
import easyocr
|
11 |
+
from shapely.geometry import box as shapely_box
|
12 |
+
import torch
|
13 |
|
14 |
+
# Load YOLO model and move to appropriate device
|
15 |
MODEL_PATH = "models/best.pt"
|
16 |
device = get_device()
|
17 |
model = YOLO(MODEL_PATH).to(device)
|
18 |
print(f"β
YOLO model loaded on: {device}")
|
19 |
|
20 |
+
# EasyOCR reader used for detecting optional labels near arrows
|
21 |
+
reader = easyocr.Reader(['en'], gpu=(device == "cuda"))
|
22 |
+
|
23 |
+
|
24 |
+
def iou(box1, box2):
|
25 |
+
"""Compute Intersection over Union (IoU) between two bounding boxes."""
|
26 |
+
b1 = shapely_box(*box1)
|
27 |
+
b2 = shapely_box(*box2)
|
28 |
+
return b1.intersection(b2).area / b1.union(b2).area
|
29 |
+
|
30 |
|
31 |
+
def deduplicate_boxes(boxes, iou_threshold=0.6):
|
32 |
+
"""
|
33 |
+
Eliminate overlapping or duplicate boxes based on IoU threshold.
|
34 |
|
35 |
+
Args:
|
36 |
+
boxes (list): List of box dictionaries with 'bbox' key.
|
37 |
+
iou_threshold (float): Threshold above which boxes are considered duplicates.
|
38 |
+
|
39 |
+
Returns:
|
40 |
+
list: Filtered list of unique boxes.
|
41 |
+
"""
|
42 |
+
filtered = []
|
43 |
+
for box in boxes:
|
44 |
+
if all(iou(box['bbox'], other['bbox']) < iou_threshold for other in filtered):
|
45 |
+
filtered.append(box)
|
46 |
+
return filtered
|
47 |
+
|
48 |
+
|
49 |
+
@torch.no_grad()
|
50 |
def run_yolo(image: Image.Image):
|
51 |
+
"""
|
52 |
+
Run YOLO model on input image and return detected boxes, arrows, and annotated image.
|
53 |
+
|
54 |
+
Args:
|
55 |
+
image (PIL.Image): Input RGB image of a flowchart.
|
56 |
+
|
57 |
+
Returns:
|
58 |
+
tuple:
|
59 |
+
boxes (list of dict): Each box has id, bbox, type, label.
|
60 |
+
arrows (list of dict): Each arrow has id, tail, head, label.
|
61 |
+
vis_image (PIL.Image): Annotated image with detections drawn.
|
62 |
+
"""
|
63 |
results = model.predict(image, conf=0.25, verbose=False)[0]
|
64 |
|
65 |
boxes = []
|
66 |
arrows = []
|
67 |
+
np_img = np.array(image) # Convert image to numpy array for OCR crops
|
|
|
|
|
68 |
|
69 |
for i, box in enumerate(results.boxes):
|
70 |
cls_id = int(box.cls)
|
|
|
71 |
label = model.names[cls_id]
|
72 |
|
73 |
x1, y1, x2, y2 = map(int, box.xyxy[0])
|
74 |
bbox = [x1, y1, x2, y2]
|
75 |
|
76 |
+
width = x2 - x1
|
77 |
+
height = y2 - y1
|
78 |
+
aspect_ratio = width / height
|
79 |
+
|
80 |
+
# Default type assignment
|
81 |
+
item_type = "arrow" if label in ["arrow", "control_flow"] else "box"
|
82 |
+
|
83 |
+
# Adjust to 'decision' if it's nearly square (likely diamond shape)
|
84 |
+
if item_type == "box" and 0.8 < aspect_ratio < 1.2:
|
85 |
+
item_type = "decision"
|
86 |
+
|
87 |
+
# Create basic detection item
|
88 |
item = {
|
89 |
"id": f"node{i+1}",
|
90 |
"bbox": bbox,
|
91 |
+
"type": item_type,
|
92 |
"label": label
|
93 |
}
|
94 |
|
95 |
+
if item_type == "arrow":
|
96 |
+
# Extract small patch at arrow center for OCR label
|
97 |
cx, cy = (x1 + x2) // 2, (y1 + y2) // 2
|
98 |
pad = 20
|
99 |
crop = np_img[max(cy - pad, 0):cy + pad, max(cx - pad, 0):cx + pad]
|
|
|
100 |
detected_label = ""
|
101 |
if crop.size > 0:
|
102 |
+
try:
|
103 |
+
ocr_results = reader.readtext(crop)
|
104 |
+
if ocr_results:
|
105 |
+
detected_label = ocr_results[0][1].strip().lower()
|
106 |
+
except Exception as e:
|
107 |
+
print(f"β οΈ Arrow OCR failed: {e}")
|
108 |
|
109 |
arrows.append({
|
110 |
"id": f"arrow{len(arrows)+1}",
|
|
|
115 |
else:
|
116 |
boxes.append(item)
|
117 |
|
118 |
+
# Remove overlapping duplicate boxes
|
119 |
+
boxes = deduplicate_boxes(boxes)
|
120 |
+
|
121 |
+
# Create annotated image with bounding boxes
|
122 |
vis_image = results.plot(pil=True)
|
123 |
+
|
124 |
+
return boxes, arrows, vis_image
|