最近在项目里用到了 LassoNet 网络用来建模,读了一下论文原文,查了一些资料,发现网上的相关内容并不多,写一篇文章记录一下我的理解,希望能帮助到一些还不是很明白的朋友。

为什么要设计 LassoNet?

神经网络的不可知性

深度学习算法训练出来的模型很像一个只有输入输出的黑箱,人的思维很难理解。

例如训练一个识别猫猫图片的神经网络,研究者们可以把大量有猫图片和没有猫的图片打上标签,输入网络里训练,最终出来的模型就是可以识别图片里是否有猫了。

要知道一个神经网络里有成千上万个“参数”,这些参数可能有些代表了图片里的一些特征,比如猫的胡子或者猫的耳朵,但是哪些参数究竟代表了哪些特征?哪些参数对图片的识别率影响最大?这些问题很难解答。

特征选择

LassoNet 可以解决这个问题。LassoNet 本质上来说就是在训练时引导神经网络尽量保留那些比较“有用”的神经元节点,其他不太有用的节点一律忽略,这样我们就能知道哪些“参数”比较有用,哪些参数没什么用,没用的参数剔除就好了。这其实就是一种 “特征选择” 的机制。意思就是“选择”了一些有用特征出来。

LassoNet 的特性 —— 跳跃连接

如何选择那些“有用”的节点呢?LassoNet 借鉴了 ResNet 的跳跃连接(Skip Connection)机制来实现。

原理很简单,就是在任意一种神经网络前加一层“输入网络”,这个输入网络的每一个节点和输出的神经元节点相连,这样就可以直接评估输入特征和输出之间的关系了。

在训练网络时,对这条“通路”,损失函数会施加一个数学上的“惩罚”,让算法在训练时尽量移除那些对结果不太有用的神经元节点。一旦移除这个节点,这个“特征”就整个被移除了。这也就是“特征选择”。

图中的箭头就表示“跳跃连接”

概念比较:ResNet 的跳跃连接

但是和 ResNet 的跳跃连接相比,还是有些不同的。ResNet 的跳跃连接是将残差加到目标节点上,但是 LassoNet 的跳跃连接是把网络中的第一层节点参数加到了损失函数上。有一些细微的不同。

为什么叫做 LassoNet?

从损失函数的公式看,非常像 Lasso 回归也就是 L1 正则化。估计也是这个原因,作者才会将这个网络命名为 LassoNet。

LassoNet 的损失函数:

\[\min_{\theta, W} L(\theta, W) + \lambda || \theta ||_1\]

其中 $\theta$ 表示就是与跳跃连接相关的参数,可以看到,损失函数希望这个参数越小越好。

附一下 L1 正则化的公式

\[\min_{w} \sum_{i=1}^{n}(y_i - w^Tx_i)^2 + \mid \mid w \mid \mid _1\]

是不是看起来很像?

使用效果的图示

在使用 LassoNet 之前的神经网络

使用 LassoNet 之后的神经网络,注意使用了跳跃连接之后,一些与输入特征相关的神经元被移除了。

结语

LassoNet 作为一个特征选择方法,可以应用在任何一个神经网络上,只需要在神经网络之前加一层跳跃连接的网络就可以了。

LassoNet 是一个非常有意思的网络,它借用了跳跃连接和 Lasso 正则化的概念,达到了不错的特征选择的效果,从选择出来的特征就可以看出,哪些特征对输出的结果影响更大。大道至简。

参考

LassoNet 的官方网站,上面可以下载到 Paper 原文,网站上还有一个 2 分钟的说明视频,非常好,建议去看一看:LassoNet: Neural Networks with Feature Sparsity