Spaces:
Running
on
Zero
Running
on
Zero
from typing import List, Tuple | |
import numpy as np | |
from numpy import ndarray | |
def get_mAP( | |
preds: ndarray, | |
gt_file: str, | |
taglist: List[str] | |
) -> Tuple[float, ndarray]: | |
assert preds.shape[1] == len(taglist) | |
# When mapping categories from test datasets to our system, there might be | |
# multiple vs one situation due to different semantic definitions of tags. | |
# So there can be duplicate tags in `taglist`. This special case is taken | |
# into account. | |
tag2idxs = {} | |
for idx, tag in enumerate(taglist): | |
if tag not in tag2idxs: | |
tag2idxs[tag] = [] | |
tag2idxs[tag].append(idx) | |
# build targets | |
targets = np.zeros_like(preds) | |
with open(gt_file, "r") as f: | |
lines = [line.strip("\n").split(",") for line in f.readlines()] | |
assert len(lines) == targets.shape[0] | |
for i, line in enumerate(lines): | |
for tag in line[1:]: | |
targets[i, tag2idxs[tag]] = 1.0 | |
# compute average precision for each class | |
APs = np.zeros(preds.shape[1]) | |
for k in range(preds.shape[1]): | |
APs[k] = _average_precision(preds[:, k], targets[:, k]) | |
return APs.mean(), APs | |
def _average_precision(output: ndarray, target: ndarray) -> float: | |
epsilon = 1e-8 | |
# sort examples | |
indices = output.argsort()[::-1] | |
# Computes prec@i | |
total_count_ = np.cumsum(np.ones((len(output), 1))) | |
target_ = target[indices] | |
ind = target_ == 1 | |
pos_count_ = np.cumsum(ind) | |
total = pos_count_[-1] | |
pos_count_[np.logical_not(ind)] = 0 | |
pp = pos_count_ / total_count_ | |
precision_at_i_ = np.sum(pp) | |
precision_at_i = precision_at_i_ / (total + epsilon) | |
return precision_at_i | |
def get_PR( | |
pred_file: str, | |
gt_file: str, | |
taglist: List[str] | |
) -> Tuple[float, float, ndarray, ndarray]: | |
# When mapping categories from test datasets to our system, there might be | |
# multiple vs one situation due to different semantic definitions of tags. | |
# So there can be duplicate tags in `taglist`. This special case is taken | |
# into account. | |
tag2idxs = {} | |
for idx, tag in enumerate(taglist): | |
if tag not in tag2idxs: | |
tag2idxs[tag] = [] | |
tag2idxs[tag].append(idx) | |
# build preds | |
with open(pred_file, "r", encoding="utf-8") as f: | |
lines = [line.strip().split(",") for line in f.readlines()] | |
preds = np.zeros((len(lines), len(tag2idxs)), dtype=bool) | |
for i, line in enumerate(lines): | |
for tag in line[1:]: | |
preds[i, tag2idxs[tag]] = True | |
# build targets | |
with open(gt_file, "r", encoding="utf-8") as f: | |
lines = [line.strip().split(",") for line in f.readlines()] | |
targets = np.zeros((len(lines), len(tag2idxs)), dtype=bool) | |
for i, line in enumerate(lines): | |
for tag in line[1:]: | |
targets[i, tag2idxs[tag]] = True | |
assert preds.shape == targets.shape | |
# calculate P and R | |
TPs = ( preds & targets).sum(axis=0) # noqa: E201, E222 | |
FPs = ( preds & ~targets).sum(axis=0) # noqa: E201, E222 | |
FNs = (~preds & targets).sum(axis=0) # noqa: E201, E222 | |
eps = 1.e-9 | |
Ps = TPs / (TPs + FPs + eps) | |
Rs = TPs / (TPs + FNs + eps) | |
return Ps.mean(), Rs.mean(), Ps, Rs | |