File size: 3,934 Bytes
1c60c6e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
#  Copyright 2021-present, the Recognai S.L. team.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.

from typing import Dict


def empty_visualization(message: str = "No data found"):
    import plotly.graph_objects as go

    fig = go.Figure()
    fig.update_layout(
        xaxis={"visible": False},
        yaxis={"visible": False},
        annotations=[{"text": message, "showarrow": False, "font": {"size": 30}}],
    )
    return fig


def bar(data: dict, title: str = "Bar", x_legend: str = "", y_legend: str = ""):
    import plotly.graph_objects as go

    if not data:
        return empty_visualization()

    keys, values = zip(*data.items())
    keys = [
        key.encode("unicode-escape").decode() if isinstance(key, str) else key
        for key in keys
    ]
    fig = go.Figure(data=go.Bar(y=values, x=keys))
    fig.update_layout(
        title=title,
        xaxis_title=x_legend,
        yaxis_title=y_legend,
    )
    return fig


def stacked_bar(
    x: list,
    y_s: Dict[str, list],
    title: str = "Bar",
    x_legend: str = "",
    y_legend: str = "",
):
    import plotly.graph_objects as go

    if not x or not y_s:
        return empty_visualization()

    data = [go.Bar(name=name, x=x, y=y_values) for name, y_values in y_s.items()]

    fig = go.Figure(data=data)
    fig.update_layout(
        title=title,
        xaxis_title=x_legend,
        yaxis_title=y_legend,
        barmode="stack",
    )

    return fig


def histogram(data: dict, title: str = "Bar", x_legend: str = "", y_legend: str = ""):
    if not data:
        return empty_visualization()

    data = {float(k): v for k, v in data.items()}
    return bar(data, title, x_legend, y_legend)


def tree_map(labels, parents, values, title: str = "Tree"):
    import plotly.graph_objects as go

    if not values:  # TODO: review condition
        return empty_visualization()

    fig = go.Figure(
        go.Treemap(
            labels=labels,
            parents=parents,
            values=values,
        )
    )
    fig.update_layout(title=title, margin=dict(t=50, l=0, r=0, b=0))

    return fig


def multilevel_pie(labels, parents, values, title: str = "Pie"):
    import plotly.graph_objects as go

    fig = go.Figure(
        go.Sunburst(
            labels=labels,
            parents=parents,
            values=values,
        )
    )
    fig.update_layout(title=title, margin=dict(t=50, l=0, r=0, b=0))

    return fig


def f1(data: Dict[str, float], title: str):
    from plotly.subplots import make_subplots

    fig = make_subplots(
        rows=1,
        cols=3,
        subplot_titles=[
            "macro average",
            "micro average",
            "per label",
        ],
    )

    fig.add_bar(
        x=[k.split("_")[0] for k, v in data.items() if "macro" in k],
        y=[v for k, v in data.items() if "macro" in k],
        row=1,
        col=1,
    )
    fig.add_bar(
        x=[k.split("_")[0] for k, v in data.items() if "micro" in k],
        y=[v for k, v in data.items() if "micro" in k],
        row=1,
        col=2,
    )
    per_label = {
        k: v
        for k, v in data.items()
        if all(key not in k for key in ["macro", "micro", "support"])
    }

    fig.add_bar(
        x=[k for k, v in per_label.items()],
        y=[v for k, v in per_label.items()],
        row=1,
        col=3,
    )
    fig.update_layout(showlegend=False, title_text=title)

    return fig