#!/usr/bin/env python # -*- coding: utf-8 -*- # # Copyright (c) 2022 Intel Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # # SPDX-License-Identifier: Apache-2.0 # import os import tensorflow as tf from tlt import TLT_BASE_DIR from tlt.datasets.tf_dataset import TFDataset from tlt.datasets.text_classification.text_classification_dataset import TextClassificationDataset from tlt.utils.dataset_utils import prepare_huggingface_input_data from tlt.utils.file_utils import read_json_file from tlt.utils.inc_utils import INCTFDataLoader from downloader.datasets import DataDownloader DATASET_CONFIG_DIR = os.path.join(TLT_BASE_DIR, "datasets/configs") config_file = os.path.join(DATASET_CONFIG_DIR, "tf_text_classification_datasets.json") config_dict = read_json_file(config_file) DATASETS = list(config_dict.keys()) class TFDSTextClassificationDataset(TFDataset, TextClassificationDataset): """ A text classification dataset from the TensorFlow datasets catalog """ def __init__(self, dataset_dir, dataset_name, split=["train"], shuffle_files=True, **kwargs): if not isinstance(split, list): raise ValueError("Value of split argument must be a list.") TextClassificationDataset.__init__(self, dataset_dir, dataset_name, "tf_datasets") if dataset_name not in DATASETS: raise ValueError("Dataset name is not supported. Choose from: {}".format(DATASETS)) # as_supervised gives us the (input, label) structure that the model expects as_supervised = True # Glue datasets don't support as_supervised=True, so we need to set as_supervised=False, and then fix # the data format after loading if "glue" in dataset_name: as_supervised = False downloader = DataDownloader(dataset_name, dataset_dir=dataset_dir, catalog='tfds', as_supervised=as_supervised, shuffle_files=shuffle_files, with_info=True) data, self._info = downloader.download(split=split) # Since glue datasets don't support the supervised (input, label) structure, we have to manually format it if "glue" in dataset_name: for split_id in range(len(data)): data[split_id] = data[split_id].map(lambda x: (x['sentence'], x['label'])) self._dataset = None self._train_subset = None self._validation_subset = None self._test_subset = None self._preprocessed = None if len(split) == 1: self._validation_type = None # Train & evaluate on the whole dataset self._dataset = data[0] else: self._validation_type = 'defined_split' # Defined by user or TFDS for i, s in enumerate(split): if s == 'train': self._train_subset = data[i] elif s == 'validation': self._validation_subset = data[i] elif s == 'test': self._test_subset = data[i] self._dataset = data[i] if self._dataset is None else self._dataset.concatenate(data[i]) @property def class_names(self): if "label" in self._info.features.keys(): return self._info.features["label"].names else: return [] @property def info(self): return {'dataset_info': self._info, 'preprocessing_info': self._preprocessed} @property def dataset(self): return self._dataset def preprocess(self, batch_size): """ Batch the dataset Args: batch_size (int): desired batch size Raises: TypeError: if the batch_size is not a positive integer ValueError: if the dataset is not defined or has already been processed """ if not isinstance(batch_size, int) or batch_size < 1: raise ValueError("batch_size should be a positive integer") if self._preprocessed: raise ValueError("Data has already been preprocessed: {}".format(self._preprocessed)) # Get the non-None splits split_list = ['_dataset', '_train_subset', '_validation_subset', '_test_subset'] subsets = [s for s in split_list if getattr(self, s, None)] for subset in subsets: setattr(self, subset, getattr(self, subset).cache()) setattr(self, subset, getattr(self, subset).batch(batch_size)) setattr(self, subset, getattr(self, subset).prefetch(tf.data.AUTOTUNE)) self._preprocessed = {'batch_size': batch_size} def get_inc_dataloaders(self, hub_name, max_seq_length): calib_data, calib_labels = prepare_huggingface_input_data(self.train_subset, hub_name, max_seq_length) calib_data['label'] = tf.convert_to_tensor(calib_labels) eval_data, eval_labels = prepare_huggingface_input_data(self.validation_subset, hub_name, max_seq_length) eval_data['label'] = tf.convert_to_tensor(eval_labels) calib_data.pop('token_type_ids') eval_data.pop('token_type_ids') calib_dataloader = INCTFDataLoader(calib_data, batch_size=self._preprocessed['batch_size']) eval_dataloader = INCTFDataLoader(eval_data, batch_size=self._preprocessed['batch_size']) return calib_dataloader, eval_dataloader