详解Megatron中的数据混合算法(BlendableDataset)

03-01 1070阅读

🧑‍💻 本文主要讲解Megatron早期版本中的数据混合算法。

目录

  • 1. 数据混合
  • 2. 源码解析
  • 3. 证明部分&讨论
  • 4. 进一步优化

    1. 数据混合

    在谈源码之前,我们有必要先了解一下Megatron中的数据混合思想。

    给定 n n n 个数据集 D 1 , D 2 , ⋯   , D n \mathcal{D}_1,\mathcal{D}_2,\cdots,\mathcal{D}_n D1​,D2​,⋯,Dn​ 和对应的 n n n 个权重 w 1 , w 2 , ⋯   , w n w_1,w_2,\cdots,w_n w1​,w2​,⋯,wn​,我们要按照这些权重去混合 n n n 个数据集,设混合后的数据集为 D \mathcal{D} D。

    Megatron假定:

    • ∣ D ∣ = ∑ i = 1 n ∣ D i ∣ |\mathcal{D}|=\sum_{i=1}^n|\mathcal{D}_i| ∣D∣=∑i=1n​∣Di​∣。即混合后的数据集大小等于混合前的各数据集大小之和。
    • D \mathcal{D} D 中有约 ∣ D ∣ ⋅ w i |\mathcal{D}|\cdot w_i ∣D∣⋅wi​ 个样本来自 D i \mathcal{D}_i Di​。

      那如何确定 D \mathcal{D} D 中到底有多少个样本是来自 D i \mathcal{D}_i Di​ 的呢?一种最直观的做法是,计算 ∣ D ∣ ⋅ w i |\mathcal{D}|\cdot w_i ∣D∣⋅wi​,然后进行取整,但这种操作无法保证所有取整后的 ∣ D ∣ ⋅ w i |\mathcal{D}|\cdot w_i ∣D∣⋅wi​ 相加起来恰好是 ∣ D ∣ |\mathcal{D}| ∣D∣。 如果总和大于 ∣ D ∣ |\mathcal{D}| ∣D∣,说明某些数据集被过采样了,应当减少相应数据集的采样数;如果总和小于 ∣ D ∣ |\mathcal{D}| ∣D∣,说明某些数据集被欠采样了,应当增加相应数据集的采样数。可问题是,如何确定这些被过采样/欠采样的数据集呢?显然我们需要一个更加公平的算法。

      我们可以把获取数据集 D \mathcal{D} D 看作是一个采样过程:一开始有 n n n 个数据源 { D i } i = 1 n \{\mathcal{D}_i\}_{i=1}^n {Di​}i=1n​,每一轮迭代,我们需要先从这 n n n 个数据源中选出一个数据源 D i \mathcal{D}_i Di​,然后再从这个数据源中选出一个样本 S \mathcal{S} S。 由于每一轮迭代只会选出一个样本,因此 ∣ D ∣ |\mathcal{D}| ∣D∣ 轮迭代结束后,我们便得到了 ∣ D ∣ |\mathcal{D}| ∣D∣ 个样本,这些样本构成了混合后的数据集 D \mathcal{D} D。

      每一轮迭代都会产生两个信息:要选取的数据源 D i \mathcal{D}_i Di​,要从 D i \mathcal{D}_i Di​ 中选取的样本。我们可以考虑构造两个整数序列 P , S \mathcal{P},\mathcal{S} P,S,它们的长度均为 ∣ D ∣ |\mathcal{D}| ∣D∣,含义如下:

      • P j \mathcal{P}_j Pj​ 代表的是第 j j j 轮迭代时,选取的数据源的下标。例如 P 10 = 3 \mathcal{P}_{10}=3 P10​=3 意味着第 10 10 10 轮迭代选取的数据源是 D 3 \mathcal{D}_3 D3​。
      • S j \mathcal{S}_j Sj​ 代表的是第 j j j 轮迭代时,从数据源 D P j \mathcal{D}_{\mathcal{P}_j} DPj​​ 选取的样本的下标。

        由以上定义知, ∀ j \forall j ∀j,都有 1 ≤ P j ≤ n 1\leq \mathcal{P}_j\leq n 1≤Pj​≤n, 1 ≤ S j ≤ ∣ D P j  ⁣ ∣ 1\leq \mathcal{S}_j\leq|\mathcal{D}_{\mathcal{P}_j}\!| 1≤Sj​≤∣DPj​​∣(下标均从 1 1 1 开始)。

        接下来的问题是,如何确定每一轮的 P j \mathcal{P}_j Pj​ 和 S j \mathcal{S}_j Sj​ 呢?

        先谈 P j \mathcal{P}_j Pj​。因为是一个从 1 1 1 到 ∣ D ∣ |\mathcal{D}| ∣D∣ 的一个逐步采样过程,在第 j j j 轮迭代时,我们已经抽取了 j − 1 j-1 j−1 个样本,接下来要确定第 j j j 个样本。根据Megatron的假定,在确定下来第 j j j 个样本后,这 j j j 个样本中应当有约 j ⋅ w i j\cdot w_i j⋅wi​ 个样本是来自 D i \mathcal{D}_i Di​ 的。

        考虑构造一个长度为 n n n 的序列 C \mathcal{C} C,该序列随着迭代不断更新。 C i \mathcal{C}_i Ci​ 代表当前已经从 D i \mathcal{D}_i Di​ 抽取了多少个样本。显然可知,第一轮迭代开始时,有 C i = 0 ,   i = 1 , 2 , ⋯   , n \mathcal{C}_i=0,\,i=1,2,\cdots,n Ci​=0,i=1,2,⋯,n。最后一轮迭代结束后,有 ∑ i = 1 n C i = ∣ D ∣ \sum_{i=1}^n\mathcal{C}_i=|\mathcal{D}| ∑i=1n​Ci​=∣D∣,并且

        C i = { ∑ t = 1 j − 1 I ( P t = i ) , P j 确定前 ∑ t = 1 j I ( P t = i ) , P j 确定后 , ∀ i \mathcal{C}_i=\begin{cases} \sum_{t=1}^{j-1} I(\mathcal{P}_t=i),&\text{$\mathcal{P}_j$确定前} \\ \sum_{t=1}^{j} I(\mathcal{P}_t=i),&\text{$\mathcal{P}_j$确定后} \\ \end{cases},\quad \forall i Ci​={∑t=1j−1​I(Pt​=i),∑t=1j​I(Pt​=i),​Pj​确定前Pj​确定后​,∀i

        回到对 P j \mathcal{P}_j Pj​ 的讨论中。假设在确定第 j j j 个样本前已经从 D i \mathcal{D}_i Di​ 中抽取了 C i \mathcal{C}_i Ci​ 个样本,在确定第 j j j 个样本后,诸 C i \mathcal{C}_i Ci​ 中有且仅有一个的值会增加 1 1 1,不妨记为 C k \mathcal{C}_k Ck​,这个过程可以形容为

        [ C 1 , ⋯   , C k , ⋯   , C n ] ⏟ 第 j 轮迭代开始时 → [ C 1 , ⋯   , C k + 1 , ⋯   , C n ] ⏟ 第 j 轮迭代结束时 [ j ⋅ w 1 , j ⋅ w 2 , ⋯   , j ⋅ w n ] ⏟ 理论值 \underbrace{[\mathcal{C}_1,\cdots,\mathcal{C}_k,\cdots,\mathcal{C}_n]}_{第j轮迭代开始时}\to\underbrace{[\mathcal{C}_1,\cdots,\mathcal{C}_{k}+1,\cdots,\mathcal{C}_n]}_{第j轮迭代结束时}\qquad \underbrace{[j\cdot w_1,j\cdot w_2,\cdots,j\cdot w_n]}_{理论值} 第j轮迭代开始时 [C1​,⋯,Ck​,⋯,Cn​]​​→第j轮迭代结束时 [C1​,⋯,Ck​+1,⋯,Cn​]​​理论值 [j⋅w1​,j⋅w2​,⋯,j⋅wn​]​​

        我们期望第 j j j 轮迭代结束时,诸 C i \mathcal{C}_i Ci​ 应当尽可能地接近理论值(在MSE下)。由于只能让其中一个 C k \mathcal{C}_k Ck​ 自增 1 1 1,显然有 k = arg max ⁡ i ( j ⋅ w i − C i ) k=\argmax_i(j\cdot w_i-\mathcal{C}_i) k=argmaxi​(j⋅wi​−Ci​)。

        再谈 S j \mathcal{S}_j Sj​。在确定了数据源是 D k \mathcal{D}_k Dk​ 后,为了避免重复,我们应当做到不放回、随机地从中采样。如何做到这两点呢?我们可以在一开始就对 n n n 个数据源进行打乱,然后在采样的时候只需要从前往后进行,就可以做到以上两点。注意到 C i \mathcal{C}_i Ci​ 的值是从 0 0 0 开始,以步长为 1 1 1 依次递增,所以我们可以用每次更新完的 C i \mathcal{C}_i Ci​ 赋值给相应的 S j \mathcal{S}_j Sj​,即 S j = 第 j 轮迭代结束时的 C i \mathcal{S}_j=第j轮迭代结束时的\mathcal{C}_i Sj​=第j轮迭代结束时的Ci​。

        由此我们可以得到整个算法的伪代码:

        详解Megatron中的数据混合算法(BlendableDataset)

        2. 源码解析

        Python部分:

        class BlendableDataset(torch.utils.data.Dataset):
            def __init__(self, datasets, weights):
                self.datasets = datasets
                num_datasets = len(datasets)
                assert num_datasets == len(weights), "The number of datasets and weights must match."
                self.size = sum(len(dataset) for dataset in self.datasets)
                # Normalize weights.
                weights = np.array(weights, dtype=np.float64)
                sum_weights = np.sum(weights)
                assert sum_weights > 0.0, "Sum of weights must be positive."
                weights /= sum_weights
                # Build indices.
                start_time = time.time()
                assert num_datasets  elapsed time for building blendable dataset indices: '
                             f'{time.time() - start_time:.2f} sec')
            def __len__(self):
                return self.size
            def __getitem__(self, idx):
                dataset_idx = self.dataset_index[idx]
                sample_idx = self.dataset_sample_index[idx]
                return {
                    "dataset_idx": dataset_idx,
                    **self.datasets[dataset_idx][sample_idx],
                }
        

        C++部分:

        void build_blending_indices(
            py::array_t &dataset_index,
            py::array_t &dataset_sample_index,
            const py::array_t &weights,
            const int32_t num_datasets,
            const int64_t size,
            const bool verbose
        ) {
            /* Given multiple datasets and a weighting array, build samples
               such that it follows those weights. */
            if (verbose) {
                std::cout 
                current_samples[i] = 0;
            }
            // For each sample:
            for (int64_t sample_idx = 0; sample_idx 
VPS购买请点击我

文章版权声明:除非注明,否则均为主机测评原创文章,转载或复制请以超链接形式并注明出处。

目录[+]