使用JAX加速研究

2020-12-05 08:05:56

DeepMind工程师通过构建工具,扩展算法并创建具有挑战性的虚拟和物理世界来训练和测试人工智能(AI)系统,从而加快了我们的研究速度。作为这项工作的一部分,我们不断评估新的机器学习库和框架。

最近,我们发现Google研究团队开发的机器学习框架JAX为越来越多的项目提供了很好的服务。 JAX在我们的工程哲学中引起了共鸣,并且在过去的一年中被我们的研究团体广泛采用。在这里,我们分享了使用JAX的经验,概述了为什么发现它对我们的AI研究有用,并概述了我们为支持各地研究人员而建立的生态系统。

JAX是为高性能数值计算(尤其是机器学习研究)而设计的Python库。其用于数字函数的API基于NumPy,它是科学计算中使用的函数的集合。 Python和NumPy都被广泛使用和熟悉,这使JAX变得简单,灵活且易于采用。

除了其NumPy API外,JAX还包括可扩展功能转换系统,可转换功能转换,有助于支持机器学习研究,其中包括:

差异化:基于梯度的优化是ML的基础。通过诸如grad,hessian,jacfwd和jacrev之类的函数转换,JAX本身支持任意数值函数的正向和反向模式自动微分。

向量化:在机器学习研究中,我们通常将单个功能应用于大量数据,例如计算批次中的损失或评估差异化私人学习的每个示例的梯度。 JAX通过vmap转换提供自动矢量化,简化了这种形式的编程。例如,研究人员在实施新算法时无需推理批处理。 JAX还通过相关的pmap转换支持大规模数据并行性,从而优雅地分发了对于单个加速器的内存而言太大的数据。

JIT编译:XLA用于在GPU和Cloud TPU加速器上实时(JIT)编译和执行JAX程序。 JIT编译与JAX的NumPy一致的API结合在一起,使以前在高性能计算方面没有经验的研究人员可以轻松地扩展到一个或多个加速器。

我们发现JAX使得能够对新颖的算法和体系结构进行快速实验,并且它现在已经成为我们许多近期出版物的基础。要了解更多信息,请考虑在格林尼治标准时间12月9日星期三7:00 pm在NeurIPS虚拟会议上加入我们的JAX圆桌会议。

支持最新的AI研究意味着在快速原型制作和快速迭代与以传统上与生产系统相关的规模部署实验的能力之间取得平衡。使这类项目特别具有挑战性的是,研究领域发展迅速且难以预测。在任何时候,一项新的研究突破都可能并且有规律地改变整个团队的发展轨迹和要求。在这个瞬息万变的环境中,我们的工程团队的核心职责是确保在一个研究项目中可以有效地重用所学到的经验教训和编写的代码。

一种被证明是成功的方法是模块化:我们将在每个研究项目中开发的最重要和最关键的构建基块提取到经过良好测试和有效的组件中。这使研究人员能够专注于他们的研究,同时还受益于代码重用,错误修复和我们的核心库所实现的算法成分的性能改进。我们还发现,重要的是要确保每个库都有明确定义的范围,并确保它们可互操作但独立。增量购买,即能够选择和选择功能而不被其他功能锁定的能力,对于为研究人员提供最大的灵活性并始终支持他们选择正确的工作工具至关重要。

开发JAX生态系统时需要考虑的其他因素包括确保与现有TensorFlow库(例如Sonnet和TRFL)的设计保持一致(如果可能)。我们还旨在构建(尽可能相关)与其基础数学尽可能匹配的组件,以进行自我描述并最大程度地减少从纸本到代码的脑力劳动。最后,我们选择开放我们的图书馆资源,以促进研究成果的共享,并鼓励广大社区探索JAX生态系统。

可组合函数转换的JAX编程模型会使处理有状态对象变得复杂,例如具有可训练参数的神经网络。 Haiku是一个神经网络库,允许用户使用熟悉的面向对象的编程模型,同时利用JAX的纯功能范例的强大功能和简单性。

Haiku已被DeepMind和Google的数百名研究人员积极使用,并且已经在多个外部项目(例如Coax,DeepChem,NumPyro)中得到采用。它建立在Sonnet API的基础上,Sonnet是我们在TensorFlow中基于模块的神经网络编程模型,我们的目标是使从Sonnet到Haiku的移植尽可能简单。

基于梯度的优化是ML的基础。 Optax提供了一个梯度转换库以及合成运算符(例如链),该运算符允许在单行代码中实现许多标准优化器(例如RMSProp或Adam)。

Optax的成分性质自然支持在定制优化器中重组相同的基本成分。此外,它还提供了许多用于随机梯度估计和二阶优化的工具。

许多Optax用户已采用Haiku,但根据我们的增量购买原则,支持将参数表示为JAX树结构的任何库(例如Elegy,Flax和Stax)。请参阅此处以获取有关这个丰富的JAX库生态系统的更多信息。

我们许多最成功的项目都位于深度学习与强化学习(RL)(也称为深度强化学习)的交集处。 RLax是一个库,为构建RL代理提供有用的构建基块。

RLax中的组件涵盖了广泛的算法和思想:TD学习,策略梯度,参与者批评,MAP,近端策略优化,非线性价值转换,通用价值函数和多种探索方法。

尽管提供了一些介绍性的示例代理,但RLax并不旨在用作构建和部署完整RL代理系统的框架。 Acme是基于RLax组件构建的功能齐全的代理框架的一个示例。

测试对于软件可靠性至关重要,研究代码也不例外。从研究实验中得出科学结论需要对代码的正确性充满信心。 Chex是测试工具的集合,图书馆作者使用这些工具来验证通用构件是否正确且健壮,并由最终用户检查其实验代码。

Chex提供了各种实用程序,包括可识别JAX的单元测试,JAX数据类型的属性声明,模拟和伪造以及多设备测试环境。 Chex用于DeepMind的JAX生态系统以及外部项目(例如Coax和MineRL)。

图神经网络(GNN)是令人兴奋的研究领域,具有许多有希望的应用程序。例如,请参阅我们最近在Google Maps中进行交通预测的工作以及在物理模拟方面的工作。 Jraph(发音为" giraffe")是一个轻量级的库,支持在JAX中使用GNN。

Jraph提供了用于图形的标准化数据结构,用于处理图形的一组实用程序以及一个' zoo&#39 ;;易于分叉和可扩展的图神经网络模型。其他主要功能包括:有效利用硬件加速器的GraphTuples批处理,通过填充和掩码对可变形状图的JIT编译支持以及在输入分区上定义的损失。像Optax和我们的其他库一样,Jraph对用户选择神经网络库没有任何限制。

我们的JAX生态系统正在不断发展,我们鼓励ML研究社区探索我们的图书馆以及JAX加速自身研究的潜力。

如果您发现DeepMind JAX生态系统对您的工作有用,请使用此引用(托管在GitHub上)。