diff --git a/lightautoml/text/nn_model.py b/lightautoml/text/nn_model.py index 42265bd5..e47d4dbb 100644 --- a/lightautoml/text/nn_model.py +++ b/lightautoml/text/nn_model.py @@ -5,6 +5,7 @@ from typing import Any from typing import Callable from typing import Dict +from typing import List from typing import Optional from typing import Union import numpy as np @@ -74,6 +75,18 @@ def __getitem__(self, index: int) -> Dict[str, np.ndarray]: return res + def __getitems__(self, indices) -> Union[Dict[str, np.ndarray], List[Dict[str, np.ndarray]]]: + if self.tokenizer is not None: + return [self.__getitem__(index) for index in indices] + + indices = np.asarray(indices) + res = {"label": self.y[indices]} + res.update({key: value[indices] for key, value in self.data.items() if key != "text"}) + if self.w is not None: + res["weight"] = self.w[indices] + + return res + class Clump(nn.Module): """Clipping input tensor. diff --git a/lightautoml/text/utils.py b/lightautoml/text/utils.py index 991689f9..61910f57 100644 --- a/lightautoml/text/utils.py +++ b/lightautoml/text/utils.py @@ -10,6 +10,7 @@ from typing import List from typing import Dict from typing import Sequence +from typing import Union _dtypes_mapping = { @@ -136,7 +137,13 @@ def parse_devices(dvs, is_dp: bool = False) -> tuple: return device[0], ids if (len(device) > 1) and is_dp else None -def custom_collate(batch: List[np.ndarray]) -> torch.Tensor: +def _cast_collated_tensor(tensor: torch.Tensor, dtype_name: str) -> torch.Tensor: + if dtype_name == "long": + return tensor if tensor.dtype == torch.long else tensor.long() + return tensor if tensor.dtype == torch.float32 else tensor.float() + + +def custom_collate(batch: List[np.ndarray], dtype_name: str = "float") -> torch.Tensor: """Puts each data field into a tensor with outer dimension batch size.""" elem = batch[0] if isinstance(elem, torch.Tensor): @@ -144,16 +151,22 @@ def custom_collate(batch: List[np.ndarray]) -> torch.Tensor: numel = sum([x.numel() for x in batch]) storage = elem.storage()._new_shared(numel) out = elem.new(storage) - return torch.stack(batch, 0, out=out) + return _cast_collated_tensor(torch.stack(batch, 0, out=out), dtype_name) else: - return torch.from_numpy(np.array(batch)).float() + return _cast_collated_tensor(torch.from_numpy(np.array(batch)), dtype_name) -def collate_dict(batch: List[Dict[str, np.ndarray]]) -> Dict[str, torch.Tensor]: +def collate_dict(batch: Union[Dict[str, np.ndarray], List[Dict[str, np.ndarray]]]) -> Dict[str, torch.Tensor]: """custom_collate for dicts.""" + if isinstance(batch, dict): + return { + key: _cast_collated_tensor(torch.as_tensor(value), _dtypes_mapping.get(key, "float")) + for key, value in batch.items() + } + keys = list(batch[0].keys()) transposed_data = list(map(list, zip(*[tuple([i[name] for name in i.keys()]) for i in batch]))) - return {key: custom_collate(transposed_data[n]) for n, key in enumerate(keys)} + return {key: custom_collate(transposed_data[n], _dtypes_mapping.get(key, "float")) for n, key in enumerate(keys)} def single_text_hash(x: str) -> str: diff --git a/tests/unit/test_text/test_universal_dataset.py b/tests/unit/test_text/test_universal_dataset.py new file mode 100644 index 00000000..96b96479 --- /dev/null +++ b/tests/unit/test_text/test_universal_dataset.py @@ -0,0 +1,94 @@ +import numpy as np +import torch + +from lightautoml.text.nn_model import UniversalDataset +from lightautoml.text.utils import collate_dict + + +def test_universal_dataset_batched_getitems_matches_rowwise_batch(): + data = { + "cat": np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]], dtype=np.int64), + "cont": np.array([[0.1, 0.2], [1.1, 1.2], [2.1, 2.2], [3.1, 3.2]], dtype=np.float32), + } + y = np.array([0.0, 1.0, 0.0, 1.0], dtype=np.float32) + w = np.array([1.0, 0.5, 2.0, 1.5], dtype=np.float32) + indices = [0, 2, 3] + + dataset = UniversalDataset(data=data, y=y, w=w, tokenizer=None) + + row_batch = [dataset[index] for index in indices] + fast_batch = dataset.__getitems__(indices) + row_collated = collate_dict(row_batch) + fast_collated = collate_dict(fast_batch) + + assert row_collated.keys() == fast_collated.keys() + for key in row_collated: + assert row_collated[key].shape == fast_collated[key].shape + assert torch.equal(row_collated[key], fast_collated[key]) + + assert fast_collated["cat"].dtype == torch.int64 + assert fast_collated["cont"].dtype == torch.float32 + assert fast_collated["label"].dtype == torch.float32 + assert fast_collated["weight"].dtype == torch.float32 + + +def test_collate_dict_supports_rowwise_and_batched_dict_inputs(): + row_batch = [ + { + "cat": np.array([1, 2], dtype=np.int64), + "cont": np.array([0.1, 0.2], dtype=np.float32), + "label": np.array(0.0, dtype=np.float32), + "weight": np.array(1.0, dtype=np.float32), + }, + { + "cat": np.array([3, 4], dtype=np.int64), + "cont": np.array([1.1, 1.2], dtype=np.float32), + "label": np.array(1.0, dtype=np.float32), + "weight": np.array(0.5, dtype=np.float32), + }, + ] + batched_dict = { + "cat": np.array([[1, 2], [3, 4]], dtype=np.int64), + "cont": np.array([[0.1, 0.2], [1.1, 1.2]], dtype=np.float32), + "label": np.array([0.0, 1.0], dtype=np.float32), + "weight": np.array([1.0, 0.5], dtype=np.float32), + } + + row_collated = collate_dict(row_batch) + batched_collated = collate_dict(batched_dict) + + assert row_collated.keys() == batched_collated.keys() + for key in row_collated: + assert row_collated[key].shape == batched_collated[key].shape + assert torch.equal(row_collated[key], batched_collated[key]) + + assert batched_collated["cat"].dtype == torch.int64 + assert batched_collated["cont"].dtype == torch.float32 + assert batched_collated["label"].dtype == torch.float32 + assert batched_collated["weight"].dtype == torch.float32 + + +def test_universal_dataset_getitems_falls_back_to_rowwise_with_tokenizer(): + class DummyTokenizer: + def encode_plus(self, *args, **kwargs): + return { + "input_ids": [1, 2, 3], + "attention_mask": [1, 1, 1], + "token_type_ids": [0, 0, 0], + } + + dataset = UniversalDataset( + data={"text": np.array([["hello"], ["world"]])}, + y=np.array([0.0, 1.0], dtype=np.float32), + tokenizer=DummyTokenizer(), + ) + + batch = dataset.__getitems__([0, 1]) + row_batch = [dataset[0], dataset[1]] + + assert isinstance(batch, list) + assert len(batch) == 2 + for batched_item, row_item in zip(batch, row_batch): + assert batched_item.keys() == row_item.keys() + for key in batched_item: + assert np.array_equal(batched_item[key], row_item[key])