yangwang825 commited on
Commit
c7a9171
·
verified ·
1 Parent(s): db0ef1e

Create datastats.py

Browse files
Files changed (1) hide show
  1. datastats.py +301 -0
datastats.py ADDED
@@ -0,0 +1,301 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ DataStats metric. """
2
+
3
+ import functools
4
+ from collections import Counter
5
+ from multiprocessing import Pool
6
+ from contextlib import contextmanager
7
+ from typing import List, Any, Dict, Optional
8
+ from collections import namedtuple as _namedtuple
9
+
10
+ import spacy
11
+ import datasets
12
+ import evaluate
13
+ from packaging import version
14
+
15
+ try:
16
+ _en = spacy.load('en_core_web_sm')
17
+ except OSError as stderr:
18
+ spacy.cli.download('en_core_web_sm')
19
+ _en = spacy.load('en_core_web_sm')
20
+
21
+ @contextmanager
22
+ def filter_logging_context():
23
+
24
+ def filter_log(record):
25
+ return False if "This is expected if you are initialising" in record.msg else True
26
+
27
+ logger = datasets.utils.logging.get_logger("transformers.modeling_utils")
28
+ logger.addFilter(filter_log)
29
+
30
+ try:
31
+ yield
32
+ finally:
33
+ logger.removeFilter(filter_log)
34
+
35
+
36
+ _CITATION = """\
37
+ @article{grusky2018newsroom,
38
+ title={Newsroom: A dataset of 1.3 million summaries with diverse extractive strategies},
39
+ author={Grusky, Max and Naaman, Mor and Artzi, Yoav},
40
+ journal={arXiv preprint arXiv:1804.11283},
41
+ year={2018}
42
+ }
43
+ """
44
+
45
+ _DESCRIPTION = """\
46
+ DataStats examines summarization strategies using three measures that capture the degree of text overlap between the summary and article, and the rate of compression of the information conveyed.
47
+ """
48
+
49
+ _KWARGS_DESCRIPTION = """
50
+ BERTScore Metrics with the hashcode from a source against one or more references.
51
+ Args:
52
+ predictions (list of str): Prediction/candidate sentences.
53
+ references (list of str or list of list of str): Reference sentences.
54
+
55
+ Returns:
56
+ coverage: Percentage of words in the summary that are from the source article, measuring the extent to which a summary is a derivative of a text.
57
+ density: It is defined as the average length of the extractive fragment to which each summary word belongs.
58
+ compression: It is defined as the word ratio between the articles and its summaries.
59
+
60
+ Examples:
61
+ >>> predictions = ["hello there", "general kenobi"]
62
+ >>> references = ["hello there", "general kenobi"]
63
+ >>> bertscore = evaluate.load("datastats")
64
+ >>> results = bertscore.compute(predictions=predictions, references=references)
65
+ """
66
+
67
+
68
+ def find_ngrams(input_list: List[Any], n: int):
69
+ return zip(*[input_list[i:] for i in range(n)])
70
+
71
+
72
+ def normalize(tokens: List[str], lowercase: bool = False):
73
+ """
74
+ Lowercases and turns tokens into distinct words.
75
+ """
76
+ return [str(t).lower() if not lowercase else str(t) for t in tokens]
77
+
78
+
79
+ class Fragments:
80
+
81
+ Match = _namedtuple("Match", ("summary", "text", "length"))
82
+
83
+ def __init__(self, summary, text, lowercase: bool = False):
84
+ if isinstance(summary, str):
85
+ self.summary = summary.split()
86
+ else:
87
+ self.summary = summary
88
+ if isinstance(text, str):
89
+ self.text = text.split()
90
+ else:
91
+ self.text = text
92
+ self._norm_summary = normalize(self.summary, lowercase)
93
+ self._norm_text = normalize(self.text, lowercase)
94
+ self._match(self._norm_summary, self._norm_text)
95
+
96
+ def overlaps(self):
97
+ """
98
+ Return a list of Fragments.Match objects between summary and text.
99
+ This is a list of named tuples of the form (summary, text, length):
100
+ """
101
+ return self._matches
102
+
103
+ def strings(self, min_length=0, summary_base=True):
104
+ # Compute the strings against the summary or the text?
105
+ base = self.summary if summary_base else self.text
106
+ # Generate strings, filtering out strings below the minimum length.
107
+ strings = [base[i : i + length] for i, j, length in self.overlaps() if length > min_length]
108
+ return strings
109
+
110
+ def coverage(self, summary_base=True):
111
+ """
112
+ Return the COVERAGE score of the summary and text.
113
+ """
114
+ numerator = sum(o.length for o in self.overlaps())
115
+ if summary_base:
116
+ denominator = len(self.summary)
117
+ else:
118
+ denominator = len(self.text)
119
+ if denominator == 0:
120
+ return 0
121
+ else:
122
+ return numerator / denominator
123
+
124
+ def density(self, summary_base=True):
125
+ """
126
+ Return the DENSITY score of summary and text.
127
+ """
128
+ numerator = sum(o.length ** 2 for o in self.overlaps())
129
+ if summary_base:
130
+ denominator = len(self.summary)
131
+ else:
132
+ denominator = len(self.text)
133
+ if denominator == 0:
134
+ return 0
135
+ else:
136
+ return numerator / denominator
137
+
138
+ def compression(self, text_to_summary=True):
139
+ """
140
+ Return compression ratio between summary and text.
141
+ """
142
+ ratio = [len(self.text), len(self.summary)]
143
+ try:
144
+ if text_to_summary:
145
+ return ratio[0] / ratio[1]
146
+ else:
147
+ return ratio[1] / ratio[0]
148
+ except ZeroDivisionError:
149
+ return 0
150
+
151
+ def _match(self, a, b):
152
+ """
153
+ Raw procedure for matching summary in text, described in paper.
154
+ """
155
+ self._matches = []
156
+ a_start = b_start = 0
157
+ while a_start < len(a):
158
+ best_match = None
159
+ best_match_length = 0
160
+ while b_start < len(b):
161
+ if a[a_start] == b[b_start]:
162
+ a_end = a_start
163
+ b_end = b_start
164
+ while a_end < len(a) and b_end < len(b) \
165
+ and b[b_end] == a[a_end]:
166
+ b_end += 1
167
+ a_end += 1
168
+ length = a_end - a_start
169
+ if length > best_match_length:
170
+ best_match = Fragments.Match(a_start, b_start, length)
171
+ best_match_length = length
172
+ b_start = b_end
173
+ else:
174
+ b_start += 1
175
+ b_start = 0
176
+ if best_match:
177
+ if best_match_length > 0:
178
+ self._matches.append(best_match)
179
+ a_start += best_match_length
180
+ else:
181
+ a_start += 1
182
+
183
+
184
+ class DataStatsMetric(object):
185
+
186
+ def __init__(
187
+ self,
188
+ n_gram: int = 3,
189
+ n_workers: int = 24,
190
+ lowercase: bool = False,
191
+ tokenize: bool = True
192
+ ):
193
+ """
194
+ Data Statistics metric
195
+
196
+ Args:
197
+ n_gram (int): Compute statistics for n-grams up to and including this length.
198
+ n_workers (int): Number of processes to use if using multiprocessing.
199
+ case (bool): Whether to lowercase input before calculating statistics.
200
+ tokenize (bool): Whether to tokenize the input.
201
+ """
202
+ self.n_gram = n_gram
203
+ self.n_workers = n_workers
204
+ self.lowercase = lowercase
205
+ self.tokenize = tokenize
206
+
207
+ def evaluate_example(self, summary, input_text):
208
+ if self.tokenize:
209
+ input_text = _en(input_text, disable=["tagger", "parser", "ner", "textcat"])
210
+ input_text = [tok.text for tok in input_text]
211
+ summary = _en(summary, disable=["tagger", "parser", "ner", "textcat"])
212
+ summary = [tok.text for tok in summary]
213
+ fragments = Fragments(summary, input_text, lowercase=self.lowercase)
214
+ coverage = fragments.coverage()
215
+ density = fragments.density()
216
+ compression = fragments.compression()
217
+ score_dict = {"coverage": coverage, "density": density, "compression": compression}
218
+ tokenized_summary = fragments._norm_summary
219
+ tokenized_text = fragments._norm_text
220
+ score_dict["summary_length"] = len(tokenized_summary)
221
+ for i in range(1, self.n_gram + 1):
222
+ input_ngrams = list(find_ngrams(tokenized_text, i))
223
+ summ_ngrams = list(find_ngrams(tokenized_summary, i))
224
+ input_ngrams_set = set(input_ngrams)
225
+ summ_ngrams_set = set(summ_ngrams)
226
+ intersect = summ_ngrams_set.intersection(input_ngrams_set)
227
+ try:
228
+ score_dict[f"percentage_novel_{i}-gram"] = (len(summ_ngrams_set) \
229
+ - len(intersect))/float(len(summ_ngrams_set))
230
+ ngramCounter = Counter()
231
+ ngramCounter.update(summ_ngrams)
232
+ repeated = [key for key, val in ngramCounter.items() if val > 1]
233
+ score_dict[f"percentage_repeated_{i}-gram_in_summ"] = len(repeated)/float(len(summ_ngrams_set))
234
+ except ZeroDivisionError:
235
+ continue
236
+ return score_dict
237
+
238
+ def evaluate_batch(self, summaries, input_texts, aggregate=True):
239
+ corpus_score_dict = Counter()
240
+ p = Pool(processes=self.n_workers)
241
+ results = p.starmap(self.evaluate_example, zip(summaries, input_texts))
242
+ p.close()
243
+ if aggregate:
244
+ [corpus_score_dict.update(x) for x in results]
245
+ for key in corpus_score_dict.keys():
246
+ corpus_score_dict[key] /= float(len(input_texts))
247
+ return corpus_score_dict
248
+ else:
249
+ return results
250
+
251
+ @property
252
+ def supports_multi_ref(self):
253
+ return False
254
+
255
+
256
+ @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
257
+ class DataStats(evaluate.Metric):
258
+
259
+ def _info(self):
260
+ return evaluate.MetricInfo(
261
+ description=_DESCRIPTION,
262
+ citation=_CITATION,
263
+ homepage="",
264
+ inputs_description=_KWARGS_DESCRIPTION,
265
+ features=[
266
+ datasets.Features(
267
+ {
268
+ "predictions": datasets.Value("string", id="sequence"),
269
+ "references": datasets.Sequence(datasets.Value("string", id="sequence"), id="references"),
270
+ }
271
+ ),
272
+ datasets.Features(
273
+ {
274
+ "predictions": datasets.Value("string", id="sequence"),
275
+ "references": datasets.Value("string", id="sequence"),
276
+ }
277
+ ),
278
+ ],
279
+ codebase_urls=["https://github.com/Tiiiger/bert_score"],
280
+ reference_urls=[
281
+ "https://github.com/lil-lab/newsroom",
282
+ "https://arxiv.org/pdf/2007.12626",
283
+ ],
284
+ )
285
+
286
+ def _compute(
287
+ self,
288
+ predictions,
289
+ references,
290
+ n_gram: int = 3,
291
+ n_workers: int = 24,
292
+ lowercase: bool = False,
293
+ tokenize: bool = True
294
+ ):
295
+ datastats = DataStatsMetric(n_gram, n_workers, lowercase, tokenize)
296
+ results = datastats.evaluate_batch(predictions, references)
297
+ return {
298
+ "coverage": results['coverage'],
299
+ "density": results['density'],
300
+ "compression": results['compression']
301
+ }