NexusLearnAI / visualization.py
ChaseHan's picture
Upload 19 files
415ceb3 verified
"""
Visualization Module - Generate concept knowledge graphs
"""
import matplotlib.pyplot as plt
import networkx as nx
import matplotlib
import io
import base64
import os
from typing import Dict, Any, List
# Ensure using Agg backend (no need for GUI)
matplotlib.use('Agg')
# Set up Chinese font support
# Try to find suitable Chinese fonts
font_found = False
chinese_fonts = ['SimHei', 'Microsoft YaHei', 'WenQuanYi Micro Hei', 'AR PL UMing CN', 'STSong', 'NSimSun', 'FangSong', 'KaiTi']
for font in chinese_fonts:
try:
matplotlib.font_manager.findfont(font)
matplotlib.rcParams['font.sans-serif'] = [font, 'DejaVu Sans', 'Arial Unicode MS', 'sans-serif']
print(f"Using Chinese font: {font}")
font_found = True
break
except:
continue
if not font_found:
print("Warning: No suitable Chinese font found, using default font")
matplotlib.rcParams['font.sans-serif'] = ['DejaVu Sans', 'Arial Unicode MS', 'sans-serif']
matplotlib.rcParams['axes.unicode_minus'] = False
matplotlib.rcParams['font.size'] = 10
def create_network_graph(concepts_data: Dict[str, Any]) -> str:
"""
Create an enhanced network visualization of concept relationships
Args:
concepts_data: Dictionary containing concept hierarchy and relationships
Returns:
Base64 encoded PNG image as data URL
"""
G = nx.DiGraph()
# Clear any existing plots
plt.clf()
plt.close('all')
# Increase figure size and DPI for better display
plt.figure(figsize=(14, 10), dpi=150, facecolor='white')
# Add nodes with difficulty-based colors
difficulty_colors = {
'basic': '#90CAF9', # Light blue
'intermediate': '#FFB74D', # Orange
'advanced': '#EF5350' # Red
}
# Only add subconcepts (skip main concept)
for concept in concepts_data.get("sub_concepts", []):
concept_id = concept.get("id")
concept_name = concept.get("name")
difficulty = concept.get("difficulty", "basic")
if concept_id and concept_name:
G.add_node(
concept_id,
name=concept_name,
type="sub",
difficulty=difficulty,
color=difficulty_colors.get(difficulty, '#90CAF9')
)
# Add relationships between subconcepts only
for relation in concepts_data.get("relationships", []):
source = relation.get("source")
target = relation.get("target")
rel_type = relation.get("type")
# Skip relationships involving main concept
if (source and target and
source in G.nodes and target in G.nodes): # Only add edges between existing subconcepts
G.add_edge(
source,
target,
type=rel_type
)
# Optimize layout parameters and increase node spacing
pos = nx.spring_layout(
G,
k=2.0, # Increase node spacing
iterations=100, # Increase iterations for better layout
seed=42 # Fixed random seed for consistent layout
)
# Draw nodes with difficulty-based colors
node_colors = [G.nodes[node].get('color', '#90CAF9') for node in G.nodes()]
# All nodes are now the same size since there's no main concept
node_sizes = [1500 for _ in G.nodes()]
# Draw nodes
nx.draw_networkx_nodes(
G, pos,
node_color=node_colors,
node_size=node_sizes,
alpha=0.8
)
# Draw edges with different styles for different relationship types
edges_prerequisite = [(u, v) for (u, v, d) in G.edges(data=True) if d.get('type') == 'prerequisite']
edges_related = [(u, v) for (u, v, d) in G.edges(data=True) if d.get('type') == 'related']
# Draw edges with curves to avoid overlap
nx.draw_networkx_edges(
G, pos,
edgelist=edges_prerequisite,
edge_color='red',
width=2,
connectionstyle="arc3,rad=0.2", # Add curve
arrowsize=20,
arrowstyle='->',
min_source_margin=30,
min_target_margin=30
)
nx.draw_networkx_edges(
G, pos,
edgelist=edges_related,
edge_color='blue',
style='dashed',
width=1.5,
connectionstyle="arc3,rad=-0.2", # Add reverse curve
arrowsize=15,
arrowstyle='->',
min_source_margin=25,
min_target_margin=25
)
# Optimize label display
labels = {
node: G.nodes[node].get('name', node)
for node in G.nodes()
}
# Calculate label position offsets
label_pos = {
node: (coord[0], coord[1] + 0.08) # Offset labels upward
for node, coord in pos.items()
}
# Use larger font size and add text background
nx.draw_networkx_labels(
G,
label_pos,
labels,
font_size=12, # Increase font size
font_weight='bold',
bbox={ # Add text background
'facecolor': 'white',
'edgecolor': '#E0E0E0',
'alpha': 0.9,
'pad': 6,
'boxstyle': 'round,pad=0.5'
}
)
# Adjust legend position and size
legend_elements = [
plt.Line2D([0], [0], color='red', lw=2, label='Prerequisite'),
plt.Line2D([0], [0], color='blue', linestyle='--', lw=2, label='Related'),
plt.Line2D([0], [0], marker='o', color='w', label='Basic', markerfacecolor='#90CAF9', markersize=12),
plt.Line2D([0], [0], marker='o', color='w', label='Intermediate', markerfacecolor='#FFB74D', markersize=12),
plt.Line2D([0], [0], marker='o', color='w', label='Advanced', markerfacecolor='#EF5350', markersize=12)
]
plt.legend(
handles=legend_elements,
loc='upper right',
bbox_to_anchor=(1.2, 1),
fontsize=10,
frameon=True,
facecolor='white',
edgecolor='none',
shadow=True
)
# Add title showing the main concept without creating a node for it
main_concept = concepts_data.get("main_concept", "Concept Map")
plt.title(f"Concept Map: {main_concept}", pad=20, fontsize=14, fontweight='bold')
# Increase graph margins
plt.margins(x=0.2, y=0.2)
plt.axis('off')
plt.tight_layout()
# Add padding when saving the image
buf = io.BytesIO()
plt.savefig(
buf,
format='png',
bbox_inches='tight',
dpi=150,
pad_inches=0.5
)
plt.close('all')
buf.seek(0)
return "data:image/png;base64," + base64.b64encode(buf.getvalue()).decode('utf-8')