Spaces:
Runtime error
Runtime error
Create datastats.py
Browse files- datastats.py +301 -0
datastats.py
ADDED
@@ -0,0 +1,301 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""" DataStats metric. """
|
2 |
+
|
3 |
+
import functools
|
4 |
+
from collections import Counter
|
5 |
+
from multiprocessing import Pool
|
6 |
+
from contextlib import contextmanager
|
7 |
+
from typing import List, Any, Dict, Optional
|
8 |
+
from collections import namedtuple as _namedtuple
|
9 |
+
|
10 |
+
import spacy
|
11 |
+
import datasets
|
12 |
+
import evaluate
|
13 |
+
from packaging import version
|
14 |
+
|
15 |
+
try:
|
16 |
+
_en = spacy.load('en_core_web_sm')
|
17 |
+
except OSError as stderr:
|
18 |
+
spacy.cli.download('en_core_web_sm')
|
19 |
+
_en = spacy.load('en_core_web_sm')
|
20 |
+
|
21 |
+
@contextmanager
|
22 |
+
def filter_logging_context():
|
23 |
+
|
24 |
+
def filter_log(record):
|
25 |
+
return False if "This is expected if you are initialising" in record.msg else True
|
26 |
+
|
27 |
+
logger = datasets.utils.logging.get_logger("transformers.modeling_utils")
|
28 |
+
logger.addFilter(filter_log)
|
29 |
+
|
30 |
+
try:
|
31 |
+
yield
|
32 |
+
finally:
|
33 |
+
logger.removeFilter(filter_log)
|
34 |
+
|
35 |
+
|
36 |
+
_CITATION = """\
|
37 |
+
@article{grusky2018newsroom,
|
38 |
+
title={Newsroom: A dataset of 1.3 million summaries with diverse extractive strategies},
|
39 |
+
author={Grusky, Max and Naaman, Mor and Artzi, Yoav},
|
40 |
+
journal={arXiv preprint arXiv:1804.11283},
|
41 |
+
year={2018}
|
42 |
+
}
|
43 |
+
"""
|
44 |
+
|
45 |
+
_DESCRIPTION = """\
|
46 |
+
DataStats examines summarization strategies using three measures that capture the degree of text overlap between the summary and article, and the rate of compression of the information conveyed.
|
47 |
+
"""
|
48 |
+
|
49 |
+
_KWARGS_DESCRIPTION = """
|
50 |
+
BERTScore Metrics with the hashcode from a source against one or more references.
|
51 |
+
Args:
|
52 |
+
predictions (list of str): Prediction/candidate sentences.
|
53 |
+
references (list of str or list of list of str): Reference sentences.
|
54 |
+
|
55 |
+
Returns:
|
56 |
+
coverage: Percentage of words in the summary that are from the source article, measuring the extent to which a summary is a derivative of a text.
|
57 |
+
density: It is defined as the average length of the extractive fragment to which each summary word belongs.
|
58 |
+
compression: It is defined as the word ratio between the articles and its summaries.
|
59 |
+
|
60 |
+
Examples:
|
61 |
+
>>> predictions = ["hello there", "general kenobi"]
|
62 |
+
>>> references = ["hello there", "general kenobi"]
|
63 |
+
>>> bertscore = evaluate.load("datastats")
|
64 |
+
>>> results = bertscore.compute(predictions=predictions, references=references)
|
65 |
+
"""
|
66 |
+
|
67 |
+
|
68 |
+
def find_ngrams(input_list: List[Any], n: int):
|
69 |
+
return zip(*[input_list[i:] for i in range(n)])
|
70 |
+
|
71 |
+
|
72 |
+
def normalize(tokens: List[str], lowercase: bool = False):
|
73 |
+
"""
|
74 |
+
Lowercases and turns tokens into distinct words.
|
75 |
+
"""
|
76 |
+
return [str(t).lower() if not lowercase else str(t) for t in tokens]
|
77 |
+
|
78 |
+
|
79 |
+
class Fragments:
|
80 |
+
|
81 |
+
Match = _namedtuple("Match", ("summary", "text", "length"))
|
82 |
+
|
83 |
+
def __init__(self, summary, text, lowercase: bool = False):
|
84 |
+
if isinstance(summary, str):
|
85 |
+
self.summary = summary.split()
|
86 |
+
else:
|
87 |
+
self.summary = summary
|
88 |
+
if isinstance(text, str):
|
89 |
+
self.text = text.split()
|
90 |
+
else:
|
91 |
+
self.text = text
|
92 |
+
self._norm_summary = normalize(self.summary, lowercase)
|
93 |
+
self._norm_text = normalize(self.text, lowercase)
|
94 |
+
self._match(self._norm_summary, self._norm_text)
|
95 |
+
|
96 |
+
def overlaps(self):
|
97 |
+
"""
|
98 |
+
Return a list of Fragments.Match objects between summary and text.
|
99 |
+
This is a list of named tuples of the form (summary, text, length):
|
100 |
+
"""
|
101 |
+
return self._matches
|
102 |
+
|
103 |
+
def strings(self, min_length=0, summary_base=True):
|
104 |
+
# Compute the strings against the summary or the text?
|
105 |
+
base = self.summary if summary_base else self.text
|
106 |
+
# Generate strings, filtering out strings below the minimum length.
|
107 |
+
strings = [base[i : i + length] for i, j, length in self.overlaps() if length > min_length]
|
108 |
+
return strings
|
109 |
+
|
110 |
+
def coverage(self, summary_base=True):
|
111 |
+
"""
|
112 |
+
Return the COVERAGE score of the summary and text.
|
113 |
+
"""
|
114 |
+
numerator = sum(o.length for o in self.overlaps())
|
115 |
+
if summary_base:
|
116 |
+
denominator = len(self.summary)
|
117 |
+
else:
|
118 |
+
denominator = len(self.text)
|
119 |
+
if denominator == 0:
|
120 |
+
return 0
|
121 |
+
else:
|
122 |
+
return numerator / denominator
|
123 |
+
|
124 |
+
def density(self, summary_base=True):
|
125 |
+
"""
|
126 |
+
Return the DENSITY score of summary and text.
|
127 |
+
"""
|
128 |
+
numerator = sum(o.length ** 2 for o in self.overlaps())
|
129 |
+
if summary_base:
|
130 |
+
denominator = len(self.summary)
|
131 |
+
else:
|
132 |
+
denominator = len(self.text)
|
133 |
+
if denominator == 0:
|
134 |
+
return 0
|
135 |
+
else:
|
136 |
+
return numerator / denominator
|
137 |
+
|
138 |
+
def compression(self, text_to_summary=True):
|
139 |
+
"""
|
140 |
+
Return compression ratio between summary and text.
|
141 |
+
"""
|
142 |
+
ratio = [len(self.text), len(self.summary)]
|
143 |
+
try:
|
144 |
+
if text_to_summary:
|
145 |
+
return ratio[0] / ratio[1]
|
146 |
+
else:
|
147 |
+
return ratio[1] / ratio[0]
|
148 |
+
except ZeroDivisionError:
|
149 |
+
return 0
|
150 |
+
|
151 |
+
def _match(self, a, b):
|
152 |
+
"""
|
153 |
+
Raw procedure for matching summary in text, described in paper.
|
154 |
+
"""
|
155 |
+
self._matches = []
|
156 |
+
a_start = b_start = 0
|
157 |
+
while a_start < len(a):
|
158 |
+
best_match = None
|
159 |
+
best_match_length = 0
|
160 |
+
while b_start < len(b):
|
161 |
+
if a[a_start] == b[b_start]:
|
162 |
+
a_end = a_start
|
163 |
+
b_end = b_start
|
164 |
+
while a_end < len(a) and b_end < len(b) \
|
165 |
+
and b[b_end] == a[a_end]:
|
166 |
+
b_end += 1
|
167 |
+
a_end += 1
|
168 |
+
length = a_end - a_start
|
169 |
+
if length > best_match_length:
|
170 |
+
best_match = Fragments.Match(a_start, b_start, length)
|
171 |
+
best_match_length = length
|
172 |
+
b_start = b_end
|
173 |
+
else:
|
174 |
+
b_start += 1
|
175 |
+
b_start = 0
|
176 |
+
if best_match:
|
177 |
+
if best_match_length > 0:
|
178 |
+
self._matches.append(best_match)
|
179 |
+
a_start += best_match_length
|
180 |
+
else:
|
181 |
+
a_start += 1
|
182 |
+
|
183 |
+
|
184 |
+
class DataStatsMetric(object):
|
185 |
+
|
186 |
+
def __init__(
|
187 |
+
self,
|
188 |
+
n_gram: int = 3,
|
189 |
+
n_workers: int = 24,
|
190 |
+
lowercase: bool = False,
|
191 |
+
tokenize: bool = True
|
192 |
+
):
|
193 |
+
"""
|
194 |
+
Data Statistics metric
|
195 |
+
|
196 |
+
Args:
|
197 |
+
n_gram (int): Compute statistics for n-grams up to and including this length.
|
198 |
+
n_workers (int): Number of processes to use if using multiprocessing.
|
199 |
+
case (bool): Whether to lowercase input before calculating statistics.
|
200 |
+
tokenize (bool): Whether to tokenize the input.
|
201 |
+
"""
|
202 |
+
self.n_gram = n_gram
|
203 |
+
self.n_workers = n_workers
|
204 |
+
self.lowercase = lowercase
|
205 |
+
self.tokenize = tokenize
|
206 |
+
|
207 |
+
def evaluate_example(self, summary, input_text):
|
208 |
+
if self.tokenize:
|
209 |
+
input_text = _en(input_text, disable=["tagger", "parser", "ner", "textcat"])
|
210 |
+
input_text = [tok.text for tok in input_text]
|
211 |
+
summary = _en(summary, disable=["tagger", "parser", "ner", "textcat"])
|
212 |
+
summary = [tok.text for tok in summary]
|
213 |
+
fragments = Fragments(summary, input_text, lowercase=self.lowercase)
|
214 |
+
coverage = fragments.coverage()
|
215 |
+
density = fragments.density()
|
216 |
+
compression = fragments.compression()
|
217 |
+
score_dict = {"coverage": coverage, "density": density, "compression": compression}
|
218 |
+
tokenized_summary = fragments._norm_summary
|
219 |
+
tokenized_text = fragments._norm_text
|
220 |
+
score_dict["summary_length"] = len(tokenized_summary)
|
221 |
+
for i in range(1, self.n_gram + 1):
|
222 |
+
input_ngrams = list(find_ngrams(tokenized_text, i))
|
223 |
+
summ_ngrams = list(find_ngrams(tokenized_summary, i))
|
224 |
+
input_ngrams_set = set(input_ngrams)
|
225 |
+
summ_ngrams_set = set(summ_ngrams)
|
226 |
+
intersect = summ_ngrams_set.intersection(input_ngrams_set)
|
227 |
+
try:
|
228 |
+
score_dict[f"percentage_novel_{i}-gram"] = (len(summ_ngrams_set) \
|
229 |
+
- len(intersect))/float(len(summ_ngrams_set))
|
230 |
+
ngramCounter = Counter()
|
231 |
+
ngramCounter.update(summ_ngrams)
|
232 |
+
repeated = [key for key, val in ngramCounter.items() if val > 1]
|
233 |
+
score_dict[f"percentage_repeated_{i}-gram_in_summ"] = len(repeated)/float(len(summ_ngrams_set))
|
234 |
+
except ZeroDivisionError:
|
235 |
+
continue
|
236 |
+
return score_dict
|
237 |
+
|
238 |
+
def evaluate_batch(self, summaries, input_texts, aggregate=True):
|
239 |
+
corpus_score_dict = Counter()
|
240 |
+
p = Pool(processes=self.n_workers)
|
241 |
+
results = p.starmap(self.evaluate_example, zip(summaries, input_texts))
|
242 |
+
p.close()
|
243 |
+
if aggregate:
|
244 |
+
[corpus_score_dict.update(x) for x in results]
|
245 |
+
for key in corpus_score_dict.keys():
|
246 |
+
corpus_score_dict[key] /= float(len(input_texts))
|
247 |
+
return corpus_score_dict
|
248 |
+
else:
|
249 |
+
return results
|
250 |
+
|
251 |
+
@property
|
252 |
+
def supports_multi_ref(self):
|
253 |
+
return False
|
254 |
+
|
255 |
+
|
256 |
+
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
|
257 |
+
class DataStats(evaluate.Metric):
|
258 |
+
|
259 |
+
def _info(self):
|
260 |
+
return evaluate.MetricInfo(
|
261 |
+
description=_DESCRIPTION,
|
262 |
+
citation=_CITATION,
|
263 |
+
homepage="",
|
264 |
+
inputs_description=_KWARGS_DESCRIPTION,
|
265 |
+
features=[
|
266 |
+
datasets.Features(
|
267 |
+
{
|
268 |
+
"predictions": datasets.Value("string", id="sequence"),
|
269 |
+
"references": datasets.Sequence(datasets.Value("string", id="sequence"), id="references"),
|
270 |
+
}
|
271 |
+
),
|
272 |
+
datasets.Features(
|
273 |
+
{
|
274 |
+
"predictions": datasets.Value("string", id="sequence"),
|
275 |
+
"references": datasets.Value("string", id="sequence"),
|
276 |
+
}
|
277 |
+
),
|
278 |
+
],
|
279 |
+
codebase_urls=["https://github.com/Tiiiger/bert_score"],
|
280 |
+
reference_urls=[
|
281 |
+
"https://github.com/lil-lab/newsroom",
|
282 |
+
"https://arxiv.org/pdf/2007.12626",
|
283 |
+
],
|
284 |
+
)
|
285 |
+
|
286 |
+
def _compute(
|
287 |
+
self,
|
288 |
+
predictions,
|
289 |
+
references,
|
290 |
+
n_gram: int = 3,
|
291 |
+
n_workers: int = 24,
|
292 |
+
lowercase: bool = False,
|
293 |
+
tokenize: bool = True
|
294 |
+
):
|
295 |
+
datastats = DataStatsMetric(n_gram, n_workers, lowercase, tokenize)
|
296 |
+
results = datastats.evaluate_batch(predictions, references)
|
297 |
+
return {
|
298 |
+
"coverage": results['coverage'],
|
299 |
+
"density": results['density'],
|
300 |
+
"compression": results['compression']
|
301 |
+
}
|