Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions lightautoml/text/nn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Any
from typing import Callable
from typing import Dict
from typing import List
from typing import Optional
from typing import Union
import numpy as np
Expand Down Expand Up @@ -74,6 +75,18 @@ def __getitem__(self, index: int) -> Dict[str, np.ndarray]:

return res

def __getitems__(self, indices) -> Union[Dict[str, np.ndarray], List[Dict[str, np.ndarray]]]:
if self.tokenizer is not None:
return [self.__getitem__(index) for index in indices]

indices = np.asarray(indices)
res = {"label": self.y[indices]}
res.update({key: value[indices] for key, value in self.data.items() if key != "text"})
if self.w is not None:
res["weight"] = self.w[indices]

return res


class Clump(nn.Module):
"""Clipping input tensor.
Expand Down
23 changes: 18 additions & 5 deletions lightautoml/text/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import List
from typing import Dict
from typing import Sequence
from typing import Union


_dtypes_mapping = {
Expand Down Expand Up @@ -136,24 +137,36 @@ def parse_devices(dvs, is_dp: bool = False) -> tuple:
return device[0], ids if (len(device) > 1) and is_dp else None


def custom_collate(batch: List[np.ndarray]) -> torch.Tensor:
def _cast_collated_tensor(tensor: torch.Tensor, dtype_name: str) -> torch.Tensor:
if dtype_name == "long":
return tensor if tensor.dtype == torch.long else tensor.long()
return tensor if tensor.dtype == torch.float32 else tensor.float()


def custom_collate(batch: List[np.ndarray], dtype_name: str = "float") -> torch.Tensor:
"""Puts each data field into a tensor with outer dimension batch size."""
elem = batch[0]
if isinstance(elem, torch.Tensor):
out = None
numel = sum([x.numel() for x in batch])
storage = elem.storage()._new_shared(numel)
out = elem.new(storage)
return torch.stack(batch, 0, out=out)
return _cast_collated_tensor(torch.stack(batch, 0, out=out), dtype_name)
else:
return torch.from_numpy(np.array(batch)).float()
return _cast_collated_tensor(torch.from_numpy(np.array(batch)), dtype_name)


def collate_dict(batch: List[Dict[str, np.ndarray]]) -> Dict[str, torch.Tensor]:
def collate_dict(batch: Union[Dict[str, np.ndarray], List[Dict[str, np.ndarray]]]) -> Dict[str, torch.Tensor]:
"""custom_collate for dicts."""
if isinstance(batch, dict):
return {
key: _cast_collated_tensor(torch.as_tensor(value), _dtypes_mapping.get(key, "float"))
for key, value in batch.items()
}

keys = list(batch[0].keys())
transposed_data = list(map(list, zip(*[tuple([i[name] for name in i.keys()]) for i in batch])))
return {key: custom_collate(transposed_data[n]) for n, key in enumerate(keys)}
return {key: custom_collate(transposed_data[n], _dtypes_mapping.get(key, "float")) for n, key in enumerate(keys)}


def single_text_hash(x: str) -> str:
Expand Down
94 changes: 94 additions & 0 deletions tests/unit/test_text/test_universal_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import numpy as np
import torch

from lightautoml.text.nn_model import UniversalDataset
from lightautoml.text.utils import collate_dict


def test_universal_dataset_batched_getitems_matches_rowwise_batch():
data = {
"cat": np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]], dtype=np.int64),
"cont": np.array([[0.1, 0.2], [1.1, 1.2], [2.1, 2.2], [3.1, 3.2]], dtype=np.float32),
}
y = np.array([0.0, 1.0, 0.0, 1.0], dtype=np.float32)
w = np.array([1.0, 0.5, 2.0, 1.5], dtype=np.float32)
indices = [0, 2, 3]

dataset = UniversalDataset(data=data, y=y, w=w, tokenizer=None)

row_batch = [dataset[index] for index in indices]
fast_batch = dataset.__getitems__(indices)
row_collated = collate_dict(row_batch)
fast_collated = collate_dict(fast_batch)

assert row_collated.keys() == fast_collated.keys()
for key in row_collated:
assert row_collated[key].shape == fast_collated[key].shape
assert torch.equal(row_collated[key], fast_collated[key])

assert fast_collated["cat"].dtype == torch.int64
assert fast_collated["cont"].dtype == torch.float32
assert fast_collated["label"].dtype == torch.float32
assert fast_collated["weight"].dtype == torch.float32


def test_collate_dict_supports_rowwise_and_batched_dict_inputs():
row_batch = [
{
"cat": np.array([1, 2], dtype=np.int64),
"cont": np.array([0.1, 0.2], dtype=np.float32),
"label": np.array(0.0, dtype=np.float32),
"weight": np.array(1.0, dtype=np.float32),
},
{
"cat": np.array([3, 4], dtype=np.int64),
"cont": np.array([1.1, 1.2], dtype=np.float32),
"label": np.array(1.0, dtype=np.float32),
"weight": np.array(0.5, dtype=np.float32),
},
]
batched_dict = {
"cat": np.array([[1, 2], [3, 4]], dtype=np.int64),
"cont": np.array([[0.1, 0.2], [1.1, 1.2]], dtype=np.float32),
"label": np.array([0.0, 1.0], dtype=np.float32),
"weight": np.array([1.0, 0.5], dtype=np.float32),
}

row_collated = collate_dict(row_batch)
batched_collated = collate_dict(batched_dict)

assert row_collated.keys() == batched_collated.keys()
for key in row_collated:
assert row_collated[key].shape == batched_collated[key].shape
assert torch.equal(row_collated[key], batched_collated[key])

assert batched_collated["cat"].dtype == torch.int64
assert batched_collated["cont"].dtype == torch.float32
assert batched_collated["label"].dtype == torch.float32
assert batched_collated["weight"].dtype == torch.float32


def test_universal_dataset_getitems_falls_back_to_rowwise_with_tokenizer():
class DummyTokenizer:
def encode_plus(self, *args, **kwargs):
return {
"input_ids": [1, 2, 3],
"attention_mask": [1, 1, 1],
"token_type_ids": [0, 0, 0],
}

dataset = UniversalDataset(
data={"text": np.array([["hello"], ["world"]])},
y=np.array([0.0, 1.0], dtype=np.float32),
tokenizer=DummyTokenizer(),
)

batch = dataset.__getitems__([0, 1])
row_batch = [dataset[0], dataset[1]]

assert isinstance(batch, list)
assert len(batch) == 2
for batched_item, row_item in zip(batch, row_batch):
assert batched_item.keys() == row_item.keys()
for key in batched_item:
assert np.array_equal(batched_item[key], row_item[key])