File size: 3,313 Bytes
2795186
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import os
import sys
import csv
import json
import networkx as nx
from collections import defaultdict

import matplotlib
import matplotlib.pyplot as plt
import warnings

warnings.filterwarnings('ignore', category=matplotlib.cbook.deprecation.MatplotlibDeprecationWarning)


def load_alerts(_conf_json):
    _g = nx.DiGraph()
    _bank_accts = defaultdict(list)

    with open(_conf_json, "r") as rf:
        conf = json.load(rf)
    
    data_dir = os.path.join(conf["output"]["directory"], conf["general"]["simulation_name"])
    acct_csv = os.path.join(data_dir, conf["output"]["alert_members"])
    tx_csv = os.path.join(data_dir, conf["output"]["alert_transactions"])

    input_dir = conf["input"]["directory"]
    schema_json = os.path.join(input_dir, conf["input"]["schema"])
    with open(schema_json, "r") as rf:
        schema = json.load(rf)

    acct_idx = None
    bank_idx = None
    orig_idx = None
    bene_idx = None
    amt_idx = None
    date_idx = None
    for i, col in enumerate(schema["alert_member"]):
        if col.get("dataType") == "account_id":
            acct_idx = i
        elif col.get("dataType") == "bank_id":
            bank_idx = i
    for i, col in enumerate(schema["alert_tx"]):
        if col.get("dataType") == "orig_id":
            orig_idx = i
        elif col.get("dataType") == "dest_id":
            bene_idx = i
        elif col.get("dataType") == "amount":
            amt_idx = i
        elif col.get("dataType") == "timestamp":
            date_idx = i

    with open(acct_csv, "r") as rf:
        reader = csv.reader(rf)
        next(reader)
        for row in reader:
            acct_id = row[acct_idx]
            bank_id = row[bank_idx]
            _g.add_node(acct_id, bank_id=bank_id)
            _bank_accts[bank_id].append(acct_id)

    with open(tx_csv, "r") as rf:
        reader = csv.reader(rf)
        next(reader)
        for row in reader:
            orig_id = row[orig_idx]
            bene_id = row[bene_idx]
            amount = row[amt_idx]
            date = row[date_idx].split("T")[0]  # Extract only the date
            label = amount + "\n" + date
            _g.add_edge(orig_id, bene_id, amount=amount, date=date, label=label)

    return _g, _bank_accts


def plot_alerts(_g, _bank_accts, _output_png):
    bank_ids = _bank_accts.keys()
    cmap = plt.get_cmap("tab10")
    pos = nx.nx_agraph.graphviz_layout(_g)

    plt.figure(figsize=(12.0, 8.0))
    plt.axis('off')

    for i, bank_id in enumerate(bank_ids):
        color = cmap(i)
        members = _bank_accts[bank_id]
        nx.draw_networkx_nodes(_g, pos, members, node_size=300, node_color=color, label=bank_id)
        nx.draw_networkx_labels(_g, pos, {n: n for n in members}, font_size=10)

    edge_labels = nx.get_edge_attributes(_g, "label")
    nx.draw_networkx_edges(_g, pos)
    nx.draw_networkx_edge_labels(_g, pos, edge_labels, font_size=6)

    plt.legend(numpoints=1)
    plt.subplots_adjust(left=0, right=1, bottom=0, top=1)
    plt.savefig(_output_png, dpi=120)


if __name__ == "__main__":
    argv = sys.argv

    if len(argv) < 3:
        print("Usage: python3 %s [ConfJSON] [OutputPNG]" % argv[0])
        exit(1)

    conf_json = argv[1]
    output_png = argv[2]
    g, bank_accts = load_alerts(conf_json)
    plot_alerts(g, bank_accts, output_png)