详解Megatron中的数据混合算法(BlendableDataset)
🧑💻 本文主要讲解Megatron早期版本中的数据混合算法。
目录
- 1. 数据混合
- 2. 源码解析
- 3. 证明部分&讨论
- 4. 进一步优化
1. 数据混合
在谈源码之前,我们有必要先了解一下Megatron中的数据混合思想。
给定 n n n 个数据集 D 1 , D 2 , ⋯ , D n \mathcal{D}_1,\mathcal{D}_2,\cdots,\mathcal{D}_n D1,D2,⋯,Dn 和对应的 n n n 个权重 w 1 , w 2 , ⋯ , w n w_1,w_2,\cdots,w_n w1,w2,⋯,wn,我们要按照这些权重去混合 n n n 个数据集,设混合后的数据集为 D \mathcal{D} D。
Megatron假定:
- ∣ D ∣ = ∑ i = 1 n ∣ D i ∣ |\mathcal{D}|=\sum_{i=1}^n|\mathcal{D}_i| ∣D∣=∑i=1n∣Di∣。即混合后的数据集大小等于混合前的各数据集大小之和。
-
D
\mathcal{D}
D 中有约
∣
D
∣
⋅
w
i
|\mathcal{D}|\cdot w_i
∣D∣⋅wi 个样本来自
D
i
\mathcal{D}_i
Di。
那如何确定 D \mathcal{D} D 中到底有多少个样本是来自 D i \mathcal{D}_i Di 的呢?一种最直观的做法是,计算 ∣ D ∣ ⋅ w i |\mathcal{D}|\cdot w_i ∣D∣⋅wi,然后进行取整,但这种操作无法保证所有取整后的 ∣ D ∣ ⋅ w i |\mathcal{D}|\cdot w_i ∣D∣⋅wi 相加起来恰好是 ∣ D ∣ |\mathcal{D}| ∣D∣。 如果总和大于 ∣ D ∣ |\mathcal{D}| ∣D∣,说明某些数据集被过采样了,应当减少相应数据集的采样数;如果总和小于 ∣ D ∣ |\mathcal{D}| ∣D∣,说明某些数据集被欠采样了,应当增加相应数据集的采样数。可问题是,如何确定这些被过采样/欠采样的数据集呢?显然我们需要一个更加公平的算法。
我们可以把获取数据集 D \mathcal{D} D 看作是一个采样过程:一开始有 n n n 个数据源 { D i } i = 1 n \{\mathcal{D}_i\}_{i=1}^n {Di}i=1n,每一轮迭代,我们需要先从这 n n n 个数据源中选出一个数据源 D i \mathcal{D}_i Di,然后再从这个数据源中选出一个样本 S \mathcal{S} S。 由于每一轮迭代只会选出一个样本,因此 ∣ D ∣ |\mathcal{D}| ∣D∣ 轮迭代结束后,我们便得到了 ∣ D ∣ |\mathcal{D}| ∣D∣ 个样本,这些样本构成了混合后的数据集 D \mathcal{D} D。
每一轮迭代都会产生两个信息:要选取的数据源 D i \mathcal{D}_i Di,要从 D i \mathcal{D}_i Di 中选取的样本。我们可以考虑构造两个整数序列 P , S \mathcal{P},\mathcal{S} P,S,它们的长度均为 ∣ D ∣ |\mathcal{D}| ∣D∣,含义如下:
- P j \mathcal{P}_j Pj 代表的是第 j j j 轮迭代时,选取的数据源的下标。例如 P 10 = 3 \mathcal{P}_{10}=3 P10=3 意味着第 10 10 10 轮迭代选取的数据源是 D 3 \mathcal{D}_3 D3。
-
S
j
\mathcal{S}_j
Sj 代表的是第
j
j
j 轮迭代时,从数据源
D
P
j
\mathcal{D}_{\mathcal{P}_j}
DPj 选取的样本的下标。
由以上定义知, ∀ j \forall j ∀j,都有 1 ≤ P j ≤ n 1\leq \mathcal{P}_j\leq n 1≤Pj≤n, 1 ≤ S j ≤ ∣ D P j ∣ 1\leq \mathcal{S}_j\leq|\mathcal{D}_{\mathcal{P}_j}\!| 1≤Sj≤∣DPj∣(下标均从 1 1 1 开始)。
接下来的问题是,如何确定每一轮的 P j \mathcal{P}_j Pj 和 S j \mathcal{S}_j Sj 呢?
先谈 P j \mathcal{P}_j Pj。因为是一个从 1 1 1 到 ∣ D ∣ |\mathcal{D}| ∣D∣ 的一个逐步采样过程,在第 j j j 轮迭代时,我们已经抽取了 j − 1 j-1 j−1 个样本,接下来要确定第 j j j 个样本。根据Megatron的假定,在确定下来第 j j j 个样本后,这 j j j 个样本中应当有约 j ⋅ w i j\cdot w_i j⋅wi 个样本是来自 D i \mathcal{D}_i Di 的。
考虑构造一个长度为 n n n 的序列 C \mathcal{C} C,该序列随着迭代不断更新。 C i \mathcal{C}_i Ci 代表当前已经从 D i \mathcal{D}_i Di 抽取了多少个样本。显然可知,第一轮迭代开始时,有 C i = 0 , i = 1 , 2 , ⋯ , n \mathcal{C}_i=0,\,i=1,2,\cdots,n Ci=0,i=1,2,⋯,n。最后一轮迭代结束后,有 ∑ i = 1 n C i = ∣ D ∣ \sum_{i=1}^n\mathcal{C}_i=|\mathcal{D}| ∑i=1nCi=∣D∣,并且
C i = { ∑ t = 1 j − 1 I ( P t = i ) , P j 确定前 ∑ t = 1 j I ( P t = i ) , P j 确定后 , ∀ i \mathcal{C}_i=\begin{cases} \sum_{t=1}^{j-1} I(\mathcal{P}_t=i),&\text{$\mathcal{P}_j$确定前} \\ \sum_{t=1}^{j} I(\mathcal{P}_t=i),&\text{$\mathcal{P}_j$确定后} \\ \end{cases},\quad \forall i Ci={∑t=1j−1I(Pt=i),∑t=1jI(Pt=i),Pj确定前Pj确定后,∀i
回到对 P j \mathcal{P}_j Pj 的讨论中。假设在确定第 j j j 个样本前已经从 D i \mathcal{D}_i Di 中抽取了 C i \mathcal{C}_i Ci 个样本,在确定第 j j j 个样本后,诸 C i \mathcal{C}_i Ci 中有且仅有一个的值会增加 1 1 1,不妨记为 C k \mathcal{C}_k Ck,这个过程可以形容为
[ C 1 , ⋯ , C k , ⋯ , C n ] ⏟ 第 j 轮迭代开始时 → [ C 1 , ⋯ , C k + 1 , ⋯ , C n ] ⏟ 第 j 轮迭代结束时 [ j ⋅ w 1 , j ⋅ w 2 , ⋯ , j ⋅ w n ] ⏟ 理论值 \underbrace{[\mathcal{C}_1,\cdots,\mathcal{C}_k,\cdots,\mathcal{C}_n]}_{第j轮迭代开始时}\to\underbrace{[\mathcal{C}_1,\cdots,\mathcal{C}_{k}+1,\cdots,\mathcal{C}_n]}_{第j轮迭代结束时}\qquad \underbrace{[j\cdot w_1,j\cdot w_2,\cdots,j\cdot w_n]}_{理论值} 第j轮迭代开始时 [C1,⋯,Ck,⋯,Cn]→第j轮迭代结束时 [C1,⋯,Ck+1,⋯,Cn]理论值 [j⋅w1,j⋅w2,⋯,j⋅wn]
我们期望第 j j j 轮迭代结束时,诸 C i \mathcal{C}_i Ci 应当尽可能地接近理论值(在MSE下)。由于只能让其中一个 C k \mathcal{C}_k Ck 自增 1 1 1,显然有 k = arg max i ( j ⋅ w i − C i ) k=\argmax_i(j\cdot w_i-\mathcal{C}_i) k=argmaxi(j⋅wi−Ci)。
再谈 S j \mathcal{S}_j Sj。在确定了数据源是 D k \mathcal{D}_k Dk 后,为了避免重复,我们应当做到不放回、随机地从中采样。如何做到这两点呢?我们可以在一开始就对 n n n 个数据源进行打乱,然后在采样的时候只需要从前往后进行,就可以做到以上两点。注意到 C i \mathcal{C}_i Ci 的值是从 0 0 0 开始,以步长为 1 1 1 依次递增,所以我们可以用每次更新完的 C i \mathcal{C}_i Ci 赋值给相应的 S j \mathcal{S}_j Sj,即 S j = 第 j 轮迭代结束时的 C i \mathcal{S}_j=第j轮迭代结束时的\mathcal{C}_i Sj=第j轮迭代结束时的Ci。
由此我们可以得到整个算法的伪代码:
2. 源码解析
Python部分:
class BlendableDataset(torch.utils.data.Dataset): def __init__(self, datasets, weights): self.datasets = datasets num_datasets = len(datasets) assert num_datasets == len(weights), "The number of datasets and weights must match." self.size = sum(len(dataset) for dataset in self.datasets) # Normalize weights. weights = np.array(weights, dtype=np.float64) sum_weights = np.sum(weights) assert sum_weights > 0.0, "Sum of weights must be positive." weights /= sum_weights # Build indices. start_time = time.time() assert num_datasets elapsed time for building blendable dataset indices: ' f'{time.time() - start_time:.2f} sec') def __len__(self): return self.size def __getitem__(self, idx): dataset_idx = self.dataset_index[idx] sample_idx = self.dataset_sample_index[idx] return { "dataset_idx": dataset_idx, **self.datasets[dataset_idx][sample_idx], }
C++部分:
void build_blending_indices( py::array_t &dataset_index, py::array_t &dataset_sample_index, const py::array_t &weights, const int32_t num_datasets, const int64_t size, const bool verbose ) { /* Given multiple datasets and a weighting array, build samples such that it follows those weights. */ if (verbose) { std::cout current_samples[i] = 0; } // For each sample: for (int64_t sample_idx = 0; sample_idx