Venkat V commited on
Commit
c842ab7
·
1 Parent(s): e2acd29

made changes to fix render issues

Browse files
Files changed (6) hide show
  1. api_backend.py +95 -0
  2. app.py +47 -77
  3. graph_module/__init__.py +42 -60
  4. render.yaml +4 -2
  5. streamlit_app.py +0 -70
  6. yolo_module/__init__.py +22 -9
api_backend.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, UploadFile, File, Form
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from fastapi.responses import JSONResponse
4
+ import uvicorn
5
+ from PIL import Image
6
+ import io
7
+ import json
8
+ import base64
9
+
10
+ # 💡 Import modules
11
+ from yolo_module import run_yolo
12
+ from ocr_module import extract_text, count_elements, validate_structure
13
+ from flowchart_builder import map_arrows, build_flowchart_json # renamed for clarity
14
+ from summarizer_module import summarize_flowchart
15
+
16
+ app = FastAPI()
17
+
18
+ # 🔓 Enable CORS for Streamlit frontend
19
+ app.add_middleware(
20
+ CORSMiddleware,
21
+ allow_origins=["*"], # Update with actual domain if needed
22
+ allow_credentials=True,
23
+ allow_methods=["*"],
24
+ allow_headers=["*"],
25
+ )
26
+
27
+ @app.post("/process-image")
28
+ async def process_image(file: UploadFile = File(...), debug: str = Form("false")):
29
+ debug_mode = debug.lower() == "true"
30
+ debug_log = []
31
+
32
+ if debug_mode:
33
+ debug_log.append("📥 Received file upload")
34
+ print(f"📥 File received: {file.filename}")
35
+
36
+ # 🖼️ Load image
37
+ contents = await file.read()
38
+ image = Image.open(io.BytesIO(contents)).convert("RGB")
39
+ if debug_mode:
40
+ debug_log.append("✅ Image converted to RGB")
41
+ print("✅ Image converted to RGB")
42
+
43
+ # 📦 YOLO Detection
44
+ boxes, arrows, vis_debug = run_yolo(image)
45
+ if debug_mode:
46
+ debug_log.append(f"📦 Detected {len(boxes)} boxes, {len(arrows)} arrows")
47
+
48
+ # 🔍 OCR for each box
49
+ for box in boxes:
50
+ box["text"] = extract_text(image, box["bbox"], debug=debug_mode)
51
+ print(f"🔍 OCR for {box['id']}: {box['text']}")
52
+ if debug_mode:
53
+ debug_log.append(f"🔍 {box['id']}: {box['text']}")
54
+
55
+ # ➡️ Build directional edges
56
+ edges = map_arrows(boxes, arrows)
57
+ if debug_mode:
58
+ debug_log.append(f"➡️ Mapped {len(edges)} directional edges")
59
+
60
+ # 🧠 Build structured flowchart
61
+ flowchart_json = build_flowchart_json(boxes, edges)
62
+ print("🧠 Flowchart JSON:", json.dumps(flowchart_json, indent=2))
63
+
64
+ # ✅ Sanity checks
65
+ structure_info = count_elements(boxes, arrows, debug=debug_mode)
66
+ validation = validate_structure(
67
+ flowchart_json,
68
+ expected_boxes=structure_info["box_count"],
69
+ expected_arrows=len(arrows),
70
+ debug=debug_mode
71
+ )
72
+ if debug_mode:
73
+ debug_log.append(f"🧾 Validation: {validation}")
74
+
75
+ # ✍️ Generate Summary
76
+ summary = summarize_flowchart(flowchart_json)
77
+ print("📝 Summary:", summary)
78
+
79
+ # 🖼️ Encode visual debug
80
+ yolo_vis = None
81
+ if debug_mode and vis_debug:
82
+ vis_io = io.BytesIO()
83
+ vis_debug.save(vis_io, format="PNG")
84
+ yolo_vis = base64.b64encode(vis_io.getvalue()).decode("utf-8")
85
+
86
+ return JSONResponse({
87
+ "flowchart": flowchart_json,
88
+ "summary": summary,
89
+ "yolo_vis": yolo_vis,
90
+ "debug": "\n".join(debug_log) if debug_mode else ""
91
+ })
92
+
93
+
94
+ if __name__ == "__main__":
95
+ uvicorn.run(app, host="0.0.0.0", port=7860)
app.py CHANGED
@@ -1,94 +1,64 @@
1
- from fastapi import FastAPI, UploadFile, File, Form
2
- from fastapi.middleware.cors import CORSMiddleware
3
- from fastapi.responses import JSONResponse
4
- import uvicorn
5
  from PIL import Image
6
  import io
7
- import json
8
  import base64
 
9
 
 
 
 
 
10
 
11
- # Import pipeline modules
12
- from yolo_module import run_yolo
13
- from ocr_module import extract_text, count_elements, validate_structure
14
- from graph_module import map_arrows, build_flowchart_json
15
- from summarizer_module import summarize_flowchart
16
 
17
- app = FastAPI()
18
 
19
- # Allow Streamlit access
20
- app.add_middleware(
21
- CORSMiddleware,
22
- allow_origins=["*"],
23
- allow_credentials=True,
24
- allow_methods=["*"],
25
- allow_headers=["*"],
26
- )
27
 
28
- @app.post("/process-image")
29
- async def process_image(file: UploadFile = File(...), debug: str = Form("false")):
30
- debug_mode = debug.lower() == "true"
31
- debug_log = []
 
 
32
 
33
- if debug_mode:
34
- debug_log.append("📥 Received file: file")
35
- print("📥 Received file:", file.filename)
 
36
 
37
- contents = await file.read()
38
- image = Image.open(io.BytesIO(contents)).convert("RGB")
39
- if debug_mode:
40
- debug_log.append("✅ Image loaded and converted to RGB")
41
- print("✅ Image loaded and converted to RGB")
42
 
43
- # 🔁 Run YOLO
44
- boxes, arrows, vis_debug = run_yolo(image)
45
- if debug_mode:
46
- debug_log.append(f"📦 YOLO detected {len(boxes)} boxes and {len(arrows)} arrows")
47
 
48
- # 🔍 Run OCR
49
- for box in boxes:
50
- box["text"] = extract_text(image, box["bbox"], debug=debug_mode)
51
- if debug_mode:
52
- debug_log.append(f"🔍 OCR text for box {box['id']}: {box['text']}")
53
- print(f"🔍 OCR text for box {box['id']}: {box['text']}")
54
-
55
-
56
- # 🔗 Map arrows and build graph
57
- edges = map_arrows(boxes, arrows)
58
- if debug_mode:
59
- debug_log.append(f"🧭 Mapped {len(edges)} edges from arrows to boxes")
60
-
61
- flowchart_json = build_flowchart_json(boxes, edges)
62
- print("🧠 Flowchart JSON structure:")
63
- print(json.dumps(flowchart_json, indent=2))
64
 
65
- # 🧮 Validate and count
66
- structure_info = count_elements(boxes, arrows, debug=debug_mode)
67
- validation = validate_structure(flowchart_json, expected_boxes=structure_info["box_count"], expected_arrows=len(arrows), debug=debug_mode)
68
- if debug_mode:
69
- debug_log.append(f"🧾 Validation: {validation}")
 
 
70
 
 
 
 
71
 
72
- # 📝 Summarize
73
- summary = summarize_flowchart(flowchart_json)
74
- print("📝 Generated English summary:")
75
- print(summary)
76
-
77
- # Optional: encode vis_debug for streamlit
78
- yolo_vis = None
79
- if debug_mode and vis_debug:
80
- vis_io = io.BytesIO()
81
- vis_debug.save(vis_io, format="PNG")
82
- vis_io.seek(0)
83
- yolo_vis = base64.b64encode(vis_io.read()).decode("utf-8")
84
-
85
- return JSONResponse({
86
- "flowchart": flowchart_json,
87
- "summary": summary,
88
- "yolo_vis": yolo_vis, # ✅ key must match what Streamlit expects
89
- "debug": "\n".join(debug_log) if debug_mode else ""
90
- })
91
-
92
 
 
93
  if __name__ == "__main__":
94
- uvicorn.run(app, host="0.0.0.0", port=7860)
 
 
 
 
 
1
+ import streamlit as st
 
 
 
2
  from PIL import Image
3
  import io
 
4
  import base64
5
+ import os
6
 
7
+ from yolo_module import run_yolo # Local detection module
8
+ from ocr_module import extract_text # EasyOCR wrapper
9
+ from flowchart_builder import map_arrows, build_flowchart_json
10
+ from summarizer import summarize_flowchart # Your LLM logic
11
 
12
+ st.set_page_config(page_title="Flowchart to English", layout="wide")
13
+ st.title("📄 Flowchart to Plain English")
 
 
 
14
 
15
+ debug_mode = st.toggle("🔧 Show Debug Info", value=False)
16
 
17
+ uploaded_file = st.file_uploader("Upload a flowchart image", type=["png", "jpg", "jpeg"])
 
 
 
 
 
 
 
18
 
19
+ if uploaded_file:
20
+ image = Image.open(uploaded_file)
21
+ max_width = 600
22
+ ratio = max_width / float(image.size[0])
23
+ resized_image = image.resize((max_width, int(image.size[1] * ratio)))
24
+ st.image(resized_image, caption="📤 Uploaded Image", use_container_width=False)
25
 
26
+ if st.button("🔍 Analyze Flowchart"):
27
+ progress = st.progress(0, text="Detecting boxes and arrows...")
28
+ results, arrows, vis_debug = run_yolo(image)
29
+ progress.progress(30, text="Running OCR...")
30
 
31
+ # Add text to results
32
+ for node in results:
33
+ node["text"] = extract_text(image, node["bbox"], debug=debug_mode)
 
 
34
 
35
+ progress.progress(60, text="Building flowchart structure...")
36
+ edges = map_arrows(results, arrows)
37
+ flowchart = build_flowchart_json(results, edges)
 
38
 
39
+ progress.progress(80, text="Generating plain English explanation...")
40
+ summary = summarize_flowchart(flowchart)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
+ col1, col2 = st.columns(2)
43
+ with col1:
44
+ st.subheader("🧠 Flowchart JSON")
45
+ st.json(flowchart)
46
+ with col2:
47
+ st.subheader("📝 English Summary")
48
+ st.markdown(summary)
49
 
50
+ if debug_mode:
51
+ st.subheader("🖼️ YOLO Visual Debug")
52
+ st.image(vis_debug, caption="Detected Boxes & Arrows", use_container_width=True)
53
 
54
+ progress.progress(100, text="Done!")
55
+ else:
56
+ st.info("Upload a flowchart image to begin.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
+ # For Render compatibility
59
  if __name__ == "__main__":
60
+ import streamlit.web.cli as stcli
61
+ import sys
62
+ port = int(os.environ.get("PORT", 7860))
63
+ sys.argv = ["streamlit", "run", "app.py", "--server.port", str(port), "--server.address=0.0.0.0"]
64
+ sys.exit(stcli.main())
graph_module/__init__.py CHANGED
@@ -1,8 +1,14 @@
 
 
 
1
  from shapely.geometry import box, Point
2
  from collections import defaultdict, deque
3
 
4
-
5
  def map_arrows(nodes, arrows):
 
 
 
 
6
  for node in nodes:
7
  node["shape"] = box(*node["bbox"])
8
 
@@ -10,24 +16,20 @@ def map_arrows(nodes, arrows):
10
  for arrow in arrows:
11
  tail_point = Point(arrow["tail"])
12
  head_point = Point(arrow["head"])
13
-
14
- source = None
15
- target = None
16
- for node in nodes:
17
- if node["shape"].contains(tail_point):
18
- source = node["id"]
19
- if node["shape"].contains(head_point):
20
- target = node["id"]
21
-
22
  label = arrow.get("label", "")
23
 
 
 
 
24
  if source and target and source != target:
25
  edges.append((source, target, label))
26
 
27
  return edges
28
 
29
-
30
  def detect_node_type(text):
 
 
 
31
  text_lower = text.lower()
32
  if "start" in text_lower:
33
  return "start"
@@ -37,25 +39,26 @@ def detect_node_type(text):
37
  return "decision"
38
  return "process"
39
 
40
-
41
  def build_flowchart_json(nodes, edges):
 
 
 
42
  graph = {}
43
  reverse_links = defaultdict(list)
44
- edge_labels = defaultdict(list)
45
 
46
  for node in nodes:
47
- raw_text = node.get("text", "").strip()
48
- node_type = node.get("type") or detect_node_type(raw_text)
49
  graph[node["id"]] = {
50
- "text": raw_text,
51
- "type": node_type,
52
  "next": []
53
  }
54
 
55
- for source, target, label in edges:
56
- graph[source]["next"].append(target)
57
- reverse_links[target].append(source)
58
- edge_labels[(source, target)] = label.lower().strip()
59
 
60
  start_nodes = [nid for nid in graph if len(reverse_links[nid]) == 0]
61
  flowchart_json = {
@@ -67,35 +70,35 @@ def build_flowchart_json(nodes, edges):
67
  queue = deque(start_nodes)
68
 
69
  while queue:
70
- current = queue.popleft()
71
- if current in visited:
72
  continue
73
- visited.add(current)
74
 
75
- info = graph[current]
76
  step = {
77
- "id": current,
78
- "text": info["text"],
79
- "type": info["type"]
80
  }
81
- parents = reverse_links[current]
 
82
  if len(parents) == 1:
83
  step["parent"] = parents[0]
84
  elif len(parents) > 1:
85
  step["parents"] = parents
86
 
87
- next_nodes = info["next"]
88
- if info["type"] == "decision" and len(next_nodes) >= 2:
89
- branches = {}
90
- for target in next_nodes:
91
- label = edge_labels.get((current, target), "")
92
  if "yes" in label:
93
- branches["yes"] = target
94
  elif "no" in label:
95
- branches["no"] = target
96
  else:
97
- branches.setdefault("unknown", []).append(target)
98
- step["branches"] = branches
99
  queue.extend(next_nodes)
100
  elif len(next_nodes) == 1:
101
  step["next"] = next_nodes[0]
@@ -106,25 +109,4 @@ def build_flowchart_json(nodes, edges):
106
 
107
  flowchart_json["steps"].append(step)
108
 
109
- return flowchart_json
110
-
111
-
112
- if __name__ == "__main__":
113
- nodes = [
114
- {"id": "node1", "bbox": [100, 100, 200, 150], "text": "Start"},
115
- {"id": "node2", "bbox": [300, 100, 400, 150], "text": "Is valid?"},
116
- {"id": "node3", "bbox": [500, 50, 600, 100], "text": "Approve"},
117
- {"id": "node4", "bbox": [500, 150, 600, 200], "text": "Reject"}
118
- ]
119
-
120
- arrows = [
121
- {"id": "arrow1", "tail": (200, 125), "head": (300, 125), "label": ""},
122
- {"id": "arrow2", "tail": (400, 125), "head": (500, 75), "label": "Yes"},
123
- {"id": "arrow3", "tail": (400, 125), "head": (500, 175), "label": "No"}
124
- ]
125
-
126
- edges = map_arrows(nodes, arrows)
127
- flowchart_json = build_flowchart_json(nodes, edges)
128
-
129
- import json
130
- print(json.dumps(flowchart_json, indent=2))
 
1
+ # flowchart_builder.py
2
+ # Arrow and graph logic for converting detected flowchart elements to structured JSON
3
+
4
  from shapely.geometry import box, Point
5
  from collections import defaultdict, deque
6
 
 
7
  def map_arrows(nodes, arrows):
8
+ """
9
+ Matches arrows to nodes based on geometric endpoints.
10
+ Returns a list of (source_id, target_id, label) edges.
11
+ """
12
  for node in nodes:
13
  node["shape"] = box(*node["bbox"])
14
 
 
16
  for arrow in arrows:
17
  tail_point = Point(arrow["tail"])
18
  head_point = Point(arrow["head"])
 
 
 
 
 
 
 
 
 
19
  label = arrow.get("label", "")
20
 
21
+ source = next((n["id"] for n in nodes if n["shape"].contains(tail_point)), None)
22
+ target = next((n["id"] for n in nodes if n["shape"].contains(head_point)), None)
23
+
24
  if source and target and source != target:
25
  edges.append((source, target, label))
26
 
27
  return edges
28
 
 
29
  def detect_node_type(text):
30
+ """
31
+ Heuristic-based type detection from node text.
32
+ """
33
  text_lower = text.lower()
34
  if "start" in text_lower:
35
  return "start"
 
39
  return "decision"
40
  return "process"
41
 
 
42
  def build_flowchart_json(nodes, edges):
43
+ """
44
+ Constructs flowchart JSON structure with parent and branching info.
45
+ """
46
  graph = {}
47
  reverse_links = defaultdict(list)
48
+ edge_labels = {}
49
 
50
  for node in nodes:
51
+ text = node.get("text", "").strip()
 
52
  graph[node["id"]] = {
53
+ "text": text,
54
+ "type": node.get("type") or detect_node_type(text),
55
  "next": []
56
  }
57
 
58
+ for src, tgt, label in edges:
59
+ graph[src]["next"].append(tgt)
60
+ reverse_links[tgt].append(src)
61
+ edge_labels[(src, tgt)] = label.lower().strip()
62
 
63
  start_nodes = [nid for nid in graph if len(reverse_links[nid]) == 0]
64
  flowchart_json = {
 
70
  queue = deque(start_nodes)
71
 
72
  while queue:
73
+ curr = queue.popleft()
74
+ if curr in visited:
75
  continue
76
+ visited.add(curr)
77
 
78
+ node = graph[curr]
79
  step = {
80
+ "id": curr,
81
+ "text": node["text"],
82
+ "type": node["type"]
83
  }
84
+
85
+ parents = reverse_links[curr]
86
  if len(parents) == 1:
87
  step["parent"] = parents[0]
88
  elif len(parents) > 1:
89
  step["parents"] = parents
90
 
91
+ next_nodes = node["next"]
92
+ if node["type"] == "decision" and len(next_nodes) >= 2:
93
+ step["branches"] = {}
94
+ for tgt in next_nodes:
95
+ label = edge_labels.get((curr, tgt), "")
96
  if "yes" in label:
97
+ step["branches"]["yes"] = tgt
98
  elif "no" in label:
99
+ step["branches"]["no"] = tgt
100
  else:
101
+ step["branches"].setdefault("unknown", []).append(tgt)
 
102
  queue.extend(next_nodes)
103
  elif len(next_nodes) == 1:
104
  step["next"] = next_nodes[0]
 
109
 
110
  flowchart_json["steps"].append(step)
111
 
112
+ return flowchart_json
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
render.yaml CHANGED
@@ -3,11 +3,13 @@ services:
3
  name: flowchart-app
4
  env: python
5
  buildCommand: |
6
- apt-get update && apt-get install -y tesseract-ocr
7
  pip install -r requirements.txt
8
  startCommand: streamlit run app.py --server.port=$PORT --server.address=0.0.0.0
9
  plan: free
10
  envVars:
11
  - key: PORT
12
- value: 10000 # This line can actually be removed if you're using $PORT above
 
 
13
  pythonVersion: 3.10
 
3
  name: flowchart-app
4
  env: python
5
  buildCommand: |
6
+ pip install --upgrade pip
7
  pip install -r requirements.txt
8
  startCommand: streamlit run app.py --server.port=$PORT --server.address=0.0.0.0
9
  plan: free
10
  envVars:
11
  - key: PORT
12
+ value: 10000
13
+ - key: API_URL
14
+ value: https://your-fastapi-service.onrender.com/process-image
15
  pythonVersion: 3.10
streamlit_app.py DELETED
@@ -1,70 +0,0 @@
1
- # streamlit_app.py
2
- import streamlit as st
3
- import requests
4
- import json
5
- from PIL import Image
6
- import io
7
- import base64
8
-
9
- API_URL = "http://localhost:7860/process-image" # Change if hosted elsewhere
10
-
11
- st.set_page_config(page_title="Flowchart to English", layout="wide")
12
- st.title("📄 Flowchart to Plain English")
13
-
14
- # Debug mode switch
15
- debug_mode = st.toggle("🔧 Show Debug Info", value=False)
16
-
17
- uploaded_file = st.file_uploader("Upload a flowchart image", type=["png", "jpg", "jpeg"])
18
-
19
- if uploaded_file:
20
- # Resize image for smaller canvas
21
- image = Image.open(uploaded_file)
22
- max_width = 600
23
- ratio = max_width / float(image.size[0])
24
- new_height = int((float(image.size[1]) * float(ratio)))
25
- resized_image = image.resize((max_width, new_height))
26
- st.image(resized_image, caption="📤 Uploaded Image", use_container_width=False)
27
-
28
- if st.button("🔍 Analyze Flowchart"):
29
- progress = st.progress(0, text="Sending image to backend...")
30
-
31
- try:
32
- response = requests.post(
33
- API_URL,
34
- files={"file": uploaded_file.getvalue()},
35
- data={"debug": str(debug_mode).lower()}
36
- )
37
- progress.progress(50, text="Processing detection, OCR, and reasoning...")
38
-
39
- if response.status_code == 200:
40
- data = response.json()
41
- progress.progress(80, text="Generating explanation using LLM...")
42
-
43
- # Optional: Visualize bounding boxes
44
- if debug_mode and data.get("yolo_vis"):
45
- st.markdown("### 🖼️ YOLO Debug Bounding Boxes")
46
- vis_bytes = base64.b64decode(data["yolo_vis"])
47
- vis_img = Image.open(io.BytesIO(vis_bytes))
48
- st.image(vis_img, caption="YOLO Detected Boxes", use_container_width=True)
49
-
50
- # Optional: show logs
51
- if debug_mode and "debug" in data:
52
- st.markdown("### 🧪 Debug Pipeline Info")
53
- st.code(data["debug"], language="markdown")
54
-
55
- # Display results in 2 columns
56
- col1, col2 = st.columns(2)
57
- with col1:
58
- st.subheader("🧠 Flowchart JSON")
59
- st.json(data["flowchart"])
60
- with col2:
61
- st.subheader("📝 English Summary")
62
- st.markdown(data["summary"])
63
-
64
- progress.progress(100, text="Done!")
65
- else:
66
- st.error(f"Something went wrong: {response.status_code}")
67
- except Exception as e:
68
- st.error(f"An error occurred: {e}")
69
- else:
70
- st.info("Upload a flowchart image to begin.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
yolo_module/__init__.py CHANGED
@@ -1,22 +1,25 @@
1
- # yolo_module.py (updated to use .pt instead of ONNX)
2
  from ultralytics import YOLO
3
  from PIL import Image, ImageDraw
4
  import numpy as np
5
-
6
- # Define YOLO class labels (should be inferred automatically from .pt model)
7
- # CLASS_NAMES no longer needed unless doing custom filtering
8
 
9
  # Load YOLO model
10
  MODEL_PATH = "models/best.pt"
11
  model = YOLO(MODEL_PATH)
12
 
 
 
 
13
  def run_yolo(image: Image.Image):
14
- # Run YOLO prediction
15
- results = model.predict(image, conf=0.25, verbose=False)[0] # single image
16
 
17
  boxes = []
18
  arrows = []
19
 
 
 
 
20
  for i, box in enumerate(results.boxes):
21
  cls_id = int(box.cls)
22
  conf = float(box.conf)
@@ -33,15 +36,25 @@ def run_yolo(image: Image.Image):
33
  }
34
 
35
  if item["type"] == "arrow":
 
 
 
 
 
 
 
 
 
 
 
36
  arrows.append({
37
  "id": f"arrow{len(arrows)+1}",
38
  "tail": (x1, y1),
39
- "head": (x2, y2)
 
40
  })
41
  else:
42
  boxes.append(item)
43
 
44
- # Visualization
45
  vis_image = results.plot(pil=True)
46
-
47
  return boxes, arrows, vis_image
 
1
+ # yolo_module.py
2
  from ultralytics import YOLO
3
  from PIL import Image, ImageDraw
4
  import numpy as np
5
+ import easyocr
 
 
6
 
7
  # Load YOLO model
8
  MODEL_PATH = "models/best.pt"
9
  model = YOLO(MODEL_PATH)
10
 
11
+ # Optional OCR reader for arrow label detection
12
+ reader = easyocr.Reader(['en'], gpu=False)
13
+
14
  def run_yolo(image: Image.Image):
15
+ results = model.predict(image, conf=0.25, verbose=False)[0]
 
16
 
17
  boxes = []
18
  arrows = []
19
 
20
+ # Convert image to OpenCV format for EasyOCR
21
+ np_img = np.array(image)
22
+
23
  for i, box in enumerate(results.boxes):
24
  cls_id = int(box.cls)
25
  conf = float(box.conf)
 
36
  }
37
 
38
  if item["type"] == "arrow":
39
+ # Heuristically scan a small region near the middle of the arrow for a label
40
+ cx, cy = (x1 + x2) // 2, (y1 + y2) // 2
41
+ pad = 20
42
+ crop = np_img[max(cy - pad, 0):cy + pad, max(cx - pad, 0):cx + pad]
43
+
44
+ detected_label = ""
45
+ if crop.size > 0:
46
+ ocr_results = reader.readtext(crop)
47
+ if ocr_results:
48
+ detected_label = ocr_results[0][1] # (bbox, text, conf)
49
+
50
  arrows.append({
51
  "id": f"arrow{len(arrows)+1}",
52
  "tail": (x1, y1),
53
+ "head": (x2, y2),
54
+ "label": detected_label
55
  })
56
  else:
57
  boxes.append(item)
58
 
 
59
  vis_image = results.plot(pil=True)
 
60
  return boxes, arrows, vis_image