1 5行代码打造无限宽神经网络模型-德赢Vwin官网 网
0
  • 聊天消息
  • 系统消息
  • 评论与回复
登录后你可以
  • 下载海量资料
  • 学习在线课程
  • 观看技术视频
  • 写文章/发帖/加入社区
会员中心
创作中心

完善资料让更多小伙伴认识你,还能领取20积分哦,立即完善>

3天内不再提示

5行代码打造无限宽神经网络模型

倩倩 来源:量子位 2020-03-27 15:47 次阅读

只要网络足够宽,深度学习动态就能大大简化,并且更易于理解。

最近的许多研究结果表明,无限宽度的DNN会收敛成一类更为简单的模型,称为高斯过程(Gaussian processes)。

于是,复杂的现象可以被归结为简单的线性代数方程,以了解AI到底是怎样工作的。

所谓的无限宽度(infinite width),指的是完全连接层中的隐藏单元数,或卷积层中的通道数量有无穷多。

但是,问题来了:推导有限网络的无限宽度限制需要大量的数学知识,并且必须针对不同研究的体系结构分别进行计算。对工程技术水平的要求也很高。

谷歌最新开源的Neural Tangents,旨在解决这个问题,让研究人员能够轻松建立、训练无限宽神经网络

甚至只需要5行代码,就能够打造一个无限宽神经网络模型。

这一研究成果已经中了ICLR 2020。戳进文末Colab链接,即可在线试玩。

开箱即用,5行代码打造无限宽神经网络模型

Neural Tangents 是一个高级神经网络 API,可用于指定复杂、分层的神经网络,在 CPU/GPU/TPU 上开箱即用。

该库用 JAX编写,既可以构建有限宽度神经网络,亦可轻松创建和训练无限宽度神经网络。

有什么用呢?举个例子,你需要训练一个完全连接神经网络。通常,神经网络是随机初始化的,然后采用梯度下降进行训练。

研究人员通过对一组神经网络中不同成员的预测取均值,来提升模型的性能。另外,每个成员预测中的方差可以用来估计不确定性。

如此一来,就需要大量的计算预算。

但当神经网络变得无限宽时,网络集合就可以用高斯过程来描述,其均值和方差可以在整个训练过程中进行计算。

而使用 Neural Tangents ,仅需5行代码,就能完成对无限宽网络集合的构造和训练。

from neural_tangents import predict, staxinit_fn, apply_fn, kernel_fn = stax.serial( stax.Dense(2048, W_std=1.5, b_std=0.05), stax.Erf(), stax.Dense(2048, W_std=1.5, b_std=0.05), stax.Erf(), stax.Dense(1, W_std=1.5, b_std=0.05))y_mean, y_var = predict.gp_inference(kernel_fn, x_train, y_train, x_test, ‘ntk’, diag_reg=1e-4, compute_cov=True)

上图中,左图为训练过程中输出(f)随输入数据(x)的变化;右图为训练过程中的不确定性训练、测试损失。

将有限神经网络的集合训练和相同体系结构的无限宽度神经网络集合进行比较,研究人员发现,使用无限宽模型的精确推理,与使用梯度下降训练整体模型的结果之间,具有良好的一致性。

这说明了无限宽神经网络捕捉训练动态的能力。

不仅如此,常规神经网络可以解决的问题,Neural Tangents 构建的网络亦不在话下。

研究人员在 CIFAR-10 数据集的图像识别任务上比较了 3 种不同架构的无限宽神经网络。

可以看到,无限宽网络vwin 有限神经网络,遵循相似的性能层次结构,其全连接网络的性能比卷积网络差,而卷积网络的性能又比宽残余网络差。

但是,与常规训练不同,这些模型的学习动力在封闭形式下是易于控制的,也就是说,可以用前所未有的视角去观察其行为。

对于深入理解机器学习机制来说,该研究也提供了一种新思路。谷歌表示,这将有助于“打开机器学习的黑匣子”。

声明:本文内容及配图由入驻作者撰写或者入驻合作网站授权转载。文章观点仅代表作者本人,不代表德赢Vwin官网 网立场。文章及其配图仅供工程师学习之用,如有内容侵权或者其他违规问题,请联系本站处理。 举报投诉
  • 神经网络
    +关注

    关注

    42

    文章

    4771

    浏览量

    100708
  • 代码
    +关注

    关注

    30

    文章

    4779

    浏览量

    68517
  • 深度学习
    +关注

    关注

    73

    文章

    5500

    浏览量

    121109
收藏 人收藏

    评论

    相关推荐

    神经网络教程(李亚非)

    源程序  4.3 旅行商问题(TSP)的HNN求解  Hopfield模型求解TSP源程序  第5章 随机型神经网络  5.1 模拟退火算法  5.2 Boltzmann机  Boltzmann机
    发表于 03-20 11:32

    非局部神经网络打造未来神经网络基本组件

    最高的精度。由此表明非局部模块可以作为一种比较通用的基本组件,在设计深度神经网络时使用。实验及结果在这一节我们简单介绍论文中描述的实验及结果。 视频的基线模型是 ResNet-50 C2D。三维输出映射
    发表于 11-12 14:52

    如何构建神经网络

    原文链接:http://tecdat.cn/?p=5725 神经网络是一种基于现有数据创建预测的计算系统。如何构建神经网络神经网络包括:输入层:根据现有数据获取输入的层隐藏层:使用反向传播优化输入变量权重的层,以提高
    发表于 07-12 08:02

    卷积神经网络模型发展及应用

    的概率。Top-5 识别率指的是 CNN 模型预测出最大概率的前 5 个分 类里有正确类别的概率。2012 年,由 Alex Krizhevshy 提出的 AlexNet给卷 积神经网络
    发表于 08-02 10:39

    神经网络模型原理

    神经网络模型原理介绍说明。
    发表于 04-21 09:40 7次下载

    卷积神经网络模型有哪些?卷积神经网络包括哪几层内容?

    卷积神经网络模型有哪些?卷积神经网络包括哪几层内容? 卷积神经网络(Convolutional Neural Networks,CNN)是深度学习领域中最广泛应用的
    的头像 发表于 08-21 16:41 1916次阅读

    卷积神经网络模型原理 卷积神经网络模型结构

    卷积神经网络模型原理 卷积神经网络模型结构  卷积神经网络是一种深度学习神经网络,是在图像、语音
    的头像 发表于 08-21 16:41 1013次阅读

    卷积神经网络算法代码matlab

    卷积神经网络算法代码matlab 卷积神经网络(Convolutional Neural Network,CNN)是一种深度学习网络模型,其
    的头像 发表于 08-21 16:50 1203次阅读

    常见的卷积神经网络模型 典型的卷积神经网络模型

    常见的卷积神经网络模型 典型的卷积神经网络模型 卷积神经网络(Convolutional Neural Network, CNN)是深度学习
    的头像 发表于 08-21 17:11 2837次阅读

    cnn卷积神经网络模型 卷积神经网络预测模型 生成卷积神经网络模型

    cnn卷积神经网络模型 卷积神经网络预测模型 生成卷积神经网络模型  卷积
    的头像 发表于 08-21 17:11 1231次阅读

    卷积神经网络模型搭建

    卷积神经网络模型搭建 卷积神经网络模型是一种深度学习算法。它已经成为了计算机视觉和自然语言处理等各种领域的主流算法,具有很大的应用前景。本篇文章将详细介绍卷积
    的头像 发表于 08-21 17:11 951次阅读

    卷积神经网络模型的优缺点

    卷积神经网络模型的优缺点  卷积神经网络(Convolutional Neural Network,CNN)是一种从图像、视频、声音和一系列多维信号中进行学习的深度学习模型。它在计算机
    的头像 发表于 08-21 17:15 4412次阅读

    构建神经网络模型的常用方法 神经网络模型的常用算法介绍

    神经网络模型是一种通过模拟生物神经元间相互作用的方式实现信息处理和学习的计算机模型。它能够对输入数据进行分类、回归、预测和聚类等任务,已经广泛应用于计算机视觉、自然语言处理、语音处理等
    发表于 08-28 18:25 1025次阅读

    rnn是什么神经网络模型

    RNN(Recurrent Neural Network,循环神经网络)是一种具有循环结构的神经网络模型,它能够处理序列数据,并对序列中的元素进行建模。RNN在自然语言处理、语音识别、时间序列预测等
    的头像 发表于 07-05 09:50 589次阅读

    神经网络预测模型的构建方法

    神经网络模型作为一种强大的预测工具,广泛应用于各种领域,如金融、医疗、交通等。本文将详细介绍神经网络预测模型的构建方法,包括模型设计、数据集
    的头像 发表于 07-05 17:41 641次阅读