From eb62ae7b1186d647d776c47e5025adf795e2a728 Mon Sep 17 00:00:00 2001 From: "chenxun.p" <759046501@qq.com> Date: Fri, 2 Aug 2024 13:57:39 +0800 Subject: [PATCH] add packed dataset support --- .../megatron/internlm/__init__.py | 13 + .../megatron/internlm/batch_sampler.py | 352 ++++++++++ .../megatron/internlm/collaters.py | 102 +++ .../megatron/internlm/dataset.py | 56 ++ .../megatron/internlm/dummy_dataset.py | 49 ++ .../internlm/dummy_dataset_multimodal.py | 61 ++ .../megatron/internlm/packed_dataset.py | 612 ++++++++++++++++++ .../megatron/internlm/single_dataset.py | 117 ++++ .../megatron/internlm/utils.py | 44 ++ .../language_modeling/megatron_gpt_model.py | 445 +++++++++++-- 10 files changed, 1789 insertions(+), 62 deletions(-) create mode 100644 nemo/collections/nlp/data/language_modeling/megatron/internlm/__init__.py create mode 100644 nemo/collections/nlp/data/language_modeling/megatron/internlm/batch_sampler.py create mode 100644 nemo/collections/nlp/data/language_modeling/megatron/internlm/collaters.py create mode 100644 nemo/collections/nlp/data/language_modeling/megatron/internlm/dataset.py create mode 100644 nemo/collections/nlp/data/language_modeling/megatron/internlm/dummy_dataset.py create mode 100644 nemo/collections/nlp/data/language_modeling/megatron/internlm/dummy_dataset_multimodal.py create mode 100644 nemo/collections/nlp/data/language_modeling/megatron/internlm/packed_dataset.py create mode 100644 nemo/collections/nlp/data/language_modeling/megatron/internlm/single_dataset.py create mode 100644 nemo/collections/nlp/data/language_modeling/megatron/internlm/utils.py diff --git a/nemo/collections/nlp/data/language_modeling/megatron/internlm/__init__.py b/nemo/collections/nlp/data/language_modeling/megatron/internlm/__init__.py new file mode 100644 index 000000000000..65b7eba12ea4 --- /dev/null +++ b/nemo/collections/nlp/data/language_modeling/megatron/internlm/__init__.py @@ -0,0 +1,13 @@ +from .batch_sampler import get_dpsampler_dataloader +from .collaters import jsonl_ds_collate_fn, packed_collate_fn +from .dummy_dataset import RandomDataset +from .packed_dataset import PackedDatasetWithCut, PackedDatasetWithoutCuSeqlen + +__all__ = [ + "jsonl_ds_collate_fn", + "packed_collate_fn", + "RandomDataset", + "PackedDatasetWithCut", + "PackedDatasetWithoutCuSeqlen", + "get_dpsampler_dataloader", +] diff --git a/nemo/collections/nlp/data/language_modeling/megatron/internlm/batch_sampler.py b/nemo/collections/nlp/data/language_modeling/megatron/internlm/batch_sampler.py new file mode 100644 index 000000000000..779681c493e2 --- /dev/null +++ b/nemo/collections/nlp/data/language_modeling/megatron/internlm/batch_sampler.py @@ -0,0 +1,352 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import math +import random +from typing import Iterator, TypeVar + +import numpy as np +import torch +from torch.utils.data import DataLoader, Dataset, Sampler + +from megatron.core import parallel_state +# from nemo.utils import logging as logger + +T_co = TypeVar("T_co", covariant=True) + + +class DataParallelSampler(Sampler): + """A data sampler for distributed data parallelism. + + Args: + dataset (:class:`torch.utils.data.Dataset`): The Dataset for sampling. + shuffle (bool, optional): Whether to shuffle data, defaults to False. + seed (int, optional): The random seed used for sampling, defaults to 0. + drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size + is not divisible by the batch size. If False and the size of dataset is not divisible by + the batch size, then the last batch will be smaller, defaults to False. + """ + + def __init__( + self, + dataset: Dataset, + shuffle: bool = False, + seed: int = 0, + drop_last: bool = False, + ) -> None: + self.dataset = dataset + self.num_replicas = parallel_state.get_data_parallel_world_size() + self.rank = parallel_state.get_data_parallel_rank() + self.epoch = 0 + self.drop_last = drop_last + # If the dataset length is evenly divisible by # of replicas, then there + # is no need to drop any data, since the dataset will be split equally. + # type: ignore[arg-type] + if self.drop_last and len(self.dataset) % self.num_replicas != 0: + # Split to nearest available length that is evenly divisible. + # This is to ensure each rank receives the same amount of data when + # using this Sampler. + self.num_samples = math.ceil( + # `type:ignore` is required because Dataset cannot provide a default __len__ + # see NOTE in pytorch/torch/utils/data/sampler.py + (len(self.dataset) - self.num_replicas) + / self.num_replicas # type: ignore[arg-type] + ) + else: + self.num_samples = math.ceil(len(self.dataset) / self.num_replicas) # type: ignore[arg-type] + self.total_size = self.num_samples * self.num_replicas + self.shuffle = shuffle + self.seed = seed + + def __iter__(self) -> Iterator[T_co]: + if self.shuffle: + # deterministically shuffle based on epoch and seed + g = torch.Generator() + g.manual_seed(self.seed + self.epoch) + # type: ignore[arg-type] + indices = torch.randperm(len(self.dataset), generator=g).tolist() + + # update for next epoch so that there is no need to call + # set_epoch manually + self.epoch += 1 + else: + indices = list(range(len(self.dataset))) # type: ignore[arg-type] + + if not self.drop_last: + # add extra samples to make it evenly divisible + padding_size = self.total_size - len(indices) + if padding_size <= len(indices): + indices += indices[:padding_size] + else: + indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size] + else: + # remove tail of data to make it evenly divisible. + indices = indices[: self.total_size] + assert len(indices) == self.total_size + + # subsample + indices = indices[self.rank : self.total_size : self.num_replicas] + assert len(indices) == self.num_samples + + return iter(indices) + + def __len__(self) -> int: + return self.num_samples + + def set_epoch(self, epoch: int) -> None: + r"""Sets the epoch for this sampler. When :attr:`shuffle=True`, this ensures all replicas + use a different random ordering for each epoch. Otherwise, the next iteration of this + sampler will yield the same ordering. + + Args: + epoch (int): Epoch number. + """ + self.epoch = epoch + + +class StaticBatchSampler: + """ + A static batch sampler that generates batches with a fixed micro-batch size. + + Args: + num_samples (int): The total number of samples in the dataset. + batch_size (int): The batch size for the current rank. Defaults to 192. + rampup_batch_size (str): A string with three space-separated integers representing the + starting batch size, the increment, and the number of steps between + each increment. For example, "192 24 8" means that the batch size + starts at 192 and increases by 24 every 8 steps. Defaults to + "6 2 8", which corresponds to a batch size of 2 for the first 6 steps. + micro_bsz (int): The micro-batch size. Defaults to 2. + seed (int): The random seed for shuffling the indices. Defaults to 0. + drop_last (bool): If True, drop the last incomplete batch. Currently only supports True. Defaults to True. + data_rank (int): The rank of the current process in the data parallel group. Defaults to 0. + data_world_size (int): The number of processes in the data parallel group. Defaults to 1. + """ + + def __init__( + self, + datasets, + batch_size=192, + rampup_batch_size="6 2 8", + micro_bsz=2, + seed=0, + drop_last=True, + data_rank=0, + data_world_size=1, + ): + assert drop_last is True, "Currently only support drop last" + if rampup_batch_size: + # In the process increase to batch_size + start_bsz, bsz_incre, incre_every = map(int, rampup_batch_size.split()) + else: + start_bsz, bsz_incre, incre_every = batch_size, batch_size, 1 + self.raw_rampup_batch_size = rampup_batch_size + self.start_bsz = start_bsz + self.bsz_incre = bsz_incre + self.incre_every = incre_every + if parallel_state._PIPELINE_MODEL_PARALLEL_GROUP is not None: + assert ( + batch_size - self.start_bsz + ) % self.bsz_incre == 0, f"{batch_size} - {self.start_bsz} should be multiple of {self.bsz_incre}" + assert batch_size % micro_bsz == 0, f"batch_size({batch_size}) should be multiple of micro_bsz({micro_bsz})" + assert ( + self.start_bsz % micro_bsz == 0 + ), f"start_bsz({self.start_bsz}) should be multiple of micro_bsz({micro_bsz})" + assert ( + self.bsz_incre % micro_bsz == 0 + ), f"bsz_incre({self.bsz_incre}) should be multiple of micro_bsz({micro_bsz})" + + self.batch_size = batch_size + self.epoch = 0 + self.seed = seed + self.rng = np.random.RandomState(seed) + self.batch_count = 0 + self.micro_bsz = micro_bsz + self.data_rank = data_rank + self.data_world_size = data_world_size + self.num_consumed_samples_in_epoch = 0 + self.datasets = datasets + self.num_samples = sum([len(ds) for ds in datasets]) + + self.get_indices() # get data + + def get_indices(self, old_indices=None): + if old_indices is not None: + assert ( + len(old_indices) <= self.num_samples + ), f"The checkpoint has {len(old_indices)} samples, \ +while the new restart use less samples ({self.num_samples})" + + else: + old_indices = np.array([]) + + # indices includes len(old_indices) but not self.num_samples + indices = np.arange(len(old_indices), self.num_samples) + self.rng_state = self.rng.get_state() + self.rng.shuffle(indices) + # Need to consider drop_last + ramp_steps = (self.batch_size - self.start_bsz) // self.bsz_incre + if self.batch_count < ramp_steps * self.incre_every: + rampup_samples = 0 + for i in range(ramp_steps): + rampup_samples += (i * self.bsz_incre + self.start_bsz) * self.incre_every + assert ( + rampup_samples * self.data_world_size <= self.num_samples + ), f"Too much rampup samples: \ +{rampup_samples*self.data_world_size} Vs. self.num_samples: {self.num_samples}" + + num_samples = (self.num_samples - rampup_samples * self.data_world_size) // ( + self.batch_size * self.data_world_size + ) + num_samples = num_samples * self.batch_size * self.data_world_size + rampup_samples * self.data_world_size + else: + num_samples = self.num_samples // (self.batch_size * self.data_world_size) + num_samples = num_samples * self.batch_size * self.data_world_size + indices = np.concatenate([old_indices, indices]).astype(int) # It needs to be spliced with the previous + indices = indices[:num_samples] + self.indices = indices + assert len(self.indices) >= self.batch_size, "The number of samples should be larger than batch_size" + self.num_consumed_samples_in_epoch = 0 + + def set_epoch(self, epoch): + self.epoch = epoch + self.rng = np.random.RandomState(self.seed + self.epoch) + + def __len__(self): + ramp_steps = (self.batch_size - self.start_bsz) // self.bsz_incre + if self.batch_count < ramp_steps * self.incre_every: + rampup_samples = 0 + for i in range(ramp_steps): + rampup_samples += (i * self.bsz_incre + self.start_bsz) * self.incre_every + assert ( + rampup_samples * self.data_world_size <= self.num_samples + ), f"Too much rampup samples: {rampup_samples*self.data_world_size} \ +Vs. self.num_samples: {self.num_samples}" + + num_batches = (self.num_samples - rampup_samples * self.data_world_size) // self.batch_size + num_batches = num_batches // self.data_world_size + self.incre_every * ramp_steps + else: + num_batches = self.num_samples // self.batch_size // self.data_world_size + + return num_batches + + def __iter__(self): + indices = self.indices[self.data_rank :: self.data_world_size] + while self.num_consumed_samples_in_epoch < len(indices): + batch_rampup_idx = self.batch_count // self.incre_every + cur_batch_size = batch_rampup_idx * self.bsz_incre + self.start_bsz + cur_batch_size = min(cur_batch_size, self.batch_size) + batch = indices[self.num_consumed_samples_in_epoch : self.num_consumed_samples_in_epoch + cur_batch_size] + self.num_consumed_samples_in_epoch += len(batch) # Consider multiple processes. + self.batch_count += 1 + yield batch + + self.get_indices() # get a new round + + def state_dict(self): + states = { + "batch_size": self.batch_size, + "raw_rampup_batch_size": self.raw_rampup_batch_size, + "rng_state": self.rng_state, + "epoch": self.epoch, + "seed": self.seed, + "data_world_size": self.data_world_size, + "num_consumed_samples_in_epoch": self.num_consumed_samples_in_epoch, + "batch_count": self.batch_count, # The batch_count here is due to the existence of multiple processes, + # the batch may be oversent, and it needs to be overwritten by the external batch_count + "indices": self.indices, # The sequence used to breakpoint retraining is the same as before + } + + return states + + def load_state_dict(self, states): + for name in ("data_world_size", "raw_rampup_batch_size", "seed"): # 'batch_size' + assert states[name] == getattr(self, name), (name, states[name], getattr(self, name)) # should not change + self.rng.set_state(states["rng_state"]) + self.get_indices(old_indices=None) # Regenerate indices based on random state + self.epoch = states["epoch"] + self.batch_count = states["batch_count"] + self.num_consumed_samples_in_epoch = states["num_consumed_samples_in_epoch"] + + def copy(self): + copy_sampler = StaticBatchSampler( + self.datasets, + self.batch_size, + self.raw_rampup_batch_size, + self.micro_bsz, + self.seed, + drop_last=True, + data_rank=self.data_rank, + data_world_size=self.data_world_size, + ) + + copy_sampler.load_state_dict(self.state_dict()) + return copy_sampler + + +def get_dpsampler_dataloader( + dataset, + shuffle=False, + seed=1024, + add_sampler=True, + drop_last=False, + pin_memory=False, + num_workers=0, + **kwargs, +): + r"""Set up a deterministic dataloader (also configure seed workers, samplers and whether shuffle or not) + + Note: + When pipeline parallel is enabled, shuffle cannot be True as it will result in mismatch between input data + on the 1st stage and label on the last stage. + + Args: + dataset (:class:`torch.utils.data.Dataset`): The dataset to be loaded. + shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False. + seed (int, optional): Random worker seed for sampling, defaults to 1024. + add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True. + drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size + is not divisible by the batch size. If False and the size of dataset is not divisible by + the batch size, then the last batch will be smaller, defaults to False. + pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False. + num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0. + kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in + `DataLoader `_. + + Returns: + :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing. + """ + _kwargs = kwargs.copy() + + if add_sampler and parallel_state.get_data_parallel_world_size() > 1 and parallel_state._DATA_PARALLEL_GROUP is not None: + sampler = DataParallelSampler(dataset, shuffle=shuffle, drop_last=drop_last) + else: + sampler = None + + # Deterministic dataloader + def seed_worker(): + worker_seed = seed + np.random.seed(worker_seed) + torch.manual_seed(worker_seed) + random.seed(worker_seed) + + if sampler is None: + return DataLoader( + dataset, + worker_init_fn=seed_worker, + shuffle=shuffle, + drop_last=drop_last, + pin_memory=pin_memory, + num_workers=num_workers, + **_kwargs, + ) + else: + return DataLoader( + dataset, + sampler=sampler, + worker_init_fn=seed_worker, + drop_last=drop_last, + pin_memory=pin_memory, + num_workers=num_workers, + **_kwargs, + ) diff --git a/nemo/collections/nlp/data/language_modeling/megatron/internlm/collaters.py b/nemo/collections/nlp/data/language_modeling/megatron/internlm/collaters.py new file mode 100644 index 000000000000..785ecc60fe7e --- /dev/null +++ b/nemo/collections/nlp/data/language_modeling/megatron/internlm/collaters.py @@ -0,0 +1,102 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import torch + + +def packed_collate_fn(batch, packed_length): + + """ + Collate function for packed input sequences. + + Args: + batch (List[Dict]): List of dictionaries representing each sample in batch. + Each dictionary contains "tokens", "labels", "type_ids", "cu_seqlens", and "indexes" keys. + packed_length (int): The length of packed sequence. + + Returns: + Tuple[Dict[str, torch.Tensor], torch.Tensor]: A tuple containing a dictionary of tensors with "input_ids", + "cu_seqlens", "indexes", and "type_ids" keys, and the tensor of padded "labels". + + Raises: + AssertionError: If the length of a sample is not equal to packed_length. + AssertionError: If the shape of the padded "input_ids" tensor does not have the correct shape. + """ + have_image = False + xs, ys, cu_seqlens, indexes, ts, images = [], [], [], [], [], [] + for b in batch: + assert ( + len(b["tokens"]) == packed_length + ), f"length of a sample should be equal to packed_length, but got {len(b['tokens'])} and {packed_length})" + assert ( + len(b["labels"]) == packed_length + ), f"length of a sample should be equal to packed_length, but got {len(b['labels'])} and {packed_length})" + assert ( + len(b["type_ids"]) == packed_length + ), f"length of a sample should be equal to packed_length, but got {len(b['type_ids'])} and {packed_length})" + + tokens = [abs(w) for w in b["tokens"]] + labels = [w if w > 0 else -100 for w in b["labels"]] + + if b.get("images", None) is not None: + have_image = True + cur_images = torch.stack(b["images"]) + images.append(cur_images) + + xs.append(torch.LongTensor(tokens)) + # The labels have been shifted here, so they are aligned with the output corresponding to the token + ys.append(torch.LongTensor(labels)) + ts.append(torch.LongTensor(b["type_ids"])) + cu_seqlens.append(torch.IntTensor(b["cu_seqlens"])) + indexes.append(torch.LongTensor(b["indexes"])) + + xs = torch.nn.utils.rnn.pad_sequence(xs, batch_first=True) + ys = torch.nn.utils.rnn.pad_sequence(ys, batch_first=True, padding_value=-100) + ts = torch.nn.utils.rnn.pad_sequence(ts, batch_first=True, padding_value=0) + indexes = torch.stack(indexes, dim=0) + if len(set(map(len, cu_seqlens))) == 1: # if has uniform length, then stack to save device transfer time + cu_seqlens = torch.stack(cu_seqlens, dim=0) + + assert xs.shape[1] == packed_length, (xs.shape[1], packed_length) + if have_image: + return {"input_ids": xs, "cu_seqlens": cu_seqlens, "indexes": indexes, "type_ids": ts, "images": images}, ys + else: + return {"input_ids": xs, "cu_seqlens": cu_seqlens, "indexes": indexes, "type_ids": ts}, ys + + +def jsonl_ds_collate_fn(batch, max_length_per_sample): + """ + Collate function for json dataset. + + Args: + batch (List[Dict]): List of dictionaries representing each sample in batch. + Each dictionary contains "tokens". + max_length_per_sample (int): The length of output sequence. + + Returns: + Tuple[Dict[str, torch.Tensor], torch.Tensor]: A tuple containing a dictionary of tensors with "input_ids", + and the tensor of padded "labels". + + """ + xs, ys, images = [], [], [] + have_image = False + for x in batch: + x["tokens"] = x["tokens"][:max_length_per_sample] + tokens = [abs(w) for w in x["tokens"]] + labels = [w if w > 0 else -100 for w in x["tokens"]] + labels = labels[1:] + [-100] + if x.get("images", None) is not None: + have_image = True + cur_images = torch.stack(x["images"]) + images.append(cur_images) + xs.append(torch.as_tensor(tokens)) + ys.append(torch.as_tensor(labels)) # y has been shifted + xs = torch.nn.utils.rnn.pad_sequence(xs, batch_first=True) + ys = torch.nn.utils.rnn.pad_sequence(ys, batch_first=True, padding_value=-100) + + xs = torch.cat([xs, xs.new_zeros(len(xs), max_length_per_sample - len(xs[0]))], dim=-1) + ys = torch.cat([ys, ys.new_full((len(ys), max_length_per_sample - len(ys[0])), fill_value=-100)], dim=-1) + if have_image: + return {"input_ids": xs, "images": images}, ys + else: + return {"input_ids": xs}, ys diff --git a/nemo/collections/nlp/data/language_modeling/megatron/internlm/dataset.py b/nemo/collections/nlp/data/language_modeling/megatron/internlm/dataset.py new file mode 100644 index 000000000000..6ce0168d2043 --- /dev/null +++ b/nemo/collections/nlp/data/language_modeling/megatron/internlm/dataset.py @@ -0,0 +1,56 @@ +import os +from typing import Dict + +from torch.utils.data import ConcatDataset + +from nemo.collections.nlp.data.language_modeling.megatron.internlm.single_dataset import JsonlDataset + + +def get_dataset_dict(folder, split="valid") -> Dict: + """ + Return a dictionary of Datasets from a folder containing data files for validation. + + Args: + folder (str): The path to the folder containing data files. + split (str): The split of the data files to be used, default is "valid". + + Returns: + A dictionary containing Datasets for each folder in the given path + that contains data files with the specified split. + + Raises: + AssertionError: If the given folder does not exist. + + Example: + If the given folder is as follows, + - data + - zhihu + - xxx.bin + - valid.bin + - baike + - xxx.bin + - valid.bin + + The returned dictionary will be, + { + 'zhihu': Dataset, + 'baike': Dataset + } + """ + + assert os.path.exists(folder), f"folder `{folder}` not exists" + data_dict = {} + + for root, dirs, files in os.walk(folder, followlinks=True): + dirs.sort() # The order is guaranteed, and the newly added data starting with z needs to be ranked behind + datasets = [] + for fn in sorted(files): # Need sorted to ensure that the order is consistent + if fn.endswith(".bin") and split in fn: + fp = os.path.join(root, fn) + ds = JsonlDataset(fp) + datasets.append(ds) + if datasets: + ds = ConcatDataset(datasets=datasets) + data_dict[os.path.basename(root)] = ds + + return data_dict diff --git a/nemo/collections/nlp/data/language_modeling/megatron/internlm/dummy_dataset.py b/nemo/collections/nlp/data/language_modeling/megatron/internlm/dummy_dataset.py new file mode 100644 index 000000000000..ab5012f482aa --- /dev/null +++ b/nemo/collections/nlp/data/language_modeling/megatron/internlm/dummy_dataset.py @@ -0,0 +1,49 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import numpy as np +from torch.utils.data import Dataset + + +class RandomDataset(Dataset): + """ + RandomDataset for generating random dataset. + + Args: + num_samples (int): The number of samples to generate. + max_len (int): The maximum length of each sample. + + """ + + def __init__(self, num_samples=10000, max_len=1024, fixed_seqlen: bool = False) -> None: + super().__init__() + rng = np.random.RandomState(1999) + max_num = rng.randint(1, 30, size=(num_samples,)) + rep_num = rng.randint(10, 200, size=(num_samples,)) + data = [] + lengths = [] + for n, r in zip(max_num, rep_num): + d = list(range(n)) * r + if fixed_seqlen: + while len(d) < max_len: + r *= 2 + d = list(range(n)) * r + + d = [n, r] + d + d = d[:max_len] + data.append(d) + lengths.append(len(d)) + self.data = data + self.max_len = max_len + self.lengths = np.array(lengths, dtype=int) + + def __getitem__(self, index): + d = self.data[index] + input_ids = np.array(d, dtype=int) + return {"tokens": list(input_ids), "type_id": 0} + + def get_dataset_name(self): + return "dummy_path/dummy_lang/dummy_ds/train.bin" + + def __len__(self): + return len(self.data) diff --git a/nemo/collections/nlp/data/language_modeling/megatron/internlm/dummy_dataset_multimodal.py b/nemo/collections/nlp/data/language_modeling/megatron/internlm/dummy_dataset_multimodal.py new file mode 100644 index 000000000000..28d2b38537d7 --- /dev/null +++ b/nemo/collections/nlp/data/language_modeling/megatron/internlm/dummy_dataset_multimodal.py @@ -0,0 +1,61 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import numpy as np +import torch +from torch.utils.data import Dataset + + +class RandomDatasetMultimodal(Dataset): + """ + RandomDataset for generating random dataset. + + Args: + num_samples (int): The number of samples to generate. + max_len (int): The maximum length of each sample. + image_token_id (int): The placeholder of image. + image_size (int): The image size. + image_patch_size (int): The patch size of vit. + image_token_size (int): The number of placeholder of each image. + + """ + + def __init__( + self, + num_samples=10000, + max_len=2048, + image_token_id=200000, + image_size=336, + image_token_size=(336 // 14) ** 2, + ) -> None: + super().__init__() + rng = np.random.RandomState(1999) + max_num = rng.randint(1, 30, size=(num_samples,)) + rep_num = rng.randint(10, 200, size=(num_samples,)) + data = [] + lengths = [] + images = [ + [torch.randn((3, image_size, image_size))] for _ in range(num_samples) + ] # num_samples x img_num x tensor(C x H x W) + for n, r in zip(max_num, rep_num): + d = list(range(n)) * r + d = [n, r] + [image_token_id] * image_token_size + d + d = d[:max_len] + data.append(d) + lengths.append(len(d)) + self.data = data + self.images = images + self.max_len = max_len + self.lengths = np.array(lengths, dtype=int) + + def __getitem__(self, index): + d = self.data[index] + input_ids = np.array(d, dtype=int) + images = self.images[index] + return {"tokens": list(input_ids), "images": images, "type_id": 0} + + def get_dataset_name(self): + return "dummy_path/dummy_lang/dummy_ds/train.bin" + + def __len__(self): + return len(self.data) diff --git a/nemo/collections/nlp/data/language_modeling/megatron/internlm/packed_dataset.py b/nemo/collections/nlp/data/language_modeling/megatron/internlm/packed_dataset.py new file mode 100644 index 000000000000..91304c716264 --- /dev/null +++ b/nemo/collections/nlp/data/language_modeling/megatron/internlm/packed_dataset.py @@ -0,0 +1,612 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import itertools as it +import operator +import os +from copy import deepcopy +from typing import Dict + +import numpy as np +import torch +import torch.distributed as dist +from megatron.core import parallel_state +from torch.utils.data import ConcatDataset, Dataset +from tqdm import tqdm + +from nemo.collections.nlp.data.language_modeling.megatron.internlm.single_dataset import JsonlDataset +from nemo.collections.nlp.data.language_modeling.megatron.internlm.utils import get_dataset_type_id +from nemo.utils import logging as logger + +# from internlm.core.context import global_context as gpc +# from internlm.data.tokenized.single_dataset import JsonlDataset +# from internlm.data.utils import get_dataset_type_id, get_dataset_type_ids_map +# from internlm.utils.logger import get_logger + + +DEFAULT_SEED = 1024 + + +class PackedDataset(Dataset): + """ + The class PackedDataset takes in a dataset and aggregates samples of different + lengths together based on the packed_length. + + Args: + dataset: The original dataset to pack. + max_length_per_sample: The maximum length of each original sample. Default is 2048. + packed_length: The length of each packed sample. Default is 4096. + """ + + def __init__( + self, + dataset, + max_length_per_sample: int = 2048, + packed_length: int = 4096, + use_packed_dataset: bool = True, + ): + assert hasattr(dataset, "lengths") + assert len(getattr(dataset, "lengths")) == len( + dataset + ), "The dataset must have lengths attribute and have the same length as the dataset" + self.dataset = dataset + self.max_length_per_sample = max_length_per_sample + self.lengths = getattr(self.dataset, "lengths") + self.packed_length = packed_length + self.use_packed_dataset = use_packed_dataset + # Force a seed to be fixed to prevent problems caused by the seed not being restored when restarting + + self.seed = DEFAULT_SEED + + def __getitem__(self, item: int) -> Dict: + """Given the index, it returns a dict as + { + 'tokens': List[int], + 'cu_seqlens': List[int], + 'indexes': List[int], # denotes positional vector as 'tokens' + 'labels': List[int], # corresponds to 'tokens' and shifted yet, -100 means skipping prediction + } + """ + + if self.use_packed_dataset: + return self.build_pack(item) + + return self.build_unpack(item) + + +class PackedDatasetWithoutCuSeqlen(torch.utils.data.Dataset): + """ + A dataset wrapper that aggregates samples with different lengths based on packed_length. + If a sample is shorter than max_length_per_sample, it will be merged with other samples. + For example, given a dataset with 10 samples: + [1, 2, 3, 4, 5] + [6, 7] + [8, 9, 10, 11] + [12, ..., 100] + ... + + Args: + dataset: The original dataset to be wrapped. + max_length_per_sample (int): The maximum length allowed for each sample. + packed_length (int): The desired length for each packed sample. + """ + + def __init__( + self, + dataset, + max_length_per_sample: int = 2048, + packed_length: int = 4096, + debug=False, + ): + assert packed_length % max_length_per_sample == 0 + assert hasattr(dataset, "lengths") + assert len(getattr(dataset, "lengths")) == len( + dataset + ), "The dataset must have lengths attribute and have the same length as the dataset" + self.dataset = dataset + self.max_length_per_sample = max_length_per_sample + self.lengths = getattr(self.dataset, "lengths") + self.bsz = packed_length // max_length_per_sample + self.packed_length = packed_length + self.debug = debug + # Force a seed to be fixed to prevent problems caused by the seed not being restored when restarting + + self.seed = DEFAULT_SEED + indices = np.arange(len(self.lengths)) + rng = np.random.RandomState(self.seed) + rng.shuffle(indices) + self.indices = indices + self.cum_lens = np.cumsum(self.lengths[self.indices]) + self.num_tokens = sum(self.lengths) + + def get_dataset_name(self): + return self.dataset.get_dataset_name() + + def __len__(self): + n_packs = self.num_tokens // self.packed_length + return n_packs + + def find_offset(self, offset): + idx = np.searchsorted(self.cum_lens, offset, side="right") + if idx == 0: + return idx, offset + length = offset - self.cum_lens[idx - 1] + return idx, length + + def pdebug(self, line): + if self.debug: + print(line, flush=True) + + def __getitem__(self, item: int) -> Dict: + """Given the index, it returns a dict as + { + 'tokens': List[int], + 'cu_seqlens': List[int], + 'indexes': List[int], # denotes positional vector as 'tokens' + 'labels': List[int], # corresponds to 'tokens' and shifted yet, -100 means skipping prediction + } + """ + + start_idx, start_length = self.find_offset(item * self.packed_length) + end_idx, end_length = self.find_offset((item + 1) * self.packed_length) + pack_tokens = [] + pack_labels = [] + type_ids = [] + + self.pdebug(f"item : {item}, start_idx:{start_idx}, start_length:{start_length} ") + self.pdebug(f"item : {item}, end_idx:{end_idx}, end_length:{end_length} ") + + if start_idx == end_idx: + idx = self.indices[start_idx] + sample = self.dataset[idx] + self.pdebug(f"item : {item}, idx: {idx}, len : {len(sample['tokens'])}") + tokens = sample["tokens"][start_length:end_length] + pack_tokens.extend(tokens) + pack_labels.extend(tokens[1:] + [-100]) + type_ids.extend([sample["type_id"]] * len(tokens)) + return { + "tokens": pack_tokens, + "cu_seqlens": [i * self.max_length_per_sample for i in range(self.bsz + 1)], + "indexes": list(range(self.max_length_per_sample)) * self.bsz, + "labels": pack_labels, + "type_ids": type_ids, + } + + idx = self.indices[start_idx] + sample = self.dataset[idx] + self.pdebug(f"item : {item}, idx: {idx}, len : {len(sample['tokens'])}") + tokens = sample["tokens"][start_length:] + pack_tokens.extend(tokens) + pack_labels.extend(tokens[1:] + [-100]) + type_ids.extend([sample["type_id"]] * len(tokens)) + + for i in range(start_idx + 1, end_idx): + idx = self.indices[i] + sample = self.dataset[idx] + self.pdebug(f"item : {item}, idx: {idx}, len : {len(sample['tokens'])}") + tokens = sample["tokens"] + pack_tokens.extend(tokens) + pack_labels.extend(tokens[1:] + [-100]) + type_ids.extend([sample.get("type_id")] * len(tokens)) + + # corner case, the last sample is useless + if end_length == 0: + pass + else: + idx = self.indices[end_idx] + sample = self.dataset[idx] + self.pdebug(f"item : {item}, idx: {idx}, len : {len(sample['tokens'])}") + tokens = sample["tokens"][:end_length] + pack_tokens.extend(tokens) + pack_labels.extend(tokens[1:] + [-100]) + type_ids.extend([sample.get("type_id")] * len(tokens)) + + return { + "tokens": pack_tokens, + "cu_seqlens": [i * self.max_length_per_sample for i in range(self.bsz + 1)], + "indexes": list(range(self.max_length_per_sample)) * self.bsz, + "labels": pack_labels, + "type_ids": type_ids, + } + + +class PackedDatasetWithCut(PackedDataset): + """ + The class PackedDataset takes in a dataset and aggregates samples of different + lengths together based on the packed_length using cut mode. + + + max_length_per_sample = 3 + packed_length = 5 + [1, 2] + [3, 4] + [5, 6, 7] + [8, 9, 10, 11, 12, 13] + + ---> + [1, 2, 3, 4, 5] + [6, 7, 8, 9, 10] + [11, 12, 13, 0, 0] + + Args: + dataset: The original dataset to pack. + max_length_per_sample: The maximum length of each original sample. Default is 2048. + packed_length: The length of each packed sample. Default is 4096. + """ + + def __init__( + self, + dataset, + max_length_per_sample: int = 2048, + packed_length: int = 4096, + use_packed_dataset: bool = True, + micro_bsz: int = 1, + ): + super().__init__(dataset, max_length_per_sample, packed_length, use_packed_dataset) + self.sample_indices, self.len_samples_shuffled, self.acm_len_samples = self.accu_sample_len(seed=self.seed) + self.num_tokens = sum(self.lengths) + self.micro_bsz = micro_bsz + + def get_dataset_name(self): + return self.dataset.get_dataset_name() + + def accu_sample_len(self, seed=None): + """accumulative length of samples""" + if seed is not None: + rng = np.random.RandomState(seed) + else: + rng = np.random.RandomState(self.seed - 1) + + sample_indices = np.arange(len(self.lengths)) + rng.shuffle(sample_indices) + len_samples_shuffled = list(map(self.lengths.__getitem__, sample_indices)) + acm_len_samples = list(it.accumulate(len_samples_shuffled, operator.add)) + return sample_indices, len_samples_shuffled, acm_len_samples + + def __len__(self): + # Line 405 of document_to_sequence.py in metaseq is directly spliced, + # without additional consideration of sos or eos + n_packs = self.num_tokens // self.packed_length + return n_packs + + def cal_map(self, carriage_idx: int = 0): + assert carriage_idx >= 0 + length_train = (carriage_idx + 1) * self.packed_length + post_pos = np.searchsorted(self.acm_len_samples, length_train, side="left") + return post_pos + + def mapping(self, pack_idx: int = 0): + # pack_idx is zero-based + pre_pos, pre_token_id = 0, 0 + if pack_idx > 0: + pre_pos = self.cal_map(pack_idx - 1) + pre_token_id = self.len_samples_shuffled[pre_pos] - ( + self.acm_len_samples[pre_pos] - (pack_idx) * self.packed_length + ) + if pre_token_id == self.len_samples_shuffled[pre_pos]: + pre_pos += 1 + pre_token_id = 0 + + pos = self.cal_map(pack_idx) + token_id = self.len_samples_shuffled[pos] - (self.acm_len_samples[pos] - (pack_idx + 1) * self.packed_length) + return pre_pos, pre_token_id, pos, token_id + + def build_pack(self, item): + pre_pos, pre_token_id, pos, token_id = self.mapping(item) + pack, cu_seqlens, indexes, labels, type_ids = [], [0], [], [], [] + + while pre_pos < pos: + sample_idx = self.sample_indices[pre_pos] + sample = self.dataset[sample_idx] + chunk = sample["tokens"][pre_token_id:] + pack.extend(chunk) + _labels = deepcopy(chunk) + _labels = list(_labels[1:]) + [-100] + assert len(_labels) == len(chunk), (_labels, chunk) + labels.extend(_labels) + type_ids.extend([sample.get("type_id", 0)] * len(chunk)) + num_new_samples, tokens_left = divmod(len(chunk), self.max_length_per_sample) + for _ in range(num_new_samples): + cu_seqlens.append(cu_seqlens[-1] + self.max_length_per_sample) + indexes.extend(list(range(self.max_length_per_sample))) + if tokens_left > 0: + cu_seqlens.append(cu_seqlens[-1] + tokens_left) + indexes.extend(list(range(tokens_left))) + pre_pos = pre_pos + 1 + pre_token_id = 0 + + sample_idx = self.sample_indices[pos] + sample = self.dataset[sample_idx] + chunk = sample["tokens"][pre_token_id:token_id] # fragement of a sample + pack.extend(chunk) + _labels = deepcopy(chunk) + if token_id == len(sample["tokens"]): + _labels = list(_labels[1:]) + [-100] + else: + if token_id > len(sample["tokens"]): + print(f"token_id {token_id}, len of sample {len(sample['tokens'])}") + _labels = list(_labels[1:]) + [sample["tokens"][token_id]] + assert len(_labels) == len(chunk), (_labels, chunk) + labels.extend(_labels) + type_ids.extend([sample.get("type_id", 0)] * len(chunk)) + num_new_samples, tokens_left = divmod(len(chunk), self.max_length_per_sample) + for _ in range(num_new_samples): + cu_seqlens.append(cu_seqlens[-1] + self.max_length_per_sample) + indexes.extend(list(range(self.max_length_per_sample))) + if tokens_left > 0: + cu_seqlens.append(cu_seqlens[-1] + tokens_left) + indexes.extend(list(range(tokens_left))) + + out = {"tokens": pack, "cu_seqlens": cu_seqlens, "indexes": indexes, "labels": labels, "type_ids": type_ids} + return out + + def cal_pos_unpack(self, index): + if index == 0: + pre_pos = 0 + else: + pre_pos = index * self.micro_bsz + + pos = (index + 1) * self.micro_bsz + return pre_pos, pos + + def build_unpack(self, index): + """ + max_length_per_sample = 3 + packed_length = 6 + micro_bsz = 2 + [1, 2] + [3, 4] + [5, 6, 7] + [8, 9, 10, 11, 12, 13] + [14, 15, 16, 17] + + ---> + [1, 2, 3, 4, 0, 0] + [5, 6, 7, 8, 9, 10] + [14, 15, 16, 0, 0, 0] + + """ + + pre_pos, pos = self.cal_pos_unpack(index) + + pack, cu_seqlens, indexes, labels, type_ids = [], [0], [], [], [] + + while pre_pos < pos and pre_pos < len(self.dataset): + sample_idx = self.sample_indices[pre_pos] + sample = self.dataset[sample_idx] + length = min(len(sample["tokens"]), self.max_length_per_sample) + chunk = sample["tokens"][0:length] + pack.extend(chunk) + _labels = deepcopy(chunk) + _labels = list(_labels[1:]) + [-100] + assert len(_labels) == len(chunk), (_labels, chunk) + labels.extend(_labels) + type_ids.extend([sample.get("type_id", 0)] * len(chunk)) + cu_seqlens.append(cu_seqlens[-1] + len(chunk)) + indexes.extend(list(range(length))) + pre_pos = pre_pos + 1 + + if cu_seqlens[-1] != self.packed_length: + pack = pack + [0] * (self.packed_length - cu_seqlens[-1]) + labels = labels + [-100] * (self.packed_length - cu_seqlens[-1]) + type_ids = type_ids + [0] * (self.packed_length - cu_seqlens[-1]) + indexes.extend(list(range(self.packed_length - cu_seqlens[-1]))) + cu_seqlens.append(self.packed_length) + + assert len(pack) == self.packed_length + + out = {"tokens": pack, "cu_seqlens": cu_seqlens, "indexes": indexes, "labels": labels, "type_ids": type_ids} + return out + + +def get_packed_dataset_without_short_length( + folder, + max_length_per_sample=2048, + packed_length=4096, + show_progress=False, + min_length=50, + min_length_dict=None, + pack_sample_into_one=False, +): + """ + Given a folder, combine all the .bin files into a single large dataset. + And filter out short samples with length less than 'min_length'. + + Each .bin file is treated as a separate dataset. + + Args: + folder (str): Path to the folder containing the .bin files. + max_length_per_sample (int): Maximum length of each sample. + packed_length (int): Length to pack samples to. + show_progress (bool): Whether to show the progress bar. + min_length (int): The minimum length of the sample. + min_length_dict (dict): The minimum length of the sample for each dataset. + The format is something like {'pile-arxiv': 50} + dataset_backend (Optional[str]): Dataset storage location. Optional parameters are local, local-shm, kv + + Returns: + A packed dataset containing all the data from the .bin files. + """ + + data_parallel_rank = parallel_state.get_data_parallel_rank() + + assert os.path.exists(folder), f"{folder} does not exist." + datasets = [] + delete_samples = 0 + + if dist.get_rank() == 0: + triples = [list(os.walk(folder, followlinks=True))] + else: + triples = [None] + dist.broadcast_object_list(triples, src=0) + triples = triples[0] + + for root, dirs, files in triples: + dirs.sort() # Let the folder need to be returned in a fixed order + if data_parallel_rank == 0: + logger.info(f"Reading {root}...") + num_token_in_folder = 0 + + for fn in tqdm(sorted(files), total=len(files), leave=False, disable=not show_progress): + if fn.endswith(".bin"): + fp = os.path.join(root, fn) + catch_ml_keys = [] + min_length_num = min_length + if min_length_dict is not None: + for k, v in min_length_dict.items(): + if k in fp: + min_length_num = v + catch_ml_keys.append(k) + assert ( + len(catch_ml_keys) < 2 + ), f"The file name `{fp}` matched the following resample keys:{catch_ml_keys}" + + ds_type_id = get_dataset_type_id(path=fp) + ds = JsonlDataset(fp, ds_type_id, min_length=min_length_num) + + if hasattr(ds, "old_length"): + delete_samples += ds.old_length - len(ds) + if len(ds) == 0: + if data_parallel_rank == 0: + logger.info(f"None of the data in `{fp}` is longer than {min_length}") + continue + + if pack_sample_into_one: + ds = PackedDatasetWithoutCuSeqlen(ds, max_length_per_sample, packed_length) + else: + ds = PackedDatasetWithCut(ds, max_length_per_sample, packed_length) + + num_token_in_folder += len(ds) * packed_length + datasets.append(ds) + + dataset = ConcatDataset(datasets=datasets) + if data_parallel_rank == 0: + logger.info( + f"Find `{len(datasets)}` datasets, \ + {len(dataset)} samples, \ + delete `{delete_samples}` because of short length", + ) + + return dataset + + +class PackedDatasetWithPadForMultimodal(PackedDataset): + """ + The class PackedDataset takes in a dataset and aggregates samples of different + lengths together based on the packed_length using pad mode. + + packed_length = 5 + max_length_per_sample = 3 + + [1, 2] + [6, 7] + [3, 4, 5] + [8, 9, 10, 11, 12, 13] + + ---> + [1, 2, 6, 7, 0] + [3 ,4, 5, 0, 0] + [8, 9, 10, 0, 0] + + Args: + dataset: The original dataset to pack. + max_length_per_sample: The maximum length of each original sample. Default is 2048. + packed_length: The length of each packed sample. Default is 4096. + padding_idx: The token id of padding. Default is 0. + """ + + def __init__( + self, + dataset, + max_length_per_sample: int = 2048, + packed_length: int = 4096, + padding_idx: int = 0, + image_token_id: int = 200000, + ): + super().__init__(dataset, max_length_per_sample, packed_length) + self.padding_idx = padding_idx + self.sample_indices, self.belongs = self.accu_sample_len(self.seed) + self.num_tokens = sum(self.lengths) + self.image_token_id = image_token_id + + def get_dataset_name(self): + return self.dataset.get_dataset_name() + + def accu_sample_len(self, seed=None): + """accumulative length of samples""" + if seed is not None: + rng = np.random.RandomState(seed) + else: + rng = np.random.RandomState(self.seed - 1) + + sample_indices = np.arange(len(self.lengths)) + rng.shuffle(sample_indices) + len_samples_shuffled = list(map(self.lengths.__getitem__, sample_indices)) + belongs = np.zeros(len(self.lengths), dtype=np.int32) + cur_num = 0 + cur_tot_len = 0 + last_pos = 0 + for idx, cur_sample_len in enumerate(len_samples_shuffled): + cur_sample_len = min(cur_sample_len, self.max_length_per_sample) + if cur_tot_len + cur_sample_len > self.packed_length: + cur_tot_len = 0 + belongs[last_pos:idx] = cur_num + cur_num += 1 + last_pos = idx + cur_tot_len += cur_sample_len + if cur_tot_len != 0: + belongs[last_pos:] = cur_num + cur_tot_len = 0 + cur_num += 1 + return sample_indices, belongs + + def __len__(self): + return self.belongs[-1] + + def build_pack(self, index): + + pack, cu_seqlens, indexes, labels, type_ids, images = [], [0], [], [], [], [] + + start_pos = np.searchsorted(self.belongs, index, "left") + end_pos = np.searchsorted(self.belongs, index, "right") + assert self.belongs[end_pos - 1] == self.belongs[start_pos] and ( + end_pos >= len(self.belongs) or self.belongs[end_pos] == self.belongs[start_pos] + 1 + ) + cur_samples = self.sample_indices[start_pos:end_pos] + + for sample_idx in cur_samples: + sample = self.dataset[sample_idx] + length = min(len(sample["tokens"]), self.max_length_per_sample) + cur_images = sample["images"] + images.extend(cur_images) + chunk = sample["tokens"][:length] + pack.extend(chunk) + cu_seqlens.append(cu_seqlens[-1] + len(chunk)) + _labels = deepcopy(chunk) + _labels = list(_labels[1:]) + [-100] + for i in range(len(_labels)): + if _labels[i] == self.image_token_id: + _labels[i] = -100 + labels.extend(_labels) + type_ids.extend([sample.get("type_id", 0)] * len(chunk)) + indexes.extend(list(range(length))) + + if cu_seqlens[-1] != self.packed_length: + pack = pack + [self.padding_idx] * (self.packed_length - cu_seqlens[-1]) + labels = labels + [-100] * (self.packed_length - cu_seqlens[-1]) + type_ids = type_ids + [0] * (self.packed_length - cu_seqlens[-1]) + indexes.extend([0] * (self.packed_length - cu_seqlens[-1])) + cu_seqlens.append(self.packed_length) + + out = { + "tokens": pack, + "images": images, + "cu_seqlens": cu_seqlens, + "indexes": indexes, + "labels": labels, + "type_ids": type_ids, + } + return out + + def build_unpack(self, index): + raise NotImplementedError diff --git a/nemo/collections/nlp/data/language_modeling/megatron/internlm/single_dataset.py b/nemo/collections/nlp/data/language_modeling/megatron/internlm/single_dataset.py new file mode 100644 index 000000000000..5477d34ce268 --- /dev/null +++ b/nemo/collections/nlp/data/language_modeling/megatron/internlm/single_dataset.py @@ -0,0 +1,117 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +""" +A .bin file corresponds to a Dataset instance here. +""" + +import json +import mmap +import os +import threading +from pathlib import Path + +import numpy as np +import torch + + +class JsonlDataset(torch.utils.data.Dataset): + """ + + JSONL format is expected to roughly follow that of The Pile. + One-line-per-document of the form: + ``` + { + "tokens": List[int], + } + ``` + + Note that only the "tokens" key is used. + """ + + def __init__(self, path: str, dataset_type_id: int = 0, min_length=50): + self.path = path + self.threadlocal = threading.local() + resolved_path = Path(path).resolve() + self.resolved_path = resolved_path + self.meta = Path(f"{resolved_path}.meta") + self.type_id = dataset_type_id + + # only build the cache in on the primary worker to prevent overloading nfs + assert os.path.exists(self.meta), f"The cache file:{self.meta} is not found for file:{self.path}" + try: + with open(self.meta, "rb") as f: + meta = np.load(f) + except Exception as e: + print(f"Cannot load file {self.meta}...") + raise e + self.offsets = meta[:, 0] + self.lengths = meta[:, -1] + + if min_length > 0: + mask = self.lengths >= min_length + self.old_lengths = self.lengths.copy() + self.old_length = len(self.offsets) + self.offsets = self.offsets[mask] + self.lengths = self.lengths[mask] + + def __getitem__(self, idx): + f = self._get_mmap() + position = self.offsets[idx] + f.seek(position) + item = f.readline().decode("utf-8") + try: + item = json.loads(item) + item["length"] = len(item["tokens"]) # add a length info + item["type_id"] = self.type_id + except Exception as err: + raise json.decoder.JSONDecodeError( + doc=self.path, + pos=position, + msg=( + f"Error while loading JSONL line in file {self.path} at byte " + f"{position}. Contents of line:\n{item}\n{err}" + ), + ) + return item + + def get_dataset_name(self): + return str(self.resolved_path) + + def _get_mmap(self): + if not hasattr(self.threadlocal, "handles"): + with open(self.path, "rb") as f: + mm = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ) + self.threadlocal.handles = [f, mm] + if self.path.endswith(".gz") or self.path.endswith(".bz") or self.path.endswith(".bz2"): + raise NotImplementedError( + "Compressed files are not supported because .seek() would require " + "rereading the entire file, making performance too slow." + ) + return self.threadlocal.handles[-1] + + def __setstate__(self, state): + self.__dict__ = state + self.threadlocal = threading.local() + + def __getstate__(self): + d = {} + for i, v in self.__dict__.items(): + if i != "threadlocal": + d[i] = v + return d + + def __del__(self): + if hasattr(self.threadlocal, "handles"): + # cleanup files we opened on initialization + while self.threadlocal.handles: + self.threadlocal.handles.pop().close() + + @staticmethod + def exists(path): + return os.path.exists(path) + + def __len__(self): + # Virtual length of the dataset depends on the epoch number if the number of documents + # is not perfectly divisible by the data_subshard_count + return len(self.offsets) diff --git a/nemo/collections/nlp/data/language_modeling/megatron/internlm/utils.py b/nemo/collections/nlp/data/language_modeling/megatron/internlm/utils.py new file mode 100644 index 000000000000..ec01502141be --- /dev/null +++ b/nemo/collections/nlp/data/language_modeling/megatron/internlm/utils.py @@ -0,0 +1,44 @@ +#!/usr/bin/env python +# -*- encoding: utf-8 -*- + +import torch + +DATASET_TYPE_IDS_MAP = {"en": 0, "cn": 1, "code": 2} + + +def get_dataset_type_id(path): + import re + + match_idxes = [] + for key, idx in DATASET_TYPE_IDS_MAP.items(): + if re.search(rf"/[z_]*{key}/", path): + match_idxes.append(idx) + assert len(match_idxes) == 1, f"{path}, match_idxes should be 1, but got {match_idxes} from {DATASET_TYPE_IDS_MAP}" + return match_idxes[0] + + +def unpack_data(input_ids, cu_seqlens): + """ + input_ids: (n, packed_length) + Return: + output: (batch_size, max_length) + """ + + bsz = input_ids.shape[0] + + num_sequence = 1 # gpc.config.data["micro_bsz"] + + outputs = torch.zeros(bsz, num_sequence, 4096, device=input_ids.device, dtype=input_ids.dtype) + + for i in range(bsz): + output = torch.zeros(num_sequence, 4096, device=input_ids.device, dtype=input_ids.dtype) + cu_seqlens_slice = cu_seqlens[i] + for j in range(num_sequence): + seq_length = cu_seqlens_slice[j + 1] - cu_seqlens_slice[j] + output[j, 0:seq_length] = input_ids[0, cu_seqlens_slice[j] : cu_seqlens_slice[j + 1]] + outputs[i] = output + + if bsz == 1: + outputs = outputs.squeeze(0) + + return outputs \ No newline at end of file diff --git a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py index 1cb2239388ef..1b1857667488 100644 --- a/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py +++ b/nemo/collections/nlp/models/language_modeling/megatron_gpt_model.py @@ -22,32 +22,46 @@ from importlib.metadata import version from typing import Any, Dict, Iterator, List, Optional, Union +import numpy as np +import scipy import torch +from omegaconf import OmegaConf +from omegaconf.dictconfig import DictConfig +from pkg_resources import packaging +from pytorch_lightning.accelerators import CPUAccelerator +from pytorch_lightning.loops.fetchers import _DataFetcherWrapper +from pytorch_lightning.trainer.trainer import Trainer +from torch import distributed as dist + from nemo.collections.common.parts.utils import extend_instance from nemo.collections.nlp.data.language_modeling.megatron.data_samplers import ( MegatronPretrainingRandomSampler, MegatronPretrainingSampler, ) -from nemo.collections.nlp.data.language_modeling.megatron.gpt_dataset import ( - build_train_valid_test_datasets, +from nemo.collections.nlp.data.language_modeling.megatron.gpt_dataset import build_train_valid_test_datasets +from nemo.collections.nlp.data.language_modeling.megatron.gpt_fim_dataset import GPTFIMDataset, GPTFIMDatasetConfig +from nemo.collections.nlp.data.language_modeling.megatron.internlm.batch_sampler import ( + StaticBatchSampler, + get_dpsampler_dataloader, ) -from nemo.collections.nlp.data.language_modeling.megatron.gpt_fim_dataset import ( - GPTFIMDataset, - GPTFIMDatasetConfig, +from nemo.collections.nlp.data.language_modeling.megatron.internlm.collaters import ( + jsonl_ds_collate_fn, + packed_collate_fn, ) -from nemo.collections.nlp.models.language_modeling.megatron.falcon.falcon_spec import ( - get_falcon_layer_spec, +from nemo.collections.nlp.data.language_modeling.megatron.internlm.dataset import get_dataset_dict +from nemo.collections.nlp.data.language_modeling.megatron.internlm.dummy_dataset import RandomDataset +from nemo.collections.nlp.data.language_modeling.megatron.internlm.packed_dataset import ( + PackedDatasetWithCut, + PackedDatasetWithoutCuSeqlen, + get_packed_dataset_without_short_length, ) +from nemo.collections.nlp.models.language_modeling.megatron.falcon.falcon_spec import get_falcon_layer_spec from nemo.collections.nlp.models.language_modeling.megatron.gpt_full_te_layer_autocast_spec import ( get_gpt_full_te_layer_autocast_spec, ) -from nemo.collections.nlp.models.language_modeling.megatron.gpt_layer_modelopt_spec import ( - get_gpt_layer_modelopt_spec, -) +from nemo.collections.nlp.models.language_modeling.megatron.gpt_layer_modelopt_spec import get_gpt_layer_modelopt_spec from nemo.collections.nlp.models.language_modeling.megatron.gpt_model import GPTModel -from nemo.collections.nlp.models.language_modeling.megatron_base_model import ( - MegatronBaseModel, -) +from nemo.collections.nlp.models.language_modeling.megatron_base_model import MegatronBaseModel from nemo.collections.nlp.modules.common.megatron.build_model import build_model from nemo.collections.nlp.modules.common.megatron.module import Float16Module from nemo.collections.nlp.modules.common.megatron.utils import ( @@ -57,9 +71,7 @@ get_ltor_masks_and_position_ids, get_params_for_weight_decay_optimization, ) -from nemo.collections.nlp.modules.common.text_generation_strategy import ( - TextGenerationStrategy, -) +from nemo.collections.nlp.modules.common.text_generation_strategy import TextGenerationStrategy from nemo.collections.nlp.modules.common.text_generation_utils import ( generate, get_computeprob_response, @@ -80,12 +92,6 @@ from nemo.core.neural_types import ChannelType, NeuralType from nemo.utils import logging from nemo.utils.te_utils import is_float8tensor -from omegaconf import OmegaConf -from omegaconf.dictconfig import DictConfig -from pkg_resources import packaging -from pytorch_lightning.accelerators import CPUAccelerator -from pytorch_lightning.loops.fetchers import _DataFetcherWrapper -from pytorch_lightning.trainer.trainer import Trainer try: import apex.transformer.pipeline_parallel.utils @@ -99,25 +105,13 @@ try: from megatron.core import InferenceParams, parallel_state, tensor_parallel - from megatron.core.datasets.blended_megatron_dataset_builder import ( - BlendedMegatronDatasetBuilder, - ) - from megatron.core.datasets.gpt_dataset import ( - GPTDataset, - GPTDatasetConfig, - MockGPTDataset, - ) + from megatron.core.datasets.blended_megatron_dataset_builder import BlendedMegatronDatasetBuilder + from megatron.core.datasets.gpt_dataset import GPTDataset, GPTDatasetConfig, MockGPTDataset from megatron.core.datasets.utils import get_blend_from_list from megatron.core.dist_checkpointing.dict_utils import dict_list_map_inplace - from megatron.core.dist_checkpointing.mapping import ( - LocalNonpersitentObject, - ShardedObject, - ) + from megatron.core.dist_checkpointing.mapping import LocalNonpersitentObject, ShardedObject from megatron.core.distributed import DistributedDataParallel as McoreDDP - from megatron.core.distributed import ( - DistributedDataParallelConfig, - finalize_model_grads, - ) + from megatron.core.distributed import DistributedDataParallelConfig, finalize_model_grads # NeMo's implementation of the get_gpt_layer_ammo_spec function is temporarily used # from megatron.core.inference.gpt.model_specs import get_gpt_layer_ammo_spec @@ -425,6 +419,43 @@ def __init__(self, cfg: DictConfig, trainer: Trainer): if self.use_loss_mask and self.transformer_config.sequence_parallel: raise ValueError('Loss mask is not supported with sequence parallelism.') + torch_profiler_config = cfg.get("pytorch_profiler", dict(enable=False, trace_path_prefix="./")) + + self._pytorch_current_step = 0 + self._pytorch_max_step = 0 + self._pytorch_profiler = None + self._pytorch_profiler_started = False + self._pytorch_profiler_enable = torch_profiler_config.get("enable", "False") + self._pytorch_profiler_prefix = torch_profiler_config.get("trace_path_prefix", "./") + + # for benchmark + self._BENCHMARK_MODE = True + self._static_attention_mask = None + self._static_position_ids = None + + def get_pytorch_profiler(self, trace_path_prefix): + if self._pytorch_profiler_enable == "False" or self._pytorch_profiler_enable is False: + return None, 20 # For Test + + schedule_config = {"wait": 1, "warmup": 1, "active": 1, "repeat": 1, "skip_first": 8} + trace_path = ( + f"{trace_path_prefix}/rank{dist.get_rank()}_" + f"dp{parallel_state.get_data_parallel_rank()}_" + f"tp{parallel_state.get_tensor_model_parallel_rank()}_" + f"pp{parallel_state.get_pipeline_model_parallel_rank()}" + ) + + llm_profile = torch.profiler.profile( + activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA], + schedule=torch.profiler.schedule(**schedule_config), + on_trace_ready=torch.profiler.tensorboard_trace_handler(trace_path), + with_stack=True, + with_modules=True, + profile_memory=True, + ) + + return llm_profile, 1 + 1 + 1 + 8 + def set_inference_config(self, inference_config): self._inference_config = inference_config @@ -704,17 +735,46 @@ def fwd_bwd_step(self, dataloader_iter, forward_only, first_val_step=None): # we do this inside training_step to support pipeline parallelism fwd_bwd_function = get_forward_backward_func() + def _inner_fwd_bwd_function_with_profiling(): + if self._pytorch_profiler_started is False: + self._pytorch_profiler_started = True + self._pytorch_profiler, self._pytorch_max_step = self.get_pytorch_profiler( + self._pytorch_profiler_prefix + ) + if self._pytorch_profiler is not None: + self._pytorch_profiler.__enter__() + # if dist.get_rank() == 0: + print(f"rank{dist.get_rank()}: Enable Pytorch Profiler!", flush=True) + + _ret = fwd_bwd_function( + forward_step_func=self.get_forward_output_and_loss_func(forward_only), + data_iterator=self._make_data_iterator_list(dataloader_iter), + model=self.model, + num_microbatches=get_num_microbatches(), + forward_only=forward_only, + seq_length=self.cfg.encoder_seq_length, + micro_batch_size=self.cfg.micro_batch_size, + first_val_step=first_val_step, + ) + + self._pytorch_current_step += 1 + if self._pytorch_profiler_started is True and self._pytorch_profiler is not None: + if self._pytorch_current_step > self._pytorch_max_step: + self._pytorch_profiler.__exit__(None, None, None) + self._pytorch_profiler.step() + + # if dist.get_rank() == 0: + + # torch.distributed.barrier() + print(f"rank{dist.get_rank()}: step {self._pytorch_current_step}", flush=True) + + if self._pytorch_current_step > self._pytorch_max_step: + import sys; sys.exit(-1) + + return _ret + # TODO @akhattar: add num_micro_batches_with_partial_activation_checkpoints when ready - losses_reduced_per_micro_batch = fwd_bwd_function( - forward_step_func=self.get_forward_output_and_loss_func(forward_only), - data_iterator=self._make_data_iterator_list(dataloader_iter), - model=self.model, - num_microbatches=get_num_microbatches(), - forward_only=forward_only, - seq_length=self.cfg.encoder_seq_length, - micro_batch_size=self.cfg.micro_batch_size, - first_val_step=first_val_step, - ) + losses_reduced_per_micro_batch = _inner_fwd_bwd_function_with_profiling() # only the last stages of the pipeline return losses if losses_reduced_per_micro_batch: @@ -958,7 +1018,11 @@ def training_step(self, dataloader_iter): batch_size=1, ) self.log( - 'num_consumed_tokens', num_consumed_tokens, prog_bar=True, rank_zero_only=True, batch_size=1, + 'num_consumed_tokens', + num_consumed_tokens, + prog_bar=True, + rank_zero_only=True, + batch_size=1, ) if self.rampup_batch_size: @@ -1189,8 +1253,7 @@ def get_batch_on_this_context_parallel_rank(self, batch): return batch def get_forward_output_and_loss_func(self, validation_step=False, tuning=False): - def fwd_output_and_loss_func(dataloader_iter, model, checkpoint_activations_all_layers=None): - + def fwd_output_and_loss_func_megatron(dataloader_iter, model, checkpoint_activations_all_layers=None): # Get data batch batch = self.get_batch(dataloader_iter, tuning) @@ -1311,7 +1374,196 @@ def loss_func(output_tensor): return output_tensor, loss_func - return fwd_output_and_loss_func + def fwd_output_and_loss_func_internlm( + dataloader_iter, + model, + checkpoint_activations_all_layers=None, + ): + # Get data batch + tmp_batch = next(dataloader_iter) + ins = tmp_batch[0][0] + + if self._BENCHMARK_MODE is True: + # print(f"cu_seqlens: {ins['cu_seqlens'][0]}", flush=True) + if self._static_position_ids is None: + if self.get_attention_mask_from_fusion: + self._static_attention_mask = None + else: + self._static_attention_mask = scipy.linalg.block_diag( + *[ + np.tril(np.ones((l2 - l1, l2 - l1), dtype=bool)) + for l1, l2 in zip(ins['cu_seqlens'][0][:-1], ins['cu_seqlens'][0][1:]) + ] + ) + self._static_attention_mask = torch.tensor( + self._static_attention_mask, dtype=torch.int32, device=ins['input_ids'].device + ).reshape(1, 1, len(self._static_attention_mask), len(self._static_attention_mask)) + + self._static_position_ids = list( + itertools.chain( + *[np.arange(l2 - l1) for l1, l2 in zip(ins['cu_seqlens'][0][:-1], ins['cu_seqlens'][0][1:])] + ) + ) + self._static_position_ids = torch.tensor( + self._static_position_ids, dtype=torch.int32, device=ins['input_ids'].device + ).reshape(1, len(self._static_position_ids)) + + attention_mask, position_ids = self._static_attention_mask, self._static_position_ids + else: + attention_mask = scipy.linalg.block_diag( + *[ + np.tril(np.ones((l2 - l1, l2 - l1), dtype=bool)) + for l1, l2 in zip(ins['cu_seqlens'][0][:-1], ins['cu_seqlens'][0][1:]) + ] + ) + attention_mask = torch.tensor(attention_mask, dtype=torch.int32).reshape( + 1, 1, len(attention_mask), len(attention_mask) + ) + position_ids = list( + itertools.chain( + *[np.arange(l2 - l1) for l1, l2 in zip(ins['cu_seqlens'][0][:-1], ins['cu_seqlens'][0][1:])] + ) + ) + position_ids = torch.tensor(position_ids, dtype=torch.int32, device=ins['input_ids'].device).reshape( + 1, len(position_ids) + ) + + labels = tmp_batch[0][1] + loss_mask = labels != -100 # todo: check 1 is valid or 0 is valid + + batch = { + 'tokens': ins['input_ids'], + 'position_ids': position_ids, + 'attention_mask': attention_mask, + 'labels': labels, + 'loss_mask': loss_mask, + } + + # Transfer needed data to GPU + required_keys = set() + max_seqlen = batch['max_seqlen'].squeeze() if 'max_seqlen' in batch else None + cu_seqlens_argmin = batch['cu_seqlens_argmin'] if 'cu_seqlens_argmin' in batch else None + if parallel_state.get_pipeline_model_parallel_world_size() == 1: + required_keys.update(batch.keys()) + else: + required_keys.add('attention_mask') + if 'cu_seqlens' in batch: + required_keys.add('cu_seqlens') + if parallel_state.is_pipeline_first_stage(): + required_keys.update(('tokens', 'position_ids')) + if parallel_state.is_pipeline_last_stage(): + required_keys.update(('labels', 'loss_mask')) + if self.get_attention_mask_from_fusion and 'attention_mask' in required_keys: + required_keys.remove('attention_mask') + batch = {key: val.cuda(non_blocking=True) if key in required_keys else None for key, val in batch.items()} + + # slice batch along sequence dimension for context parallelism + batch = self.get_batch_on_this_context_parallel_rank(batch) + + # Model forward pass + forward_args = { + 'input_ids': batch['tokens'], + 'position_ids': batch['position_ids'], + 'attention_mask': None if self.get_attention_mask_from_fusion else batch['attention_mask'], + 'labels': batch['labels'] if 'labels' in batch else None, + 'loss_mask': batch['loss_mask'], + } + + if dist.get_rank() == 0: + print(f"forward_args: {forward_args}", flush=True) + + if not self.mcore_gpt: + forward_args['checkpoint_activations_all_layers'] = checkpoint_activations_all_layers + if not self.use_loss_mask: + forward_args.pop('loss_mask') + else: + # TODO: @eharper can we add this to mcore? + forward_args.pop('loss_mask') + + if 'cu_seqlens' in batch: # packed sequence from GPTSFTPackedDataset + # these args are passed eventually into TEDotProductAttention.forward() + cu_seqlens = batch['cu_seqlens'].squeeze() # remove batch size dimension (mbs=1) + # remove -1 "paddings" added in collate_fn + if cu_seqlens_argmin is not None: + cu_seqlens = cu_seqlens[: cu_seqlens_argmin.item()] + else: + cu_seqlens = cu_seqlens[: torch.argmin(cu_seqlens)] + + try: + from megatron.core.packed_seq_params import PackedSeqParams + except (ImportError, ModuleNotFoundError) as e: + mcore_version = packaging.version.Version(version('megatron-core')) + logging.error( + f"megatron-core v{mcore_version} does not support training with packed sequence. " + "Please use megatron-core >= 0.5.0, or set model.data.train_ds.packed_sequence=False" + ) + raise e + + forward_args['packed_seq_params'] = PackedSeqParams( + cu_seqlens_q=cu_seqlens, + cu_seqlens_kv=cu_seqlens, + max_seqlen_q=max_seqlen, + max_seqlen_kv=max_seqlen, + qkv_format='thd', + ) + + output_tensor = model(**forward_args) + + if dist.get_rank() == 0: + print(f"mem allocated: {torch.cuda.max_memory_allocated() / 1024 / 1024 / 1024}", flush=True) + torch.cuda.reset_peak_memory_stats() + + def loss_func(output_tensor): + # Loss for a micro-batch (ub) + loss_for_ub = self.loss_func(batch['loss_mask'], batch['num_valid_tokens_in_ub'], output_tensor) + cp_size = parallel_state.get_context_parallel_world_size() + if self.return_output_tensors: + # TODO: need a better way to check if loss_func is returning more stuff than just loss... (@adithyare) + loss_for_ub, q_hs, d_hs, pos_cs, neg_cs, diff_cs = loss_for_ub + reduced_loss = average_losses_across_data_parallel_group([loss_for_ub]) + pos_cs = average_losses_across_data_parallel_group([pos_cs]) + neg_cs = average_losses_across_data_parallel_group([neg_cs]) + diff_cs = average_losses_across_data_parallel_group([diff_cs]) + return ( + loss_for_ub * cp_size, + { + 'avg': reduced_loss, + 'query_hs': q_hs, + 'doc_hs': d_hs, + 'avg_pos_cs': pos_cs, + 'avg_neg_cs': neg_cs, + 'diff_cs': diff_cs, + }, + ) + elif validation_step and not self.validation_drop_last: + num_valid_tokens_in_ub = batch['num_valid_tokens_in_ub'] + if loss_for_ub.isnan(): + assert batch['loss_mask'].count_nonzero() == 0, 'Got NaN loss with non-empty input' + loss_sum_for_ub = torch.zeros_like(loss_for_ub) + num_valid_tokens_in_ub = 0 + else: + if self.sample_weight == 'constant': + num_valid_tokens_in_ub = 1 + loss_sum_for_ub = num_valid_tokens_in_ub * loss_for_ub + + loss_sum_and_ub_size_all_gpu = torch.cat( + [ + loss_sum_for_ub.clone().detach().view(1), + torch.tensor([num_valid_tokens_in_ub]).cuda().clone().detach(), + ] + ) + # Could potentially reduce num_valid_samples_in_microbatch and use that to aggregate instead of len(self._validation_ds) + torch.distributed.all_reduce( + loss_sum_and_ub_size_all_gpu, group=parallel_state.get_data_parallel_group() + ) + return loss_for_ub * cp_size, {'loss_sum_and_ub_size': loss_sum_and_ub_size_all_gpu} + else: + reduced_loss = average_losses_across_data_parallel_group([loss_for_ub]) + return loss_for_ub * cp_size, {'avg': reduced_loss} + + return output_tensor, loss_func + + return fwd_output_and_loss_func_internlm if self.cfg.data.use_internlm_dl else fwd_output_and_loss_func_megatron def get_forward_output_only_func(self): def fwd_output_only_func(dataloader_iter, model): @@ -1349,11 +1601,7 @@ def fwd_output_only_func(dataloader_iter, model): extra_arg['inference_max_sequence_len'] = inference_max_sequence_len[0].item() # Currently for all MCore transformer layer specs causal attention mask # is used so we can delegate creating it to MCore/TE and pass None below - if ( - isinstance(model, MCoreGPTModel) - or hasattr(model, "module") - and isinstance(model.module, MCoreGPTModel) - ): + if isinstance(model, MCoreGPTModel) or hasattr(model, "module") and isinstance(model.module, MCoreGPTModel): attention_mask = None output_tensor = model(tokens, position_ids, attention_mask, **extra_arg) @@ -1466,6 +1714,12 @@ def loss_func(self, loss_mask, num_valid_tokens_in_ub, output_tensor): return loss def build_train_valid_test_datasets(self): + if self.cfg.data.use_internlm_dl: + return self.build_train_valid_test_datasets_internlm() + else: + return self.build_train_valid_test_datasets_megatron() + + def build_train_valid_test_datasets_megatron(self): if self.trainer.limit_val_batches > 1.0 and isinstance(self.trainer.limit_val_batches, float): raise ValueError("limit_val_batches must be an integer or float less than or equal to 1.0.") logging.info('Building GPT datasets.') @@ -1561,8 +1815,49 @@ def build_train_valid_test_datasets(self): return self._train_ds, self._validation_ds, self._test_ds + def build_train_valid_test_datasets_internlm(self): + if self.cfg.data.train_folder is None or self.cfg.data.train_folder == "": + random_ds = RandomDataset(num_samples=1000000, max_len=self.cfg.encoder_seq_length) + if self.cfg.data.pack_sample_into_one: + self._train_ds = PackedDatasetWithoutCuSeqlen( + random_ds, + max_length_per_sample=self.cfg.encoder_seq_length, + packed_length=self.cfg.micro_batch_size * self.cfg.encoder_seq_length, + ) + else: + self._train_ds = PackedDatasetWithCut( + random_ds, + max_length_per_sample=self.cfg.encoder_seq_length, + packed_length=self.cfg.micro_batch_size * self.cfg.encoder_seq_length, + use_packed_dataset=self.cfg.data.use_packed_dataset, + micro_bsz=self.cfg.micro_batch_size, + ) + else: + self._train_ds = get_packed_dataset_without_short_length( + folder=self.cfg.data.train_folder, + packed_length=self.cfg.micro_batch_size * self.cfg.encoder_seq_length, + max_length_per_sample=self.cfg.encoder_seq_length, + show_progress=parallel_state.get_data_parallel_rank(), + min_length=self.cfg.data.min_length, + min_length_dict={}, + pack_sample_into_one=self.cfg.data.pack_sample_into_one, + ) + if self._train_ds is not None: + logging.info(f'Length of train dataset: {len(self._train_ds)}') + return self._train_ds + def build_pretraining_data_loader( self, dataset, consumed_samples, dataset_type=None, drop_last=True, pad_samples_to_global_batch_size=False + ): + if self.cfg.data.use_internlm_dl: + return self.build_pretraining_data_loader_internlm(dataset) + else: + return self.build_pretraining_data_loader_megatron( + dataset, consumed_samples, dataset_type=None, drop_last=True, pad_samples_to_global_batch_size=False + ) + + def build_pretraining_data_loader_megatron( + self, dataset, consumed_samples, dataset_type=None, drop_last=True, pad_samples_to_global_batch_size=False ): """Buld dataloader given an input dataset.""" @@ -1603,6 +1898,32 @@ def build_pretraining_data_loader( persistent_workers=True if self.cfg.data.num_workers > 0 else False, ) + def build_pretraining_data_loader_internlm(self, dataset): + assert isinstance(dataset, (PackedDatasetWithCut, PackedDatasetWithoutCuSeqlen, torch.utils.data.ConcatDataset)) + # Create the training dataset sampler + train_sampler = StaticBatchSampler( + dataset.datasets if isinstance(dataset, torch.utils.data.ConcatDataset) else [dataset], + batch_size=1, # for compatability with nemo-megatron input tensor shape + rampup_batch_size=self.cfg.get('rampup_batch_size', None), + micro_bsz=self.cfg.micro_batch_size, + seed=1024, + drop_last=True, + data_rank=parallel_state.get_data_parallel_rank(), + data_world_size=parallel_state.get_data_parallel_world_size(), + ) + train_collate_fn = partial( + packed_collate_fn, packed_length=self.cfg.micro_batch_size * self.cfg.encoder_seq_length + ) + train_dl = torch.utils.data.DataLoader( + dataset=dataset, + batch_sampler=train_sampler, + num_workers=self.cfg.data.num_workers, + pin_memory=True, + collate_fn=train_collate_fn, + persistent_workers=self.cfg.data.num_workers, + ) + return train_dl + def setup(self, stage=None): """ PTL hook that is executed after DDP spawns. @@ -1641,8 +1962,8 @@ def setup(self, stage=None): # allowing restored models to optionally setup datasets self.build_train_valid_test_datasets() self.setup_training_data(self.cfg.data) - self.setup_validation_data(self.cfg.data) - self.setup_test_data(self.cfg.data) + # self.setup_validation_data(self.cfg.data) + # self.setup_test_data(self.cfg.data) # Override limit_train_batches in terms of num of microbatches self._reconfigure_limit_batches(self.trainer.limit_train_batches, self._train_dl, 'train') # Override limit_val_batches to be a multiple of num microbatches to prevent val_step from exiting in between a step @@ -1659,7 +1980,7 @@ def setup_training_data(self, cfg): if hasattr(self, '_train_ds'): consumed_samples = self.compute_consumed_samples(0) logging.info( - f'Setting up train dataloader with len(len(self._train_ds)): {len(self._train_ds)} and consumed samples: {consumed_samples}' + f'Setting up train dataloader with len(self._train_ds): {len(self._train_ds)} and consumed samples: {consumed_samples}' ) self._train_dl = self.build_pretraining_data_loader(self._train_ds, consumed_samples) @@ -1667,7 +1988,7 @@ def setup_validation_data(self, cfg): if hasattr(self, '_validation_ds'): consumed_samples = 0 logging.info( - f'Setting up validation dataloader with len(len(self._validation_ds)): {len(self._validation_ds)} and consumed samples: {consumed_samples}' + f'Setting up validation dataloader with len(self._validation_ds): {len(self._validation_ds)} and consumed samples: {consumed_samples}' ) drop_last = True @@ -1688,7 +2009,7 @@ def setup_test_data(self, cfg): if self._test_ds is not None: consumed_samples = 0 logging.info( - f'Setting up test dataloader with len(len(self._test_ds)): {len(self._test_ds)} and consumed samples: {consumed_samples}' + f'Setting up test dataloader with len(self._test_ds): {len(self._test_ds)} and consumed samples: {consumed_samples}' ) self._test_dl = self.build_pretraining_data_loader(self._test_ds, consumed_samples) else: