# coding=utf-8 # Copyright 2021-present, the Recognai S.L. team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import functools import logging from enum import Enum from typing import Any, Dict, List, Optional, Tuple, Type, Union import pandas as pd from pkg_resources import parse_version from argilla.client.models import ( Record, Text2TextRecord, TextClassificationRecord, TokenAttributions, TokenClassificationRecord, ) from argilla.client.sdk.datasets.models import TaskType from argilla.utils.span_utils import SpanUtils _LOGGER = logging.getLogger(__name__) def _requires_datasets(func): @functools.wraps(func) def check_if_datasets_installed(*args, **kwargs): try: import datasets except ModuleNotFoundError: raise ModuleNotFoundError( f"'datasets' must be installed to use `{func.__name__}`! " "You can install 'datasets' with the command: `pip install datasets>1.17.0`" ) if not (parse_version(datasets.__version__) > parse_version("1.17.0")): raise ModuleNotFoundError( "Version >1.17.0 of 'datasets' must be installed to use `to_datasets`! " "You can update 'datasets' with the command: `pip install -U datasets>1.17.0`" ) return func(*args, **kwargs) return check_if_datasets_installed def _requires_spacy(func): @functools.wraps(func) def check_if_spacy_installed(*args, **kwargs): try: import spacy except ModuleNotFoundError: raise ModuleNotFoundError( f"'spacy' must be installed to use `{func.__name__}`" "You can install 'spacy' with the command: `pip install spacy`" ) return func(*args, **kwargs) return check_if_spacy_installed class DatasetBase: """The Dataset classes are containers for argilla records. This is the base class to facilitate the implementation for each record type. Args: records: A list of argilla records. Raises: WrongRecordTypeError: When the record type in the provided list does not correspond to the dataset type. """ _RECORD_TYPE = None # record fields that can hold multiple input columns from a datasets.Dataset or a pandas.DataFrame _RECORD_FIELDS_WITH_MULTIPLE_INPUT_COLUMNS = ["inputs", "metadata"] @classmethod def _record_init_args(cls) -> List[str]: """ Helper the returns the field list available for creation of inner records. The ``_RECORD_TYPE.__fields__`` will be returned as default """ return [field for field in cls._RECORD_TYPE.__fields__] def __init__(self, records: Optional[List[Record]] = None): if self._RECORD_TYPE is None: raise NotImplementedError( "A Dataset implementation has to define a `_RECORD_TYPE`!" ) self._records = records or [] if self._records: self._validate_record_type() def _validate_record_type(self): """Validates the record type. Raises: WrongRecordTypeError: When the record type in the provided list does not correspond to the dataset type. """ record_types = {type(rec): None for rec in self._records} if len(record_types) > 1: raise WrongRecordTypeError( f"A {type(self).__name__} must only contain {self._RECORD_TYPE.__name__}s, " f"but you provided various types: {[rt.__name__ for rt in record_types.keys()]}" ) elif next(iter(record_types)) is not self._RECORD_TYPE: raise WrongRecordTypeError( f"A {type(self).__name__} must only contain {self._RECORD_TYPE.__name__}s, " f"but you provided {list(record_types.keys())[0].__name__}s." ) def __iter__(self): return self._records.__iter__() def __getitem__(self, key): return self._records[key] def __setitem__(self, key, value): if type(value) is not self._RECORD_TYPE: raise WrongRecordTypeError( f"You are only allowed to set a record of type {self._RECORD_TYPE} in this dataset, but you provided {type(value)}" ) self._records[key] = value def __delitem__(self, key): del self._records[key] def __len__(self) -> int: return len(self._records) @_requires_datasets def to_datasets(self) -> "datasets.Dataset": """Exports your records to a `datasets.Dataset`. Returns: A `datasets.Dataset` containing your records. """ import datasets ds_dict = self._to_datasets_dict() # TODO: THIS FIELD IS ONLY AT CLIENT API LEVEL. NOT SENSE HERE FOR NOW if "search_keywords" in ds_dict: del ds_dict["search_keywords"] try: dataset = datasets.Dataset.from_dict(ds_dict) # try without metadata, since it is more prone to incompatible structures except Exception: del ds_dict["metadata"] dataset = datasets.Dataset.from_dict(ds_dict) _LOGGER.warning( "The 'metadata' of the records were removed, since it was incompatible with the 'datasets' format." ) return dataset def _to_datasets_dict(self) -> Dict: """Helper method to transform a argilla dataset into a dict that is compatible with `datasets.Dataset`""" raise NotImplementedError @classmethod def from_datasets(cls, dataset: "datasets.Dataset", **kwargs) -> "Dataset": """Imports records from a `datasets.Dataset`. Columns that are not supported are ignored. Args: dataset: A datasets Dataset from which to import the records. Returns: The imported records in a argilla Dataset. """ raise NotImplementedError @classmethod def _prepare_dataset_and_column_mapping( cls, dataset: "datasets.Dataset", column_mapping: Dict[str, Union[str, List[str]]], ) -> Tuple["datasets.Dataset", Dict[str, List[str]]]: """Renames and removes columns, and extracts the mapping of the columns to be joined. Args: dataset: A datasets Dataset from which to import the records. column_mapping: Mappings from record fields to column names. Returns: The prepared dataset and a mapping for the columns to be joined """ import datasets if isinstance(dataset, datasets.DatasetDict): raise ValueError( "`datasets.DatasetDict` are not supported. Please, select the dataset split before." ) # clean column mappings column_mapping = { key: val for key, val in column_mapping.items() if val is not None } cols_to_be_renamed, cols_to_be_joined = {}, {} for field, col in column_mapping.items(): if field in cls._RECORD_FIELDS_WITH_MULTIPLE_INPUT_COLUMNS: cols_to_be_joined[field] = [col] if isinstance(col, str) else col else: cols_to_be_renamed[col] = field dataset = dataset.rename_columns(cols_to_be_renamed) dataset = cls._remove_unsupported_columns( dataset, extra_columns=[col for cols in cols_to_be_joined.values() for col in cols], ) return dataset, cols_to_be_joined @classmethod def _remove_unsupported_columns( cls, dataset: "datasets.Dataset", extra_columns: List[str], ) -> "datasets.Dataset": """Helper function to remove unsupported columns from the `datasets.Dataset` following the record type. Args: dataset: The dataset. extra_columns: Extra columns to be kept. Returns: The dataset with unsupported columns removed. """ not_supported_columns = [ col for col in dataset.column_names if col not in cls._record_init_args() + extra_columns ] if not_supported_columns: _LOGGER.warning( f"Following columns are not supported by the {cls._RECORD_TYPE.__name__}" f" model and are ignored: {not_supported_columns}" ) dataset = dataset.remove_columns(not_supported_columns) return dataset @staticmethod def _join_datasets_columns_and_delete( row: Dict[str, Any], columns: List[str] ) -> Dict[str, Any]: """Joins columns of a `datasets.Dataset` row into a dict, and deletes the single columns. Updates the ``row`` dictionary! Args: row: A row of a `datasets.Dataset` columns: Name of the columns to be joined and deleted from the row. Returns: A dict containing the columns and its values. """ joined_cols = {} for col in columns: joined_cols[col] = row[col] del row[col] return joined_cols @staticmethod def _parse_datasets_column_with_classlabel( column_value: Union[str, List[str], int, List[int]], feature: Optional[Any], ) -> Optional[Union[str, List[str], int, List[int]]]: """Helper function to parse a datasets.Dataset column with a potential ClassLabel feature. Args: column_value: The value from the datasets Dataset column. feature: The feature of the annotation column to optionally convert ints to strs. Returns: The column value optionally converted to str, or None if the conversion fails. """ import datasets # extract ClassLabel feature if isinstance(feature, list): feature = feature[0] if isinstance(feature, datasets.Sequence): feature = feature.feature if not isinstance(feature, datasets.ClassLabel): feature = None if feature is None: return column_value try: return feature.int2str(column_value) # integers don't have to map to the names ... # it seems that sometimes -1 is used to denote "no label" except ValueError: return None def to_pandas(self) -> pd.DataFrame: """Exports your records to a `pandas.DataFrame`. Returns: A `datasets.Dataset` containing your records. """ return pd.DataFrame(map(dict, self._records)) @classmethod def from_pandas(cls, dataframe: pd.DataFrame) -> "Dataset": """Imports records from a `pandas.DataFrame`. Columns that are not supported are ignored. Args: dataframe: A pandas DataFrame from which to import the records. Returns: The imported records in a argilla Dataset. """ not_supported_columns = [ col for col in dataframe.columns if col not in cls._record_init_args() ] if not_supported_columns: _LOGGER.warning( f"Following columns are not supported by the {cls._RECORD_TYPE.__name__} model " f"and are ignored: {not_supported_columns}" ) dataframe = dataframe.drop(columns=not_supported_columns) return cls._from_pandas(dataframe) @classmethod def _from_pandas(cls, dataframe: pd.DataFrame) -> "Dataset": """Helper method to create a argilla Dataset from a pandas DataFrame. Must be implemented by the child class. Args: dataframe: A pandas DataFrame Returns: A argilla Dataset """ raise NotImplementedError @_requires_datasets def prepare_for_training(self, **kwargs) -> "datasets.Dataset": """Prepares the dataset for training. Args: **kwargs: Specific to the task of the dataset. Returns: A datasets Dataset. """ raise NotImplementedError def _prepend_docstring(record_type: Type[Record]): docstring = f"""This Dataset contains {record_type.__name__} records. It allows you to export/import records into/from different formats, loop over the records, and access them by index. Args: records: A list of `{record_type.__name__}`s. Raises: WrongRecordTypeError: When the record type in the provided list does not correspond to the dataset type. """ def docstring_decorator(cls): cls.__doc__ = docstring + (cls.__doc__ or "") return cls return docstring_decorator @_prepend_docstring(TextClassificationRecord) class DatasetForTextClassification(DatasetBase): """ Examples: >>> # Import/export records: >>> import argilla as rg >>> dataset = rg.DatasetForTextClassification.from_pandas(my_dataframe) >>> dataset.to_datasets() >>> >>> # Looping over the dataset: >>> for record in dataset: ... print(record) >>> >>> # Passing in a list of records: >>> records = [ ... rg.TextClassificationRecord(text="example"), ... rg.TextClassificationRecord(text="another example"), ... ] >>> dataset = rg.DatasetForTextClassification(records) >>> assert len(dataset) == 2 >>> >>> # Indexing into the dataset: >>> dataset[0] ... rg.TextClassificationRecord(text="example") >>> dataset[0] = rg.TextClassificationRecord(text="replaced example") """ _RECORD_TYPE = TextClassificationRecord def __init__(self, records: Optional[List[TextClassificationRecord]] = None): # we implement this to have more specific type hints super().__init__(records=records) @classmethod def from_datasets( cls, dataset: "datasets.Dataset", text: Optional[str] = None, id: Optional[str] = None, inputs: Optional[Union[str, List[str]]] = None, annotation: Optional[str] = None, metadata: Optional[Union[str, List[str]]] = None, ) -> "DatasetForTextClassification": """Imports records from a `datasets.Dataset`. Columns that are not supported are ignored. Args: dataset: A datasets Dataset from which to import the records. text: The field name used as record text. Default: `None` id: The field name used as record id. Default: `None` inputs: A list of field names used for record inputs. Default: `None` annotation: The field name used as record annotation. Default: `None` metadata: The field name used as record metadata. Default: `None` Returns: The imported records in a argilla Dataset. Examples: >>> import datasets >>> ds = datasets.Dataset.from_dict({ ... "inputs": ["example"], ... "prediction": [ ... [{"label": "LABEL1", "score": 0.9}, {"label": "LABEL2", "score": 0.1}] ... ] ... }) >>> DatasetForTextClassification.from_datasets(ds) """ dataset, cols_to_be_joined = cls._prepare_dataset_and_column_mapping( dataset, dict( text=text, id=id, inputs=inputs, annotation=annotation, metadata=metadata, ), ) records = [] for row in dataset: row["inputs"] = cls._parse_inputs_field( row, cols_to_be_joined.get("inputs") ) if row.get("inputs") is not None and row.get("text") is not None: del row["text"] if row.get("annotation") is not None: row["annotation"] = cls._parse_datasets_column_with_classlabel( row["annotation"], dataset.features["annotation"] ) if row.get("prediction"): row["prediction"] = ( [ ( pred["label"], pred["score"], ) for pred in row["prediction"] ] if row["prediction"] is not None else None ) if row.get("explanation"): row["explanation"] = ( { key: [ TokenAttributions(**tokattr_kwargs) for tokattr_kwargs in val ] for key, val in row["explanation"].items() } if row["explanation"] is not None else None ) if cols_to_be_joined.get("metadata"): row["metadata"] = cls._join_datasets_columns_and_delete( row, cols_to_be_joined["metadata"] ) records.append(TextClassificationRecord.parse_obj(row)) return cls(records) @classmethod def _parse_inputs_field( cls, row: Dict[str, Any], columns: Optional[List[str]], ) -> Optional[Union[Dict[str, str], str]]: """Helper function to parse the inputs field. Args: row: A row of the dataset.Datasets columns: A list of columns to be joined for the inputs field, optional. Returns: None, a dictionary or a string as input for the inputs field. """ inputs = row.get("inputs") if columns is not None: inputs = cls._join_datasets_columns_and_delete(row, columns) if isinstance(inputs, dict): inputs = {key: val for key, val in inputs.items() if val is not None} return inputs @classmethod def from_pandas( # we implement this to have more specific type hints cls, dataframe: pd.DataFrame, ) -> "DatasetForTextClassification": return super().from_pandas(dataframe) def _to_datasets_dict(self) -> Dict: # create a dict first, where we make the necessary transformations ds_dict = {} for key in self._RECORD_TYPE.__fields__: if key == "prediction": ds_dict[key] = [ [{"label": pred[0], "score": pred[1]} for pred in rec.prediction] if rec.prediction is not None else None for rec in self._records ] elif key == "explanation": ds_dict[key] = [ { key: list(map(dict, tokattrs)) for key, tokattrs in rec.explanation.items() } if rec.explanation is not None else None for rec in self._records ] elif key == "id": ds_dict[key] = [ None if rec.id is None else str(rec.id) for rec in self._records ] elif key == "metadata": ds_dict[key] = [getattr(rec, key) or None for rec in self._records] else: ds_dict[key] = [getattr(rec, key) for rec in self._records] return ds_dict @classmethod def _from_pandas(cls, dataframe: pd.DataFrame) -> "DatasetForTextClassification": return cls( [TextClassificationRecord(**row) for row in dataframe.to_dict("records")] ) @_requires_datasets def prepare_for_training(self) -> "datasets.Dataset": """Prepares the dataset for training. This will return a ``datasets.Dataset`` with a *label* column, and one column for each key in the *inputs* dictionary of the records: - Records without an annotation are removed. - The *label* column corresponds to the annotations of the records. - Labels are transformed to integers. Returns: A datasets Dataset with a *label* column and several *inputs* columns. Examples: >>> import argilla as rg >>> rb_dataset = rg.DatasetForTextClassification([ ... rg.TextClassificationRecord( ... inputs={"header": "my header", "content": "my content"}, ... annotation="SPAM", ... ) ... ]) >>> rb_dataset.prepare_for_training().features {'header': Value(dtype='string'), 'content': Value(dtype='string'), 'label': ClassLabel(num_classes=1, names=['SPAM'])} """ import datasets inputs_keys = { key: None for rec in self._records for key in rec.inputs if rec.annotation is not None }.keys() ds_dict = {**{key: [] for key in inputs_keys}, "label": []} for rec in self._records: if rec.annotation is None: continue for key in inputs_keys: ds_dict[key].append(rec.inputs.get(key)) ds_dict["label"].append(rec.annotation) if self._records[0].multi_label: labels = {label: None for labels in ds_dict["label"] for label in labels} else: labels = {label: None for label in ds_dict["label"]} class_label = ( datasets.ClassLabel(names=sorted(labels.keys())) if ds_dict["label"] # in case we don't have any labels, ClassLabel fails with Dataset.from_dict({"labels": []}) else datasets.Value("string") ) feature_dict = { **{key: datasets.Value("string") for key in inputs_keys}, "label": [class_label] if self._records[0].multi_label else class_label, } return datasets.Dataset.from_dict( ds_dict, features=datasets.Features(feature_dict) ) class Framework(Enum): TRANSFORMERS = "transformers" SPACY = "spacy" @classmethod def _missing_(cls, value): raise ValueError( f"{value} is not a valid {cls.__name__}, please select one of {list(cls._value2member_map_.keys())}" ) @_prepend_docstring(TokenClassificationRecord) class DatasetForTokenClassification(DatasetBase): """ Examples: >>> # Import/export records: >>> import argilla as rg >>> dataset = rg.DatasetForTokenClassification.from_pandas(my_dataframe) >>> dataset.to_datasets() >>> >>> # Looping over the dataset: >>> assert len(dataset) == 2 >>> for record in dataset: ... print(record) >>> >>> # Passing in a list of records: >>> import argilla as rg >>> records = [ ... rg.TokenClassificationRecord(text="example", tokens=["example"]), ... rg.TokenClassificationRecord(text="another example", tokens=["another", "example"]), ... ] >>> dataset = rg.DatasetForTokenClassification(records) >>> >>> # Indexing into the dataset: >>> dataset[0] ... rg.TokenClassificationRecord(text="example", tokens=["example"]) >>> dataset[0] = rg.TokenClassificationRecord(text="replace example", tokens=["replace", "example"]) """ _RECORD_TYPE = TokenClassificationRecord def __init__(self, records: Optional[List[TokenClassificationRecord]] = None): # we implement this to have more specific type hints super().__init__(records=records) @classmethod def _record_init_args(cls) -> List[str]: """Adds the `tags` argument to default record init arguments""" parent_fields = super(DatasetForTokenClassification, cls)._record_init_args() return parent_fields + ["tags"] # compute annotation from tags @classmethod def from_datasets( cls, dataset: "datasets.Dataset", text: Optional[str] = None, id: Optional[str] = None, tokens: Optional[str] = None, tags: Optional[str] = None, metadata: Optional[Union[str, List[str]]] = None, ) -> "DatasetForTokenClassification": """Imports records from a `datasets.Dataset`. Columns that are not supported are ignored. Args: dataset: A datasets Dataset from which to import the records. text: The field name used as record text. Default: `None` id: The field name used as record id. Default: `None` tokens: The field name used as record tokens. Default: `None` tags: The field name used as record tags. Default: `None` metadata: The field name used as record metadata. Default: `None` Returns: The imported records in a argilla Dataset. Examples: >>> import datasets >>> ds = datasets.Dataset.from_dict({ ... "text": ["my example"], ... "tokens": [["my", "example"]], ... "prediction": [ ... [{"label": "LABEL1", "start": 3, "end": 10, "score": 1.0}] ... ] ... }) >>> DatasetForTokenClassification.from_datasets(ds) """ dataset, cols_to_be_joined = cls._prepare_dataset_and_column_mapping( dataset, dict( text=text, tokens=tokens, tags=tags, id=id, metadata=metadata, ), ) records = [] for row in dataset: # TODO: fails with a KeyError if no tokens column is present and no mapping is indicated if not row["tokens"]: _LOGGER.warning(f"Ignoring row with no tokens.") continue if row.get("tags"): row["tags"] = cls._parse_datasets_column_with_classlabel( row["tags"], dataset.features["tags"] ) if row.get("prediction"): row["prediction"] = cls.__entities_to_tuple__(row["prediction"]) if row.get("annotation"): row["annotation"] = cls.__entities_to_tuple__(row["annotation"]) if cols_to_be_joined.get("metadata"): row["metadata"] = cls._join_datasets_columns_and_delete( row, cols_to_be_joined["metadata"] ) records.append(TokenClassificationRecord.parse_obj(row)) return cls(records) @classmethod def from_pandas( # we implement this to have more specific type hints cls, dataframe: pd.DataFrame, ) -> "DatasetForTokenClassification": return super().from_pandas(dataframe) def prepare_for_training( self, framework: Union[Framework, str] = "transformers", lang: Optional["spacy.Language"] = None, ) -> Union["datasets.Dataset", "spacy.tokens.DocBin"]: """Prepares the dataset for training. This will return a ``datasets.Dataset`` with all columns returned by ``to_datasets`` method and an additional *ner_tags* column: - Records without an annotation are removed. - The *ner_tags* column corresponds to the iob tags sequences for annotations of the records - The iob tags are transformed to integers. Args: framework: A string|enum specifying the framework for the training. "transformers" and "spacy" are currently supported. Default: `transformers` lang: The spacy nlp Language pipeline used to process the dataset. (Only for spacy framework) Returns: A datasets Dataset with a *ner_tags* column and all columns returned by ``to_datasets`` for "transformers" framework. A spacy DocBin ready to use for training a spacy NER model for "spacy" framework. Examples: >>> import argilla as rg >>> rb_dataset = rg.DatasetForTokenClassification([ ... rg.TokenClassificationRecord( ... text="The text", ... tokens=["The", "text"], ... annotation=[("TAG", 0, 2)], ... ) ... ]) >>> rb_dataset.prepare_for_training().features {'text': Value(dtype='string'), 'tokens': Sequence(feature=Value(dtype='string'), length=-1), 'prediction': Value(dtype='null'), 'prediction_agent': Value(dtype='null'), 'annotation': [{'end': Value(dtype='int64'), 'label': Value(dtype='string'), 'start': Value(dtype='int64')}], 'annotation_agent': Value(dtype='null'), 'id': Value(dtype='null'), 'metadata': Value(dtype='null'), 'status': Value(dtype='string'), 'event_timestamp': Value(dtype='null'), 'metrics': Value(dtype='null'), 'ner_tags': [ClassLabel(num_classes=3, names=['O', 'B-TAG', 'I-TAG'])]} """ # turn the string into a Framework instance and trigger error if str is not valid if isinstance(framework, str): framework = Framework(framework) if framework is Framework.TRANSFORMERS: return self._prepare_for_training_with_transformers() # else: must be spacy for sure if lang is None: raise ValueError( "Please provide a spacy language model to prepare the dataset for training with the spacy framework." ) return self._prepare_for_training_with_spacy(nlp=lang) @_requires_datasets def _prepare_for_training_with_transformers(self): import datasets has_annotations = False for rec in self._records: if rec.annotation is not None: has_annotations = True break if not has_annotations: return datasets.Dataset.from_dict({}) class_tags = ["O"] class_tags.extend( [ f"{pre}-{label}" for label in sorted(self.__all_labels__()) for pre in ["B", "I"] ] ) class_tags = datasets.ClassLabel(names=class_tags) def spans2iob(example): span_utils = SpanUtils(example["text"], example["tokens"]) entity_spans = self.__entities_to_tuple__(example["annotation"]) tags = span_utils.to_tags(entity_spans) return class_tags.str2int(tags) ds = ( self.to_datasets() .filter(self.__only_annotations__) .map(lambda example: {"ner_tags": spans2iob(example)}) ) new_features = ds.features.copy() new_features["ner_tags"] = [class_tags] return ds.cast(new_features) @_requires_spacy def _prepare_for_training_with_spacy( self, nlp: "spacy.Language" ) -> "spacy.tokens.DocBin": from spacy.tokens import DocBin db = DocBin() # Creating the DocBin object as in https://spacy.io/usage/training#training-data for record in self._records: if record.annotation is None: continue doc = nlp.make_doc(record.text) entities = [] for anno in record.annotation: span = doc.char_span(anno[1], anno[2], label=anno[0]) # There is a misalignment between record tokenization and spaCy tokenization if span is None: # TODO(@dcfidalgo): Do we want to warn and continue or should we stop the training set generation? raise ValueError( "The following annotation does not align with the tokens produced " f"by the provided spacy language model: {(anno[0], record.text[anno[1]:anno[2]])}, {list(doc)}" ) else: entities.append(span) doc.ents = entities db.add(doc) return db def __all_labels__(self): all_labels = set() for record in self._records: if record.annotation: all_labels.update([label for label, _, _ in record.annotation]) return list(all_labels) def __only_annotations__(self, data) -> bool: return data["annotation"] is not None def _to_datasets_dict(self) -> Dict: """Helper method to put token classification records in a `datasets.Dataset`""" # create a dict first, where we make the necessary transformations def entities_to_dict( entities: Optional[ List[Union[Tuple[str, int, int, float], Tuple[str, int, int]]] ] ) -> Optional[List[Dict[str, Union[str, int, float]]]]: if entities is None: return None return [ {"label": ent[0], "start": ent[1], "end": ent[2]} if len(ent) == 3 else {"label": ent[0], "start": ent[1], "end": ent[2], "score": ent[3]} for ent in entities ] ds_dict = {} for key in self._RECORD_TYPE.__fields__: if key == "prediction": ds_dict[key] = [ entities_to_dict(rec.prediction) for rec in self._records ] elif key == "annotation": ds_dict[key] = [ entities_to_dict(rec.annotation) for rec in self._records ] elif key == "id": ds_dict[key] = [ None if rec.id is None else str(rec.id) for rec in self._records ] elif key == "metadata": ds_dict[key] = [getattr(rec, key) or None for rec in self._records] else: ds_dict[key] = [getattr(rec, key) for rec in self._records] return ds_dict @staticmethod def __entities_to_tuple__( entities, ) -> List[Union[Tuple[str, int, int], Tuple[str, int, int, float]]]: return [ (ent["label"], ent["start"], ent["end"]) if len(ent) == 3 else (ent["label"], ent["start"], ent["end"], ent["score"] or 0.0) for ent in entities ] @classmethod def _from_pandas(cls, dataframe: pd.DataFrame) -> "DatasetForTokenClassification": return cls( [TokenClassificationRecord(**row) for row in dataframe.to_dict("records")] ) @_prepend_docstring(Text2TextRecord) class DatasetForText2Text(DatasetBase): """ Examples: >>> # Import/export records: >>> import argilla as rg >>> dataset = rg.DatasetForText2Text.from_pandas(my_dataframe) >>> dataset.to_datasets() >>> >>> # Passing in a list of records: >>> records = [ ... rg.Text2TextRecord(text="example"), ... rg.Text2TextRecord(text="another example"), ... ] >>> dataset = rg.DatasetForText2Text(records) >>> assert len(dataset) == 2 >>> >>> # Looping over the dataset: >>> for record in dataset: ... print(record) >>> >>> # Indexing into the dataset: >>> dataset[0] ... rg.Text2TextRecord(text="example"}) >>> dataset[0] = rg.Text2TextRecord(text="replaced example") """ _RECORD_TYPE = Text2TextRecord def __init__(self, records: Optional[List[Text2TextRecord]] = None): # we implement this to have more specific type hints super().__init__(records=records) @classmethod def from_datasets( cls, dataset: "datasets.Dataset", text: Optional[str] = None, annotation: Optional[str] = None, metadata: Optional[Union[str, List[str]]] = None, id: Optional[str] = None, ) -> "DatasetForText2Text": """Imports records from a `datasets.Dataset`. Columns that are not supported are ignored. Args: dataset: A datasets Dataset from which to import the records. text: The field name used as record text. Default: `None` annotation: The field name used as record annotation. Default: `None` metadata: The field name used as record metadata. Default: `None` Returns: The imported records in a argilla Dataset. Examples: >>> import datasets >>> ds = datasets.Dataset.from_dict({ ... "text": ["my example"], ... "prediction": [["mi ejemplo", "ejemplo mio"]] ... }) >>> # or >>> ds = datasets.Dataset.from_dict({ ... "text": ["my example"], ... "prediction": [[{"text": "mi ejemplo", "score": 0.9}]] ... }) >>> DatasetForText2Text.from_datasets(ds) """ dataset, cols_to_be_joined = cls._prepare_dataset_and_column_mapping( dataset, dict( text=text, annotation=annotation, id=id, metadata=metadata, ), ) records = [] for row in dataset: if row.get("prediction"): row["prediction"] = cls._parse_prediction_field(row["prediction"]) if cols_to_be_joined.get("metadata"): row["metadata"] = cls._join_datasets_columns_and_delete( row, cols_to_be_joined["metadata"] ) records.append(Text2TextRecord.parse_obj(row)) return cls(records) @staticmethod def _parse_prediction_field(predictions: List[Union[str, Dict[str, str]]]): def extract_prediction(prediction: Union[str, Dict]): if isinstance(prediction, str): return prediction if prediction["score"] is None: return prediction["text"] return prediction["text"], prediction["score"] return [extract_prediction(pred) for pred in predictions] @classmethod def from_pandas( # we implement this to have more specific type hints cls, dataframe: pd.DataFrame, ) -> "DatasetForText2Text": return super().from_pandas(dataframe) def _to_datasets_dict(self) -> Dict: # create a dict first, where we make the necessary transformations def pred_to_dict(pred: Union[str, Tuple[str, float]]): if isinstance(pred, str): return {"text": pred, "score": None} return {"text": pred[0], "score": pred[1]} ds_dict = {} for key in self._RECORD_TYPE.__fields__: if key == "prediction": ds_dict[key] = [ [pred_to_dict(pred) for pred in rec.prediction] if rec.prediction is not None else None for rec in self._records ] elif key == "id": ds_dict[key] = [ None if rec.id is None else str(rec.id) for rec in self._records ] elif key == "metadata": ds_dict[key] = [getattr(rec, key) or None for rec in self._records] else: ds_dict[key] = [getattr(rec, key) for rec in self._records] return ds_dict @classmethod def _from_pandas(cls, dataframe: pd.DataFrame) -> "DatasetForText2Text": return cls([Text2TextRecord(**row) for row in dataframe.to_dict("records")]) Dataset = Union[ DatasetForTextClassification, DatasetForTokenClassification, DatasetForText2Text ] def read_datasets( dataset: "datasets.Dataset", task: Union[str, TaskType], **kwargs ) -> Dataset: """Reads a datasets Dataset and returns a argilla Dataset Args: dataset: Dataset to be read in. task: Task for the dataset, one of: ["TextClassification", "TokenClassification", "Text2Text"]. **kwargs: Passed on to the task-specific ``DatasetFor*.from_datasets()`` method. Returns: A argilla dataset for the given task. Examples: >>> # Read text classification records from a datasets Dataset >>> import datasets >>> ds = datasets.Dataset.from_dict({ ... "inputs": ["example"], ... "prediction": [ ... [{"label": "LABEL1", "score": 0.9}, {"label": "LABEL2", "score": 0.1}] ... ] ... }) >>> read_datasets(ds, task="TextClassification") >>> >>> # Read token classification records from a datasets Dataset >>> ds = datasets.Dataset.from_dict({ ... "text": ["my example"], ... "tokens": [["my", "example"]], ... "prediction": [ ... [{"label": "LABEL1", "start": 3, "end": 10}] ... ] ... }) >>> read_datasets(ds, task="TokenClassification") >>> >>> # Read text2text records from a datasets Dataset >>> ds = datasets.Dataset.from_dict({ ... "text": ["my example"], ... "prediction": [["mi ejemplo", "ejemplo mio"]] ... }) >>> # or >>> ds = datasets.Dataset.from_dict({ ... "text": ["my example"], ... "prediction": [[{"text": "mi ejemplo", "score": 0.9}]] ... }) >>> read_datasets(ds, task="Text2Text") """ if isinstance(task, str): task = TaskType(task) if task is TaskType.text_classification: return DatasetForTextClassification.from_datasets(dataset, **kwargs) if task is TaskType.token_classification: return DatasetForTokenClassification.from_datasets(dataset, **kwargs) if task is TaskType.text2text: return DatasetForText2Text.from_datasets(dataset, **kwargs) raise NotImplementedError( "Reading a datasets Dataset is not implemented for the given task!" ) def read_pandas(dataframe: pd.DataFrame, task: Union[str, TaskType]) -> Dataset: """Reads a pandas DataFrame and returns a argilla Dataset Args: dataframe: Dataframe to be read in. task: Task for the dataset, one of: ["TextClassification", "TokenClassification", "Text2Text"] Returns: A argilla dataset for the given task. Examples: >>> # Read text classification records from a pandas DataFrame >>> import pandas as pd >>> df = pd.DataFrame({ ... "inputs": ["example"], ... "prediction": [ ... [("LABEL1", 0.9), ("LABEL2", 0.1)] ... ] ... }) >>> read_pandas(df, task="TextClassification") >>> >>> # Read token classification records from a datasets Dataset >>> df = pd.DataFrame({ ... "text": ["my example"], ... "tokens": [["my", "example"]], ... "prediction": [ ... [("LABEL1", 3, 10)] ... ] ... }) >>> read_pandas(df, task="TokenClassification") >>> >>> # Read text2text records from a datasets Dataset >>> df = pd.DataFrame({ ... "text": ["my example"], ... "prediction": [["mi ejemplo", "ejemplo mio"]] ... }) >>> # or >>> ds = pd.DataFrame({ ... "text": ["my example"], ... "prediction": [[("mi ejemplo", 0.9)]] ... }) >>> read_pandas(df, task="Text2Text") """ if isinstance(task, str): task = TaskType(task) if task is TaskType.text_classification: return DatasetForTextClassification.from_pandas(dataframe) if task is TaskType.token_classification: return DatasetForTokenClassification.from_pandas(dataframe) if task is TaskType.text2text: return DatasetForText2Text.from_pandas(dataframe) raise NotImplementedError( "Reading a pandas DataFrame is not implemented for the given task!" ) class WrongRecordTypeError(Exception): pass