Spaces:
Running
Running
File size: 6,674 Bytes
833dac3 415ceb3 833dac3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 |
"""
Visualization Module - Generate concept knowledge graphs
"""
import matplotlib.pyplot as plt
import networkx as nx
import matplotlib
import io
import base64
import os
from typing import Dict, Any, List
# Ensure using Agg backend (no need for GUI)
matplotlib.use('Agg')
# Set up Chinese font support
# Try to find suitable Chinese fonts
font_found = False
chinese_fonts = ['SimHei', 'Microsoft YaHei', 'WenQuanYi Micro Hei', 'AR PL UMing CN', 'STSong', 'NSimSun', 'FangSong', 'KaiTi']
for font in chinese_fonts:
try:
matplotlib.font_manager.findfont(font)
matplotlib.rcParams['font.sans-serif'] = [font, 'DejaVu Sans', 'Arial Unicode MS', 'sans-serif']
print(f"Using Chinese font: {font}")
font_found = True
break
except:
continue
if not font_found:
print("Warning: No suitable Chinese font found, using default font")
matplotlib.rcParams['font.sans-serif'] = ['DejaVu Sans', 'Arial Unicode MS', 'sans-serif']
matplotlib.rcParams['axes.unicode_minus'] = False
matplotlib.rcParams['font.size'] = 10
def create_network_graph(concepts_data: Dict[str, Any]) -> str:
"""
Create an enhanced network visualization of concept relationships
Args:
concepts_data: Dictionary containing concept hierarchy and relationships
Returns:
Base64 encoded PNG image as data URL
"""
G = nx.DiGraph()
# Clear any existing plots
plt.clf()
plt.close('all')
# Increase figure size and DPI for better display
plt.figure(figsize=(14, 10), dpi=150, facecolor='white')
# Add nodes with difficulty-based colors
difficulty_colors = {
'basic': '#90CAF9', # Light blue
'intermediate': '#FFB74D', # Orange
'advanced': '#EF5350' # Red
}
# Only add subconcepts (skip main concept)
for concept in concepts_data.get("sub_concepts", []):
concept_id = concept.get("id")
concept_name = concept.get("name")
difficulty = concept.get("difficulty", "basic")
if concept_id and concept_name:
G.add_node(
concept_id,
name=concept_name,
type="sub",
difficulty=difficulty,
color=difficulty_colors.get(difficulty, '#90CAF9')
)
# Add relationships between subconcepts only
for relation in concepts_data.get("relationships", []):
source = relation.get("source")
target = relation.get("target")
rel_type = relation.get("type")
# Skip relationships involving main concept
if (source and target and
source in G.nodes and target in G.nodes): # Only add edges between existing subconcepts
G.add_edge(
source,
target,
type=rel_type
)
# Optimize layout parameters and increase node spacing
pos = nx.spring_layout(
G,
k=2.0, # Increase node spacing
iterations=100, # Increase iterations for better layout
seed=42 # Fixed random seed for consistent layout
)
# Draw nodes with difficulty-based colors
node_colors = [G.nodes[node].get('color', '#90CAF9') for node in G.nodes()]
# All nodes are now the same size since there's no main concept
node_sizes = [1500 for _ in G.nodes()]
# Draw nodes
nx.draw_networkx_nodes(
G, pos,
node_color=node_colors,
node_size=node_sizes,
alpha=0.8
)
# Draw edges with different styles for different relationship types
edges_prerequisite = [(u, v) for (u, v, d) in G.edges(data=True) if d.get('type') == 'prerequisite']
edges_related = [(u, v) for (u, v, d) in G.edges(data=True) if d.get('type') == 'related']
# Draw edges with curves to avoid overlap
nx.draw_networkx_edges(
G, pos,
edgelist=edges_prerequisite,
edge_color='red',
width=2,
connectionstyle="arc3,rad=0.2", # Add curve
arrowsize=20,
arrowstyle='->',
min_source_margin=30,
min_target_margin=30
)
nx.draw_networkx_edges(
G, pos,
edgelist=edges_related,
edge_color='blue',
style='dashed',
width=1.5,
connectionstyle="arc3,rad=-0.2", # Add reverse curve
arrowsize=15,
arrowstyle='->',
min_source_margin=25,
min_target_margin=25
)
# Optimize label display
labels = {
node: G.nodes[node].get('name', node)
for node in G.nodes()
}
# Calculate label position offsets
label_pos = {
node: (coord[0], coord[1] + 0.08) # Offset labels upward
for node, coord in pos.items()
}
# Use larger font size and add text background
nx.draw_networkx_labels(
G,
label_pos,
labels,
font_size=12, # Increase font size
font_weight='bold',
bbox={ # Add text background
'facecolor': 'white',
'edgecolor': '#E0E0E0',
'alpha': 0.9,
'pad': 6,
'boxstyle': 'round,pad=0.5'
}
)
# Adjust legend position and size
legend_elements = [
plt.Line2D([0], [0], color='red', lw=2, label='Prerequisite'),
plt.Line2D([0], [0], color='blue', linestyle='--', lw=2, label='Related'),
plt.Line2D([0], [0], marker='o', color='w', label='Basic', markerfacecolor='#90CAF9', markersize=12),
plt.Line2D([0], [0], marker='o', color='w', label='Intermediate', markerfacecolor='#FFB74D', markersize=12),
plt.Line2D([0], [0], marker='o', color='w', label='Advanced', markerfacecolor='#EF5350', markersize=12)
]
plt.legend(
handles=legend_elements,
loc='upper right',
bbox_to_anchor=(1.2, 1),
fontsize=10,
frameon=True,
facecolor='white',
edgecolor='none',
shadow=True
)
# Add title showing the main concept without creating a node for it
main_concept = concepts_data.get("main_concept", "Concept Map")
plt.title(f"Concept Map: {main_concept}", pad=20, fontsize=14, fontweight='bold')
# Increase graph margins
plt.margins(x=0.2, y=0.2)
plt.axis('off')
plt.tight_layout()
# Add padding when saving the image
buf = io.BytesIO()
plt.savefig(
buf,
format='png',
bbox_inches='tight',
dpi=150,
pad_inches=0.5
)
plt.close('all')
buf.seek(0)
return "data:image/png;base64," + base64.b64encode(buf.getvalue()).decode('utf-8') |