论文115:Reinforced GNNs for multiple instance learning (TNNLS‘24)
文章目录
- 1 要点
- 2 预备知识
- 2.1 MIL
- 2.2 MIL-GNN
- 2.3 Markov博弈
- 2.4 深度Q-Learning
- 3 方法
- 3.1 观测生成与交互
- 3.2 动作选择和指导
- 3.3 奖励计算
- 3.4 状态转移和终止
- 3.5 多智能体训练
1 要点
题目:用于MIL的强化GNN
代码:https://github.com/RingBDStack/RGMIL
背景:MIL是一种监督学习变体,它处理包含多个实例的包,其中训练阶段只有包级别的标签可用。MIL在现实世界的应用中很多,尤其是在医学领域;
挑战:现有的GNN在MIL中通常需要过滤实例间的低置信度边,并使用新的包结构来调整图神经网络架构。这样的调整过程频繁且忽视了结构和架构之间的相关性;
RGMIL框架:首次在MIL任务中利用多智能体深度强化学习 (MADRL)。MADRL允许灵活定义或扩展影响包图或GNN的因素,并同步控制它们;
贡献:
- 引入MADRL到MIL中,实现对包结构和GNN架构的自动化和同步控制;
- 使用边阈值和GNN层数作为因素案例来构建RGMIL,探索了以前在MIL研究中被忽视的边密度和聚合范围之间的相关性;
- 实验结果表明,RGMIL在多个MIL数据集上实现了最佳性能,并且具有出色的可解释性;
细节:
- RGMIL将训练过程建模为一个完全合作的马尔可夫博弈 (MG);
- 通过两个智能体搜索边过滤阈值和GNN层数;
- 利用反向分解网络 (VDN) 来衡量智能体的贡献和相关性;
- 引入图注意力网络 (GAT) 并设计参数共享机制以提高效率;
符号表:
符号 含义 B \mathcal{B} B 包集合 G \mathcal{G} G 与包相对应的、图的集合 Y \mathcal{Y} Y 包标签 M \mathcal{M} M Markov博弈的七元组 S \mathcal{S} S M \mathcal{M} M的状态空间 O \mathcal{O} O M \mathcal{M} M的观测空间 A \mathcal{A} A M \mathcal{M} M的动作空间 L \mathcal{L} L 智能体或者GNN模型的训练损失 N N N 包数量 M M M 包内实例数量 L L L GNN层的数量 T T T 时间步的数量 I I I 智能体的数量 D D D 特征表示的维度 A \mathbf{A} A 与图相对应的邻接矩阵 F \mathbf{F} F 与图相对应的实例特征矩阵 E \mathbf{E} E 与图相对应的包图特征矩阵 Z \mathbf{Z} Z 特征变换矩阵 C \mathbf{C} C 重要性系数矩阵 i ; j ; k ; l ; t i;j;k;l;t i;j;k;l;t 索引变量 s ; o ; a ; r s;o;a;r s;o;a;r 状态、观测、动作、奖励 v v v 注意力机制特征向量 γ \gamma γ 折扣系数 α \alpha α 智能体学习率 μ \mu μ 动作或者奖励的窗口大小 λ \lambda λ 终止条件的奖励阈值 & ; % \&;\% &;% 逻辑和取余运算 ⊕ \oplus ⊕ 拼接操作 ∥ ⋅ ∥ \|\cdot\| ∥⋅∥ 矩阵的Norm函数 σ ( ⋅ ) \sigma(\cdot) σ(⋅) 激活函数 π ( ⋅ ) \pi(\cdot) π(⋅) 智能体状态-动作函数 RWD ( ⋅ ) \text{RWD}(\cdot) RWD(⋅) 奖励函数 TRN ( ⋅ ) \text{TRN}(\cdot) TRN(⋅) 状态转移函数 AGG ( ⋅ ) \text{AGG}(\cdot) AGG(⋅) 特征聚合函数 POL ( ⋅ ) \text{POL}(\cdot) POL(⋅) 特征池化函数 EVL ( ⋅ ) \text{EVL}(\cdot) EVL(⋅) 分类性能评估函数 2 预备知识
2.1 MIL
令 B = { B i ∣ i = 1 , … , N } \mathcal{B}=\{\mathcal{B}_i|i=1,\dots,N\} B={Bi∣i=1,…,N}表示包含多个包 B i = { B i , j ∣ j = 1 … , M } \mathcal{B}_i=\{\mathcal{B}_{i,j}|j=1\dots,M\} Bi={Bi,j∣j=1…,M},其中 N N N和 M M M分别表示包和包中实例的数量 (通常 M M M是变化的)。每个包对应一个两类包标签 Y i = max ( Y i , 1 , … , Y i , M ) \mathcal{Y}_i=\max(\mathcal{Y}_{i,1},\dots,\mathcal{Y}_{i,M}) Yi=max(Yi,1,…,Yi,M),其中 Y i , j ∈ { 0 , 1 } \mathcal{Y}_{i,j}\in\{0,1\} Yi,j∈{0,1}是假设的实例标签。尽管数据集中少量的实例具有真实的标签,然而在MIL的训练过程中,实例标签是不可用的。因此,MIL的目标是学习一个将包映射为标签的映射函数 B → Y \mathcal{B\to Y} B→Y,其中 Y = { Y i ∣ i = 1 , … , N } \mathcal{Y}=\{ \mathcal{Y}_i | i=1,\dots,N \} Y={Yi∣i=1,…,N}。
2.2 MIL-GNN
对于MIL-GNN,其首先需要将所有的包转换为一个图的集合 G = { G i ∣ i = 1 , … , N } \mathcal{G}=\{ \mathcal{G}_i|i=1,\dots,N \} G={Gi∣i=1,…,N},其中每个包对应一个图 G i = ( A i , F i ) \mathcal{G}_i=(\mathbf{A}_i,\mathbf{F}_i) Gi=(Ai,Fi),此外,每个实例可以看作是一个节点。每个邻接矩阵 A i ∈ R M × M \mathbf{A}_i\in\mathbb{R}^{M\times M} Ai∈RM×M使用原始节点特征构建,并通过阈值来过滤边,其每个元素表示一跳邻域信息。 F i ∈ R M × D \mathbf{F}_i\in\mathbb{R}^{M\times D} Fi∈RM×D表示实例节点的特征矩阵。
基于此, L L L层GNN被用于传递节点特征信息,其中对于第 i i i个图 G i \mathcal{G}_i Gi,其在第 l l l层的聚合过程表示为:
F i l = σ ( AGG l ( A i , F i l − 1 ) ) (1) \tag{1} \mathbf{F}_i^l=\sigma\left( \text{AGG}^l (\mathbf{A}_i,\mathbf{F}_i^{l-1})\right) Fil=σ(AGGl(Ai,Fil−1))(1)其中 AGG l ( ⋅ ) \text{AGG}^l(\cdot) AGGl(⋅)表示在第 l l l层的聚合函数,例如卷积和注意力、 σ ( ⋅ ) \sigma(\cdot) σ(⋅)表示激活函数、 F i l \mathbf{F}_i^l Fil是更新后的特征矩阵。
接下来,一个节点特征池化函数 POL ( ⋅ ) \text{POL}(\cdot) POL(⋅)被用于GNN的最后一层,以获取最终的图级别特征矩阵 E ( i ) ∈ R 1 × D \mathbf{E}(i)\in\mathbb{R}^{1\times D} E(i)∈R1×D:
E ( i ) = POL ( { F i L ( j ) ∣ j = 1 , … , M } ) (2) \tag{2} \mathbf{E}(i)=\text{POL}(\{ \mathbf{F}_i^L(j) |j=1,\dots,M \}) E(i)=POL({FiL(j)∣j=1,…,M})(2)其中 F i L ( j ) ∈ R 1 × D \mathbf{F}_i^L(j)\in\mathbb{R}^{1\times D} FiL(j)∈R1×D是实例节点 B i , j \mathcal{B}_{i,j} Bi,j的特征向量。最后, E ( i ) \mathbf{E}(i) E(i)传递给一个包图分类器。因此,在MIL-GNN中,其映射过程为 B → G → Y \mathcal{B\to G\to Y} B→G→Y。
2.3 Markov博弈
在多智能体强化学习 (MARL) 中,Markov博弈 (MG) 是从Markov决策过程 (MDP) 扩展而来。特别地,一个MG包含多个能够共同影响奖励和状态转移的智能体。根据是否所有的智能体都能完全获得全局状态信息,已有的MG可被看作是完全或者部分可观测,其中后者则更为普遍。
部分可观测的MG可以被抽象为一个七元组 M = \mathcal{M}= M=,其中:
- S \mathcal{S} S:MG的全局状态空间;
- A i \mathcal{A}_i Ai:第 i i i个智能体的动作空间。在每个时间步 t ∈ [ 1 , T ] t\in[1,T] t∈[1,T],每个智能体根据其独有的状态动作函数 π i ( ⋅ ) \pi_i(\cdot) πi(⋅)来选择动作 a i t ∈ A i a_i^t\in\mathcal{A}_i ait∈Ai;
- 每个智能体会从全局状态获得一个独立的部分观察 o i t ∈ O i o_i^t\in\mathcal{O}_i oit∈Oi,因此, π i ( ⋅ ) \pi_i(\cdot) πi(⋅)可以表示为 S → O i → A i \mathcal{S\to O_i\to A_i} S→Oi→Ai;
- 每个智能体使用其奖励函数 RED i ( ⋅ ) \text{RED}_i(\cdot) REDi(⋅)获得即时奖励 r i t r_i^t rit,这种博弈也被称为分散的部分可观测MDP (Dec-POMDP),旨在最大化累积奖励 ∑ t = 1 T γ ( t − 1 ) r ∗ t \sum^T_{t=1}\gamma^{(t−1)}r^{*t} ∑t=1Tγ(t−1)r∗t,其中 γ γ γ表示控制后续奖励的折扣系数;
- 状态转移函数 TRN ( ⋅ ) \text{TRN}(\cdot) TRN(⋅)将当前状态 s t s^t st与联合动作 a ∗ t a^{*t} a∗t映射到下一个状态 s ( t + 1 ) s^{(t+1)} s(t+1),即 S × A ∗ → S \mathcal{S \times A^*\to S} S×A∗→S。
2.4 深度Q-Learning
作为基于价值的RL的基算法,Q-Learning非常适合实现单一智能体的顺序决策系统。QLearning包含一个状态-动作表 π ( ⋅ ) π(·) π(⋅),它记录了各种状态下所有可能动作的 Q Q Q值。初始化后,智能体不断与环境交互,并通过Bellman方程更新 π ( ⋅ ) π(·) π(⋅)直到收敛。 π ( ⋅ ) π(·) π(⋅)的更新过程可以表示如下:
x = x + α [ r t + γ max a π ( s t + 1 , a ) − x ] ] s.t. x = π ( s t , a t ) (3) \tag{3} \begin{aligned} & x = x + \alpha \left[ r_t + \gamma \max_{a} \pi(s_{t+1}, a) - x \right] ]\\ & \text{s.t. } x = \pi(s_t, a_t) \end{aligned} x=x+α[rt+γamaxπ(st+1,a)−x]]s.t. x=π(st,at)(3)其中: π ( s t , a t ) \pi(s_t, a_t) π(st,at)是预测的Q值,以及在状态 s t s_t st下选择动作 a t a_t at的预期奖励、 r t r_t rt表示时间步 t t t的即时奖励、 max a π ( s t + 1 , a ) \max_a \pi(s_{t+1}, a) maxaπ(st+1,a)是下一个状态 s t + 1 s_{t+1} st+1的最大Q值,以及 α \alpha α是 π ( ⋅ ) \pi(·) π(⋅)的学习率。
在实际应用中,许多环境的状态空间是无限的,记录所有状态-动作对的值是不可行的。受深度学习的启发,许多工作引入了深度神经网络 (DNN) 来近似返回值,其中深度Q-Learning (DQN) 是传统Q-Learning的直接扩展:
- DQN使用DNN构建动作-价值函数 π π π (亦称为 Q Q Q函数),该函数将每个状态向量映射到 Q Q Q值向量 π ( s ) ∈ R 1 × ∣ A ∣ \pi(s) \in \mathbb{R}^{1 \times |A|} π(s)∈R1×∣A∣,其中 ∣ A ∣ |A| ∣A∣表示动作空间 A A A的大小;
- DQN应用经验回放和目标网络技术来更新函数
π
(
⋅
)
\pi(·)
π(⋅)。例如,给定过去时间步
t
t
t的经验记录,其元组形式为
⟨
s
t
,
a
t
,
r
t
,
s
t
+
1
⟩
\langle s_t, a_t, r_t, s_{t+1} \rangle
⟨st,at,rt,st+1⟩,则
π
π
π的时序差分损失可以计算如下:
L π = E s , a , r , s ′ [ ( π ‾ ( s t , a t ) − π ( s t , a t ) ) 2 ] s.t. π ‾ ( s t , a t ) = r t + γ max a π ‾ ( s t + 1 , a ) (4) \tag{4} \begin{aligned} &L_\pi = \mathbb{E}_{s,a,r,s'} \left[ \left( \overline{\pi}(s_t, a_t) - \pi(s_t, a_t) \right)^2 \right]\\ &\text{s.t. } \overline{\pi}(s_t, a_t) = r_t + \gamma \max_a \overline{\pi}(s_{t+1}, a) \end{aligned} Lπ=Es,a,r,s′[(π(st,at)−π(st,at))2]s.t. π(st,at)=rt+γamaxπ(st+1,a)(4)其中: π ( ⋅ ) \pi(·) π(⋅)表示评估网络,其用于预测状态 s t s_t st和动作 a t a_t at的 Q Q Q值的评估网络、 π ‾ ( ⋅ ) \overline{\pi}(·) π(⋅)是一个目标网络,其架构与 π ( ⋅ ) \pi(·) π(⋅)相同。只有 π ( ⋅ ) \pi(·) π(⋅)被优化,并且其训练参数周期性复制到 π ‾ ( ⋅ ) \overline{\pi}(·) π(⋅)。由于 π ‾ \overline{\pi} π不更新时目标 Q Q Q值是稳定的,因此 π ( ⋅ ) \pi(·) π(⋅)的训练稳定性是极好的;
- 为了权衡探索新动作的概率,DQN应用了
ϵ
ϵ
ϵ-贪婪算法。因此,它并不总是选择
π
(
s
)
\pi(s)
π(s)中最大条目的对应动作,其可以表示如下:
a = { random action , w.p. ϵ argmax a π ( s t , a ) , w.p. 1 − ϵ (5) \tag{5} \begin{aligned} a = \begin{cases} \text{random action}, & \text{w.p.} \quad\epsilon \\ \text{argmax}_a \pi(s_t, a), & \text{w.p.} \quad 1 - \epsilon \end{cases} \end{aligned} a={random action,argmaxaπ(st,a),w.p.ϵw.p.1−ϵ(5)其中, ϵ \epsilon ϵ表示随机选择动作的概率,即探索,而 1 − ϵ 1-\epsilon 1−ϵ表示选择当前基于 π π π的最优动作,即利用。通过这样做,DQN避免了在强化学习任务中的探索-利用困境,避开了局部最优,并促进了更好的 π π π函数的发现。
3 方法
本节介绍RGMIL的细节,包括:1) 用于提升博弈公平性的观测生成与交互;2) 用于提升GNN效率的动作选择和指导技术;3) 用于提升博弈稳定性的奖励计算;4) 用于确保博弈收敛的状态转移和终止技术;以及5) 多智能体训练。
RGMIL的总览如图4所示,其中左子图对应章节3.1至3.4,右子图对应章节3.5。
图4:RGMIL总览。左右子图分别对应经验收集和代理优化:1) 每一个时间步,初始观测从当前的block导出;2) 观测作为代理的输入,用于选择当前的动作;3) 构建可信包图,并作为定制的GNN的输入;4) GNN训练后,通过动作组合来评估性能,并确定当前的奖励;5) 带有动作的转移函数作为输入,以生成下一次观测;6) 记录以上过程,到达一定数量后,由VDN执行代理优化3.1 观测生成与交互
在RGMIL中,我们将其训练过程建模为一个合作的马尔可夫博弈 (MG),涉及两个智能体,分别用于搜索最佳的边过滤阈值和GNN层数:
- 利用一个改进的VDN来实现MG:
- 将训练集划分为多个等大小的区块,其中一个区块作为验证集,其余区块用作构建MG状态空间 S S S;
- 在第一个时间步之前,随机选择一个训练区块作为全局状态;
- 由于边过滤阈值的选择通常与拓扑信息相关,我们随后指定当前状态中包图的结构特征作为第一个智能体的观察;
- 通过包图的成对相似性建立实例节点的初始边。以属于当前区块的第 i i i个包 B i \mathcal{B}_i Bi为例,它的包图 G i \mathcal{G}_i Gi可以被抽象为一个邻接矩阵 A i \mathbf{A}_i Ai以及一个特征矩阵 F i \mathbf{F}_i Fi;
- 给定初始矩阵
F
i
0
\mathbf{F}^0_i
Fi0,初始邻接矩阵
A
i
\mathbf{A}_i
Ai的计算如下:
A i ( j , j ′ ) = ∥ F i 0 ( j ) − F i 0 ( j ′ ) ∥ 2 (6) \tag{6} \mathbf{A}_i(j, j') = \|\mathbf{F}^0_i(j) - \mathbf{F}^0_i(j')\|_2 Ai(j,j′)=∥Fi0(j)−Fi0(j′)∥2(6)其中 ∥ ⋅ ∥ 2 \|\cdot\|_2 ∥⋅∥2表示矩阵的二范数、 A i ( j , j ′ ) \mathbf{A}_i(j, j') Ai(j,j′)编码了第 j j j个和第 j ′ j' j′个实例节点之间的欧式距离。
- 因此,第一个智能体的观察计算如下:
o 1 ( d ) = 1 N d ∑ i = 1 N d exp ( − A i ) s.t. M i = d , d ∈ [ 1 , max M i ] (7) \tag{7} \begin{aligned} &o_1(d) = \frac{1}{N_d} \sum_{i=1}^{N_d} \exp(-\mathbf{A}_i)\\ & \text{s.t. } M_i = d, \quad d \in [1, \max M_i] \end{aligned} o1(d)=Nd1i=1∑Ndexp(−Ai)s.t. Mi=d,d∈[1,maxMi](7)其中 o 1 ( d ) o_1(d) o1(d)表示向量 o 1 o_1 o1的第 d d d个条目、 N d N_d Nd是当前区块中包的数量,并且它包含的实例数量等于 d d d、 M i M_i Mi是包图 G i G_i Gi的实例节点数量;
- 由于GNN层数控制特征聚合的迭代,随后从初始节点特征
F
i
0
\mathbf{F}^0_i
Fi0中获取第二个智能体的观察:
o 2 = 1 N ∑ i = 1 N ( 1 M i ∑ j = 1 M i F i 0 ( j ) ) (8) \tag{8} o_2 = \frac{1}{N} \sum_{i=1}^{N} \left( \frac{1}{M_i} \sum_{j=1}^{M_i} F^0_i(j) \right) o2=N1i=1∑N(Mi1j=1∑MiFi0(j))(8)其中 F i 0 ( j ) \mathbf{F}^0_i(j) Fi0(j)是第 j j j个实例节点的特征向量、 N N N是当前区块中包图的总数;
- 为了进一步探索边密度和聚合迭代之间的潜在相关性,引入了观察信息交互:
o 1 = o 1 ⊕ σ ( ( o 1 ⊕ o 2 ) ( o 2 ⊕ o 1 ) T o 1 ) o 2 = o 2 ⊕ σ ( ( o 1 ⊕ o 2 ) ( o 2 ⊕ o 1 ) T o 2 ) (9) \tag{9} \begin{aligned} &o_1 = o_1 \oplus \sigma((o_1 \oplus o_2)(o_2 \oplus o_1)^T {o_1})\\ &o_2 = o_2 \oplus \sigma((o_1 \oplus o_2)(o_2 \oplus o_1)^T {o_2}) \end{aligned} o1=o1⊕σ((o1⊕o2)(o2⊕o1)To1)o2=o2⊕σ((o1⊕o2)(o2⊕o1)To2)(9)其中 ⊕ ( ⋅ ) \oplus(\cdot) ⊕(⋅)是向量的连接操作。通过此操作,观察 o 1 o_1 o1和 o 2 o_2 o2具有相同的维度,并且都编码了来自对方的信息;
RGMIL减轻了由于观察的特征维度或信息量的变化可能导致的MG中的不公平博弈。此外,为了提高这部分的效率,RGMIL只为每个数据区块一次性计算并记录这些初始邻接矩阵和观察。
3.2 动作选择和指导
当输入当前的观察向量 o i o_i oi后,每个智能体将其映射为一个 Q Q Q值向量 π i ( o i ) ∈ R 1 × ∣ A i ∣ \pi_i(o_i) \in \mathbb{R}^{1 \times |\mathcal{A}_i|} πi(oi)∈R1×∣Ai∣,并基于最大的 Q Q Q值条目或随机选择一个动作 a i a_i ai (如公式5):
- 第一个阈值动作 a 1 ∈ [ 0 , 1 ] a_1 \in [0, 1] a1∈[0,1]是一个小数,而第二个层数动作 a 2 a_2 a2是一个整数;
- 在
a
1
a_1
a1的指导下,可以获得一个更可靠的邻接矩阵
A
i
\mathbf{A}_i
Ai:
A i ( j , j ′ ) = { 1 , if exp ( − A i ( j , j ′ ) ) ≥ a 1 0 , if exp ( − A i ( j , j ′ ) )

