import os import sys import csv import json import networkx as nx from collections import defaultdict import matplotlib import matplotlib.pyplot as plt import warnings warnings.filterwarnings('ignore', category=matplotlib.cbook.deprecation.MatplotlibDeprecationWarning) def load_alerts(_conf_json): _g = nx.DiGraph() _bank_accts = defaultdict(list) with open(_conf_json, "r") as rf: conf = json.load(rf) data_dir = os.path.join(conf["output"]["directory"], conf["general"]["simulation_name"]) acct_csv = os.path.join(data_dir, conf["output"]["alert_members"]) tx_csv = os.path.join(data_dir, conf["output"]["alert_transactions"]) input_dir = conf["input"]["directory"] schema_json = os.path.join(input_dir, conf["input"]["schema"]) with open(schema_json, "r") as rf: schema = json.load(rf) acct_idx = None bank_idx = None orig_idx = None bene_idx = None amt_idx = None date_idx = None for i, col in enumerate(schema["alert_member"]): if col.get("dataType") == "account_id": acct_idx = i elif col.get("dataType") == "bank_id": bank_idx = i for i, col in enumerate(schema["alert_tx"]): if col.get("dataType") == "orig_id": orig_idx = i elif col.get("dataType") == "dest_id": bene_idx = i elif col.get("dataType") == "amount": amt_idx = i elif col.get("dataType") == "timestamp": date_idx = i with open(acct_csv, "r") as rf: reader = csv.reader(rf) next(reader) for row in reader: acct_id = row[acct_idx] bank_id = row[bank_idx] _g.add_node(acct_id, bank_id=bank_id) _bank_accts[bank_id].append(acct_id) with open(tx_csv, "r") as rf: reader = csv.reader(rf) next(reader) for row in reader: orig_id = row[orig_idx] bene_id = row[bene_idx] amount = row[amt_idx] date = row[date_idx].split("T")[0] # Extract only the date label = amount + "\n" + date _g.add_edge(orig_id, bene_id, amount=amount, date=date, label=label) return _g, _bank_accts def plot_alerts(_g, _bank_accts, _output_png): bank_ids = _bank_accts.keys() cmap = plt.get_cmap("tab10") pos = nx.nx_agraph.graphviz_layout(_g) plt.figure(figsize=(12.0, 8.0)) plt.axis('off') for i, bank_id in enumerate(bank_ids): color = cmap(i) members = _bank_accts[bank_id] nx.draw_networkx_nodes(_g, pos, members, node_size=300, node_color=color, label=bank_id) nx.draw_networkx_labels(_g, pos, {n: n for n in members}, font_size=10) edge_labels = nx.get_edge_attributes(_g, "label") nx.draw_networkx_edges(_g, pos) nx.draw_networkx_edge_labels(_g, pos, edge_labels, font_size=6) plt.legend(numpoints=1) plt.subplots_adjust(left=0, right=1, bottom=0, top=1) plt.savefig(_output_png, dpi=120) if __name__ == "__main__": argv = sys.argv if len(argv) < 3: print("Usage: python3 %s [ConfJSON] [OutputPNG]" % argv[0]) exit(1) conf_json = argv[1] output_png = argv[2] g, bank_accts = load_alerts(conf_json) plot_alerts(g, bank_accts, output_png)