我们提供安全,免费的手游软件下载!

安卓手机游戏下载_安卓手机软件下载_安卓手机应用免费下载-先锋下载

当前位置: 主页 > 软件教程 > 软件教程

变分信息瓶颈 (Variational Information Bottleneck) 公式推导

来源:网络 更新时间:2024-08-21 09:30:45

互信息

互信息用于表示两个随机变量 相互依赖 的程度。随机变量 \(X\) \(Y\) 的互信息定义为

其中 \(p(\boldsymbol{x}, \boldsymbol{y})\) 表示 \(X\) \(Y\) 的联合概率密度, \(p(\boldsymbol{x})\) \(p(\boldsymbol{y})\) 分别表示 \(X\) \(Y\) 的边缘概率密度。

互信息是一个非负的量,当且仅当 \(X\) \(Y\) 相互独立 时(此时 \(p(\boldsymbol{x}, \boldsymbol{y}) = p(\boldsymbol{x})p(\boldsymbol{y})\) 恒成立)取到最小值 \(0\)

在机器学习中,联合分布 \(p(\boldsymbol{x}, \boldsymbol{y})\) 通常是难以得到的,因此通常会用贝叶斯公式转换一下,使用以下两种形式的互信息:

信息瓶颈

令随机变量 \(X\) 表示输入数据, \(Z\) 表示编码后的特征, \(Y\) 表示标签。信息瓶颈 (Information Bottleneck) 理论认为,神经网络的优化存在两阶段性:

  1. 快速拟合阶段:增加 \(I(Z, X)\)
  2. 压缩阶段:减少 \(I(Z, X)\) 并增加 \(I(Z, Y)\)

上面这幅插图可视化了神经网络训练过程中互信息的变化轨迹,横轴表示特征与输入的互信息 \(I(Z, X)\) ,纵轴表示特征与标签的互信息 \(I(Z, Y)\) (图中用 \(T\) 表示特征),从紫色到黄色表示从 0 epoch 到 10000 epoch。从图中可见,随着训练的进行, \(I(Z, X)\) 有一个先增大再减小的过程。

插图出自论文 [1703.00810] Opening the Black Box of Deep Neural Networks via Information 。参考阅读: Anatomize Deep Learning with Information Theory | Lil'Log 。

那么能不能利用这个现象对神经网络的训练进行正则化呢,于是有人提出了变分信息瓶颈 (Variational Information Bottleneck, VIB) 方法,优化的目标为:

我们希望 \(Z\) 能尽量准确地预测 \(Y\) ,同时尽量地遗忘 \(X\) 中的信息。换句话说,我们希望 \(Z\) 遗忘 \(X\) 中的冗余信息,只保留那些对预测 \(Y\) 有用的信息 。这里的 最小化 \(I(Z, X; \boldsymbol{\theta})\) 起到了正则化的效果

遗憾的是,从高维数据中直接估计互信息是很困难的,变分信息瓶颈的解决思路是通过变分近似实现对互信息的估计。

最小化 I(Z, X)

使用如下形式的互信息 \(I(Z, X)\)

注意到这里需要 \(p(\boldsymbol{z}|\boldsymbol{x}; \boldsymbol{\theta})\) ,一种比较方便的处理方法是像 VAE 那样使用 概率编码器 (probabilistic encoder) ,而不是传统的确定性编码器 (deterministic encoder),即 \(X \mapsto Z\) 是一个随机函数而不是传统的确定性函数。参考 VAE 中的做法,我们将 \(p(\boldsymbol{z}|\boldsymbol{x})\) 预定义为参数化的高斯分布,并用神经网络输出这个高斯分布的参数:

解决了 \(p(\boldsymbol{z}|\boldsymbol{x}; \boldsymbol{\theta})\) ,接下来的问题是如何求解 \(p(\boldsymbol{z})\) 。可能会想到采样估计的办法,即蒙特卡洛 (Monte Carlo, MC) 估计:

但是论文作者并没有使用这种方法,可能是认为在这里用 MC 估计的方差太大了,需要大量采样才能估得准,效率太低。为了估计期望 \(\mathbb{E}_{(\boldsymbol{x}, \boldsymbol{z}) \sim p(\boldsymbol{x}, \boldsymbol{z})}\left[\log\frac{p(\boldsymbol{z}|\boldsymbol{x}; \boldsymbol{\theta})}{p(\boldsymbol{z})}\right]\) ,就先要从 \(p(\boldsymbol{x})\) 中采样 \(\boldsymbol{x}\) ,然后从 \(p(\boldsymbol{z}|\boldsymbol{x}; \boldsymbol{\theta})\) 中采样 \(\boldsymbol{z}\) 。更麻烦的是方括号内的函数值也无法直接解析求解,需要先采样估计出 \(p(\boldsymbol{z})\) 才能计算。采样估计的过程太多,估计的方差自然会大。

变分信息瓶颈,顾名思义,就是通过 变分近似 的方法来解决无法获得 \(p(\boldsymbol{z})\) 的问题。假如有一个形式已知的无参分布 \(q(\boldsymbol{z})\) ,它跟 \(p(\boldsymbol{z})\) 非常接近,那我们用这个 \(q(\boldsymbol{z})\) 替换掉公式里的 \(p(\boldsymbol{z})\) ,不就能近似地计算互信息 \(I(Z, X)\) 吗?这里不妨将 \(q(\boldsymbol{z})\) 定义为标准高斯分布,即 \(q(\boldsymbol{z}) := N(\boldsymbol{z}, \boldsymbol{0}, \boldsymbol{I})\)

接下来需要证明这种替换是有道理的,参考 VAE 中推导的经验,我们尝试用 \(q(\boldsymbol{z})\) 替换 \(p(\boldsymbol{z})\) ,并尝试把额外的部分凑出一个 KL:

对于第一项, \(p(\boldsymbol{z}|\boldsymbol{x}; \boldsymbol{\theta})\) \(q(\boldsymbol{z})\) 都有解析式,因此方括号内的函数可以算出解析解。利用 \(p(\boldsymbol{x}, \boldsymbol{z}) = p(\boldsymbol{x})p(\boldsymbol{z}|\boldsymbol{x}; \boldsymbol{\theta})\) ,可以把第一项写得好看些:

这个 \(R(Z, X; \boldsymbol{\theta}) := \mathbb{E}_{\boldsymbol{x} \sim p(\boldsymbol{x})}[\mathrm{KL}[p(\boldsymbol{z}|\boldsymbol{x}; \boldsymbol{\theta}) \parallel q(\boldsymbol{z})]]\) 常常被称为 rate,也就是率失真理论里的率。Rate 这一项是可以用 mini-batch 梯度下降来优化的,具体来说,从训练集中采样一批样本 \(\boldsymbol{x}_1, \ldots, \boldsymbol{x}_N\) ,最小化每个 \(\boldsymbol{x}_i\) \(\mathrm{KL}[p(\boldsymbol{z}|\boldsymbol{x}_i; \boldsymbol{\theta}) \parallel q(\boldsymbol{z})]\) 即可。由于两个分布都是高斯分布,因此这里的 KL 有解析解:

详细的推导过程可参考 从极大似然估计到变分自编码器 - VAE 公式推导 中“KL 散度的解析解”这一节。相比原来的形式,“写得好看”之后的好处在于:函数对 \(\boldsymbol{z}\) 的积分可以解析地求解,这样一来,用 MC 估计 \(R(Z, X; \boldsymbol{\theta})\) 时,只需要从 \(p(\boldsymbol{x})\) 中采样 \(\boldsymbol{x}\) ,无需再从 \(p(\boldsymbol{z}|\boldsymbol{x}; \boldsymbol{\theta})\) 中采样 \(\boldsymbol{z}\) ,减少了采样带来的误差。

对于第二项,注意到期望方括号中的函数跟 \(\boldsymbol{x}\) 没关系,因此:

如果要详细证明一下的话就是:

因此这一项就是要凑的那个 KL 散度。由于得不到 \(p(\boldsymbol{z})\) 的解析式,KL 散度这一项无法被直接优化,它放在这里只是为了证明变分近似的合理性,详见下文。

综上所述,互信息 \(I(Z, X)\) 可以拆成两部分:

由 KL 散度的非负性可知,rate \(R\) 是互信息 \(I(Z, X; \boldsymbol{\theta})\) 的上界:

这正合我们意愿,因为我们想要最小化互信息 \(I(Z, X; \boldsymbol{\theta})\) ,所以我们可以通过 最小化它的上界 \(R(Z, X; \boldsymbol{\theta})\) 来间接地实现互信息的最小化,实现“曲线救国”。

最大化 I(Z, Y)

使用如下形式的互信息 \(I(Z, X)\)

标签的分布 \(p(\boldsymbol{y})\) 可能是无法知道的:如果 \(\boldsymbol{y}\) 是类别标签,那么离散型分布 \(p(\boldsymbol{y})\) 是比较容易求的;但如果 \(\boldsymbol{y}\) 是数值,连续型分布 \(p(\boldsymbol{y})\) 是比较难求的。不过难求的 \(p(\boldsymbol{y})\) 并不影响优化过程,因为

其中 \(\mathrm{H}(Y)\) 表示随机变量 \(Y\) 的信息熵 (entropy)。由于标签 \(Y\) 来自于数据集,不属于优化变量,因此 \(\mathrm{H}(Y)\) 是一个定值,不影响优化过程。

接下来要解决的是 \(p(\boldsymbol{y}|\boldsymbol{z}; \boldsymbol{\theta})\) 难求的问题。这里需要与前一节 \(p(\boldsymbol{z}|\boldsymbol{x}; \boldsymbol{\theta})\) 的情况相区分, \(X\) 是数据集中的数据, \(Z\) 是可优化的特征,因此对于 \(X \mapsto Z\) 这个过程,我们可以任意指定 \(p(\boldsymbol{z}|\boldsymbol{x}; \boldsymbol{\theta})\) 的形式, \(p(\boldsymbol{z}|\boldsymbol{x}; \boldsymbol{\theta})\) 不是难求的。而 \(Y\) 是数据集中的数据,对于 \(Z \mapsto Y\) 这个过程, \(p(\boldsymbol{y}|\boldsymbol{z}; \boldsymbol{\theta})\) 的形式是客观上确定的,我们不能随意指定, \(p(\boldsymbol{y}|\boldsymbol{z})\) 是难求的。

可以用一个形式已知的分布 \(q(\boldsymbol{y}|\boldsymbol{z}; \boldsymbol{\phi})\) 来近似 \(p(\boldsymbol{y}|\boldsymbol{z}; \boldsymbol{\theta})\)

利用 KL 散度的非负性,可以得到互信息 \(I(Z, Y; \boldsymbol{\theta})\) 的一个下界 \(I_{\text{BA}}\) ,它被称为互信息的 Barber & Agakov 下界。

\(p(\boldsymbol{y}, \boldsymbol{z}) = \int_x p(\boldsymbol{x}, \boldsymbol{y}, \boldsymbol{z}) \mathrm{d}\boldsymbol{x}\) 可得

\(Y\) 是连续型数据(回归问题),则选择高斯分布模型作为近似分布 \(q(\boldsymbol{y}|\boldsymbol{z}; \boldsymbol{\phi})\) ,最大化 \(\log q(\boldsymbol{y}|\boldsymbol{z}; \boldsymbol{\phi})\) 对应最小化 MSE 损失。若 \(Y\) 是离散型数据(分类问题),则选择伯努利分布(二分类模型)或类别分布(多分类模型)模型作为近似分布 \(q(\boldsymbol{y}|\boldsymbol{z}; \boldsymbol{\phi})\) ,最大化 \(\log q(\boldsymbol{y}|\boldsymbol{z}; \boldsymbol{\phi})\) 对应最小化交叉熵损失。详细的推导过程可参考 从极大似然估计到变分自编码器 - VAE 公式推导 中“重构损失”这一节。

\(N\) 的意思是从数据集中采样 \(N\) 个训练数据 \((\boldsymbol{x}_1, \boldsymbol{y}_1), \ldots, (\boldsymbol{x}_N, \boldsymbol{y}_N)\) \(M\) 的意思是对于每个样本 \(\boldsymbol{x}_i\) ,从分布 \(p(\boldsymbol{z}|\boldsymbol{x}_i; \boldsymbol{\theta})\) 中采样 \(M\) 个特征 \(\boldsymbol{z}\) 来计算 \(M\) 次 MSE/交叉熵损失。

一些理解

总的来说,最大化 \(I(Z, Y)\) 对应最小化交叉熵损失,最小化 \(I(Z, X)\) 对应最小化 KL 散度正则项(即 rate \(R\) )。

变分信息瓶颈与普通判别模型的区别:

  1. 将普通判别模型中的确定性编码器 (deterministic encoder)改成了概率编码器 (probabilistic encoder),给定 \(\boldsymbol{x}\) ,普通判别模型会给出唯一的 \(\boldsymbol{z}\) ,而 VIB 的 \(\boldsymbol{z}\) 是从某个分布中采样得到的,是一个随机变量。
  2. 加入了一个 KL 散度正则项(即 rate \(R\) ),希望特征的后验分布 \(p(\boldsymbol{z}|\boldsymbol{x}; \boldsymbol{\theta})\) 尽量接近标准高斯分布。

从这两点改进来看,变分信息瓶颈与 VAE 非常相似。

为什么最小化 KL 散度能作为正则项?为什么鼓励接近标准高斯分布是一种正则化效果?如果 KL 正则项为 0,则 \(p(\boldsymbol{z}|\boldsymbol{x}; \boldsymbol{\theta})\) 完全就是标准高斯分布,不包含任何关于样本 \(\boldsymbol{x}\) 的信息,即完全遗忘了 \(\boldsymbol{x}\) 的信息。当然了,这样的特征是不具备任何判别能力的,所以需要通过调节权重系数 \(\beta\) 以在遗忘和预测能力之间取得平衡。

此外,注意到

因此在最小化正则项 \(R(Z, X; \boldsymbol{\theta})\) 时,不仅是在最小化互信息 \(I(Z, X; \boldsymbol{\theta})\) ,而且在最小化 \(\mathrm{KL}[p(\boldsymbol{z}) \parallel N(\boldsymbol{0}, \boldsymbol{I})]\) ,使得特征 \(Z\) 的分布 \(p(\boldsymbol{z})\) 逐渐趋近于标准高斯分布。标准高斯分布有很多优良的性质,例如,它的各个维度是相互独立的,这就是在鼓励特征 \(Z\) 的各维度解耦。

参考资料

论文原文:Deep Variational Information Bottleneck

  • OpenReview (ICLR 2017)
  • arXiv

从变分编码、信息瓶颈到正态分布:论遗忘的重要性 - 科学空间

变分信息瓶颈(Variational Information Bottleneck) - Sphinx Garden

迁移学习:互信息的变分上下界 - orion-orion - 博客园 ; 迁移学习:互信息的变分上下界 - 猎户座的文章 - 知乎