xu1998hz commited on
Commit
9bab39e
·
0 Parent(s):

Duplicate from xu1998hz/sescore

Browse files
Files changed (8) hide show
  1. .gitattributes +33 -0
  2. README.md +46 -0
  3. app.py +73 -0
  4. description.md +59 -0
  5. img/logo_sescore.png +0 -0
  6. requirements.txt +3 -0
  7. sescore.py +139 -0
  8. tests.py +17 -0
.gitattributes ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ftz filter=lfs diff=lfs merge=lfs -text
6
+ *.gz filter=lfs diff=lfs merge=lfs -text
7
+ *.h5 filter=lfs diff=lfs merge=lfs -text
8
+ *.joblib filter=lfs diff=lfs merge=lfs -text
9
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
10
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
11
+ *.model filter=lfs diff=lfs merge=lfs -text
12
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
13
+ *.npy filter=lfs diff=lfs merge=lfs -text
14
+ *.npz filter=lfs diff=lfs merge=lfs -text
15
+ *.onnx filter=lfs diff=lfs merge=lfs -text
16
+ *.ot filter=lfs diff=lfs merge=lfs -text
17
+ *.parquet filter=lfs diff=lfs merge=lfs -text
18
+ *.pb filter=lfs diff=lfs merge=lfs -text
19
+ *.pickle filter=lfs diff=lfs merge=lfs -text
20
+ *.pkl filter=lfs diff=lfs merge=lfs -text
21
+ *.pt filter=lfs diff=lfs merge=lfs -text
22
+ *.pth filter=lfs diff=lfs merge=lfs -text
23
+ *.rar filter=lfs diff=lfs merge=lfs -text
24
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
25
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
26
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
27
+ *.tflite filter=lfs diff=lfs merge=lfs -text
28
+ *.tgz filter=lfs diff=lfs merge=lfs -text
29
+ *.wasm filter=lfs diff=lfs merge=lfs -text
30
+ *.xz filter=lfs diff=lfs merge=lfs -text
31
+ *.zip filter=lfs diff=lfs merge=lfs -text
32
+ *.zst filter=lfs diff=lfs merge=lfs -text
33
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: SEScore
3
+ datasets:
4
+ - null
5
+ tags:
6
+ - evaluate
7
+ - metric
8
+ description: 'SEScore: a text generation evaluation metric'
9
+ sdk: gradio
10
+ sdk_version: 3.0.2
11
+ app_file: app.py
12
+ pinned: false
13
+ duplicated_from: xu1998hz/sescore
14
+ ---
15
+
16
+ # Metric Card for SEScore
17
+ ![alt text](https://huggingface.co/spaces/xu1998hz/sescore/blob/main/img/logo_sescore.png)
18
+
19
+ ## Metric Description
20
+ *SEScore is an unsupervised learned evaluation metric trained on synthesized dataset*
21
+
22
+ ## How to Use
23
+
24
+ *Provide simplest possible example for using the metric*
25
+
26
+ ### Inputs
27
+ *SEScore takes input of predictions (a list of candidate translations) and references (a list of reference translations).*
28
+
29
+ ### Output Values
30
+
31
+ *Output value is between 0 to -25*
32
+
33
+ #### Values from Popular Papers
34
+
35
+
36
+ ### Examples
37
+ *Give code examples of the metric being used. Try to include examples that clear up any potential ambiguity left from the metric description above. If possible, provide a range of examples that show both typical and atypical results, as well as examples where a variety of input parameters are passed.*
38
+
39
+ ## Limitations and Bias
40
+ *Note any known limitations or biases that the metric has, with links and references if possible.*
41
+
42
+ ## Citation
43
+ *Cite the source where this metric was introduced.*
44
+
45
+ ## Further References
46
+ *Add any useful further references.*
app.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import evaluate
2
+ import sys
3
+ from pathlib import Path
4
+ from evaluate.utils import infer_gradio_input_types, json_to_string_type, parse_readme, parse_gradio_data, parse_test_cases
5
+
6
+
7
+ def launch_gradio_widget(metric):
8
+ """Launches `metric` widget with Gradio."""
9
+
10
+ try:
11
+ import gradio as gr
12
+ except ImportError as error:
13
+ logger.error("To create a metric widget with Gradio make sure gradio is installed.")
14
+ raise error
15
+
16
+ local_path = Path(sys.path[0])
17
+ # if there are several input types, use first as default.
18
+ if isinstance(metric.features, list):
19
+ (feature_names, feature_types) = zip(*metric.features[0].items())
20
+ else:
21
+ (feature_names, feature_types) = zip(*metric.features.items())
22
+ gradio_input_types = infer_gradio_input_types(feature_types)
23
+
24
+ def compute(data):
25
+ return metric.compute(**parse_gradio_data(data, gradio_input_types))
26
+
27
+ header_html = '''<div style="max-width:800px; margin:auto; float:center; margin-top:0; margin-bottom:0; padding:0;">
28
+ <img src="https://huggingface.co/spaces/xu1998hz/sescore/resolve/main/img/logo_sescore.png" style="margin:0; padding:0; margin-top:-10px; margin-bottom:-50px;">
29
+ </div>
30
+ <h2 style='margin-top: 5pt; padding-top:10pt;'>About <i>SEScore</i></h2>
31
+
32
+ <p><b>SEScore</b> is a reference-based text-generation evaluation metric that requires no pre-human-annotated error data,
33
+ described in our paper <a href="https://arxiv.org/abs/2210.05035"><b>"Not All Errors are Equal: Learning Text Generation Metrics using
34
+ Stratified Error Synthesis"</b></a> from EMNLP 2022.</p>
35
+
36
+ <p>Its effectiveness over prior methods like BLEU, BERTScore, BARTScore, PRISM, COMET and BLEURT has been demonstrated on a diverse set of language generation tasks, including
37
+ translation, captioning, and web text generation. <a href="https://twitter.com/LChoshen/status/1580136005654700033">Readers have even described SEScore as "one unsupervised evaluation to rule them all"</a>
38
+ and we are very excited to share it with you!</p>
39
+
40
+ <h2 style='margin-top: 10pt; padding-top:0;'>Try it yourself!</h2>
41
+ <p>Provide sample (gold) reference text and (model output) predicted text below and see how SEScore rates them! It is most performant
42
+ in a relative ranking setting, so in general <b>it will rank better predictions higher than worse ones.</b> Providing useful
43
+ absolute numbers based on SEScore is an ongoing direction of investigation.</p>
44
+ '''.replace('\n',' ')
45
+
46
+
47
+ tail_markdown = parse_readme(local_path / "description.md")
48
+
49
+
50
+ iface = gr.Interface(
51
+ fn=compute,
52
+ inputs=gr.inputs.Dataframe(
53
+ headers=feature_names,
54
+ col_count=len(feature_names),
55
+ row_count=2,
56
+ datatype=json_to_string_type(gradio_input_types),
57
+ ),
58
+ outputs=gr.outputs.Textbox(label=metric.name),
59
+ description=header_html,
60
+ #title=f"SEScore Metric Usage Example",
61
+ article=tail_markdown,
62
+ # TODO: load test cases and use them to populate examples
63
+ # examples=[parse_test_cases(test_cases, feature_names, gradio_input_types)]
64
+ )
65
+
66
+ print(dir(iface))
67
+
68
+ iface.launch()
69
+
70
+
71
+
72
+ module = evaluate.load("xu1998hz/sescore")
73
+ launch_gradio_widget(module)
description.md ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Installation and usage
2
+
3
+ ```bash
4
+ pip install -r requirements.txt
5
+ ```
6
+
7
+ Minimal example (evaluating English text generation)
8
+ ```python
9
+ import evaluate
10
+ sescore = evaluate.load("xu1998hz/sescore")
11
+ score = sescore.compute(
12
+ references=['sescore is a simple but effective next-generation text evaluation metric'],
13
+ predictions=['sescore is simple effective text evaluation metric for next generation']
14
+ )
15
+ ```
16
+
17
+ *SEScore* compares a list of references (gold translation/generated output examples) with a same-length list of candidate generated samples. Currently, the output range is learned and scores are most useful in relative ranking scenarios rather than absolute comparisons. We are producing a series of rescaling options to make absolute SEScore-based scaling more effective.
18
+
19
+
20
+ ### Available pre-trained models
21
+
22
+ Currently, the following language/model pairs are available:
23
+
24
+ | Language | pretrained data | pretrained model link |
25
+ |----------|-----------------|-----------------------|
26
+ | English | MT | [xu1998hz/sescore_english_mt](https://huggingface.co/xu1998hz/sescore_english_mt) |
27
+ | German | MT | [xu1998hz/sescore_german_mt](https://huggingface.co/xu1998hz/sescore_german_mt) |
28
+ | English | webNLG17 | [xu1998hz/sescore_english_webnlg17](https://huggingface.co/xu1998hz/sescore_english_webnlg17) |
29
+ | English | CoCo captions | [xu1998hz/sescore_english_coco](https://huggingface.co/xu1998hz/sescore_english_coco) |
30
+
31
+
32
+ Please contact repo maintainer Wenda Xu to add your models!
33
+
34
+ ## Limitations
35
+
36
+ *SEScore* is trained on synthetic data in-domain.
37
+ Although this data is generated to simulate user-relevant errors like deletion and spurious insertion, it may be limited in its ability to simulate humanlike errors.
38
+ Model applicability is domain-specific (e.g., CoCo caption-trained model will be better for captioning than MT-trained).
39
+
40
+ We are in the process of producing and benchmarking general language-level *SEScore* variants.
41
+
42
+ ## Citation
43
+
44
+ If you find our work useful, please cite the following:
45
+
46
+ ```bibtex
47
+ @inproceedings{xu-etal-2022-not,
48
+ title={Not All Errors are Equal: Learning Text Generation Metrics using Stratified Error Synthesis},
49
+ author={Xu, Wenda and Tuan, Yi-lin and Lu, Yujie and Saxon, Michael and Li, Lei and Wang, William Yang},
50
+ booktitle ={Proceedings of the 2022 Conference on Empirical Methods in Natural Language Processing},
51
+ month={dec},
52
+ year={2022},
53
+ url={https://arxiv.org/abs/2210.05035}
54
+ }
55
+ ```
56
+
57
+ ## Acknowledgements
58
+
59
+ The work of the [COMET](https://github.com/Unbabel/COMET) maintainers at [Unbabel](https://duckduckgo.com/?t=ffab&q=unbabel&ia=web) has been instrumental in producing SEScore.
img/logo_sescore.png ADDED
requirements.txt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ git+https://github.com/huggingface/evaluate@main
2
+ unbabel-comet
3
+ torch
sescore.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """SEScore: a text generation evaluation metric """
15
+
16
+ import evaluate
17
+ import datasets
18
+
19
+ import comet
20
+ from typing import Dict
21
+ import torch
22
+ from comet.encoders.base import Encoder
23
+ from comet.encoders.bert import BERTEncoder
24
+ from transformers import AutoModel, AutoTokenizer
25
+
26
+ class robertaEncoder(BERTEncoder):
27
+ def __init__(self, pretrained_model: str) -> None:
28
+ super(Encoder, self).__init__()
29
+ self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model)
30
+ self.model = AutoModel.from_pretrained(
31
+ pretrained_model, add_pooling_layer=False
32
+ )
33
+ self.model.encoder.output_hidden_states = True
34
+
35
+ @classmethod
36
+ def from_pretrained(cls, pretrained_model: str) -> Encoder:
37
+ return robertaEncoder(pretrained_model)
38
+
39
+ def forward(
40
+ self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs
41
+ ) -> Dict[str, torch.Tensor]:
42
+ last_hidden_states, _, all_layers = self.model(
43
+ input_ids=input_ids,
44
+ attention_mask=attention_mask,
45
+ output_hidden_states=True,
46
+ return_dict=False,
47
+ )
48
+ return {
49
+ "sentemb": last_hidden_states[:, 0, :],
50
+ "wordemb": last_hidden_states,
51
+ "all_layers": all_layers,
52
+ "attention_mask": attention_mask,
53
+ }
54
+
55
+
56
+ # TODO: Add BibTeX citation
57
+ _CITATION = """\
58
+ @inproceedings{xu-etal-2022-not,
59
+ title={Not All Errors are Equal: Learning Text Generation Metrics using Stratified Error Synthesis},
60
+ author={Xu, Wenda and Tuan, Yi-lin and Lu, Yujie and Saxon, Michael and Li, Lei and Wang, William Yang},
61
+ booktitle ={Proceedings of the 2022 Conference on Empirical Methods in Natural Language Processing},
62
+ month={dec},
63
+ year={2022},
64
+ url={https://arxiv.org/abs/2210.05035}
65
+ }
66
+ """
67
+
68
+ _DESCRIPTION = """\
69
+ SEScore is an evaluation metric that trys to compute an overall score to measure text generation quality.
70
+ """
71
+
72
+ _KWARGS_DESCRIPTION = """
73
+ Calculates how good are predictions given some references
74
+ Args:
75
+ predictions: list of candidate outputs
76
+ references: list of references
77
+ Returns:
78
+ {"mean_score": mean_score, "scores": scores}
79
+
80
+ Examples:
81
+ >>> import evaluate
82
+ >>> sescore = evaluate.load("xu1998hz/sescore")
83
+ >>> score = sescore.compute(
84
+ references=['sescore is a simple but effective next-generation text evaluation metric'],
85
+ predictions=['sescore is simple effective text evaluation metric for next generation']
86
+ )
87
+ """
88
+
89
+ # TODO: Define external resources urls if needed
90
+ BAD_WORDS_URL = "http://url/to/external/resource/bad_words.txt"
91
+
92
+
93
+ @evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
94
+ class SEScore(evaluate.Metric):
95
+ """SEScore"""
96
+
97
+ def _info(self):
98
+ # TODO: Specifies the evaluate.EvaluationModuleInfo object
99
+ return evaluate.MetricInfo(
100
+ # This is the description that will appear on the modules page.
101
+ module_type="metric",
102
+ description=_DESCRIPTION,
103
+ citation=_CITATION,
104
+ inputs_description=_KWARGS_DESCRIPTION,
105
+ # This defines the format of each prediction and reference
106
+ features=datasets.Features({
107
+ 'predictions': datasets.Value("string", id="sequence"),
108
+ 'references': datasets.Value("string", id="sequence"),
109
+ }),
110
+ # Homepage of the module for documentation
111
+ homepage="http://module.homepage",
112
+ # Additional links to the codebase or references
113
+ codebase_urls=["http://github.com/path/to/codebase/of/new_module"],
114
+ reference_urls=["http://path.to.reference.url/new_module"]
115
+ )
116
+
117
+ def _download_and_prepare(self, dl_manager):
118
+ """download SEScore checkpoints to compute the scores"""
119
+ # Download SEScore checkpoint
120
+ from comet import load_from_checkpoint
121
+ import os
122
+ from huggingface_hub import snapshot_download
123
+ # initialize roberta into str2encoder
124
+ comet.encoders.str2encoder['RoBERTa'] = robertaEncoder
125
+ print("config name: ", self.config_name)
126
+ if self.config_name == "default":
127
+ destination = snapshot_download(repo_id="xu1998hz/sescore_english_mt", revision="main")
128
+ self.scorer = load_from_checkpoint(f'{destination}/checkpoint/sescore_english_mt.ckpt')
129
+ else:
130
+ print("Config name is not supported!")
131
+
132
+ def _compute(self, predictions, references, gpus=None, progress_bar=False):
133
+ if gpus is None:
134
+ gpus = 1 if torch.cuda.is_available() else 0
135
+
136
+ data = {"src": references, "mt": predictions}
137
+ data = [dict(zip(data, t)) for t in zip(*data.values())]
138
+ scores, mean_score = self.scorer.predict(data, gpus=gpus, progress_bar=progress_bar)
139
+ return {"mean_score": mean_score, "scores": scores}
tests.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ test_cases = [
2
+ {
3
+ "predictions": [0, 0],
4
+ "references": [1, 1],
5
+ "result": {"metric_score": 0}
6
+ },
7
+ {
8
+ "predictions": [1, 1],
9
+ "references": [1, 1],
10
+ "result": {"metric_score": 1}
11
+ },
12
+ {
13
+ "predictions": [1, 0],
14
+ "references": [1, 1],
15
+ "result": {"metric_score": 0.5}
16
+ }
17
+ ]