Nx:具有多阶段编译的多维张量Elixir lib(CPU / GPU)

2021-02-18 00:46:17

Nx是Elixir的多维张量库,具有对CPU / GPU的多阶段编译。它的高级功能是:

类型化的多维张量,其中的张量可以是无符号整数(8、16、32、64的大小),有符号整数(8、16、32、64的大小),浮点数(32、64的大小)和脑浮点数(16的大小) );

命名张量,使开发人员可以为每个维度命名,从而使代码库更具可读性,并且不易出错。

自动分化,也称为autograd。梯度函数提供反向模式微分,可用于仿真,训练概率模型等;

张量后端,使主要的Nx API可以用于处理二进制张量,GPU支持的张量,稀疏矩阵等;

数值定义(称为defn)为多个目标(例如高度专业的CPU代码或GPU)提供张量操作的多阶段编译。使用您选择的编译器,编译可以提前进行(AOT)或及时进行(JIT)。

您可以在问题跟踪器中找到计划的增强功能。如果您需要一项特定功能来向前迈进,请随时告知我们并提供反馈。

Nx的吉祥物是Numbat,它是原产于澳大利亚南部的有袋动物。不幸的是,Numbat濒临灭绝,估计还不到1000。如果您喜欢这个项目,可以考虑捐赠Numbat保护工作,例如Numbat项目和澳大利亚野生动物保护组织。

对于Python开发人员而言,Nx当前主要从Numpy和JAX汲取灵感,但打包到一个统一的库中。

在Elixir的任何社区空间(例如Elixir论坛和elixir-lang.org侧边栏中列出的其他频道)中,也欢迎对Nx进行讨论。

为了使用Nx,您将需要安装Elixir。然后通过混合构建工具创建一个Elixir项目:

然后,您可以在mix.exs中添加Nx作为依赖项。目前,在我们开发第一个版本时,您将不得不使用Git依赖项:

iex> t = Nx。张量([[1,2],[3,4]] iex> Nx。除法(Nx.exp(t),Nx.sum(Nx.exp(t)))#Nx.Tensor< f64 [2] [2] [[0.03205860328008499,0.08714431874203257],[0.23688281808991013,0.6439142598879722]]>

默认情况下,Nx使用纯Elixir代码。由于Elixir是一种功能性且不变的语言,因此上述每个操作都会复制张量,这是非常低效的。

但是,Nx还带有称为defn的数字定义,这是为数字计算量身定制的Elixir的子集。例如,它会覆盖Elixir的默认运算符,因此它们可感知张量:

defn支持多个编译器后端,这些后端可以编译所述函数以在CPU或GPU中运行。例如,使用EXLA编译器,该编译器提供了对Google XLA的绑定:

调用softmax后,Nx.Defn将调用EXLA来生成针对张量类型和形状量身定制的即时且高度专业的代码编译版本。通过传递client::cuda或client::rocm,可以为GPU编译代码。作为参考,以下是使用一百万个随机浮点值的张量调用时上述函数的一些基准:

名称ips平均偏差中位数99th%xla gpu f32保持15308.14 0.0653 ms±29.01%0.0638 ms 0.0758 msxla gpu f64保持4550.59 0.22 ms±7.54%0.22 ms 0.33 msxla cpu f32 434.21 2.30 ms±7.54%2.26 ms 2.69 msxla gpu f32 398.45 2.51 ms±2.28%2.50 ms 2.69 msxla gpu f64 190.27 5.26 ms±2.16%5.23 ms 5.56 msxla cpu f64 168.25 5.94 ms±5.64%5.88 ms 7.35 mselixir f32 3.22 311.01 ms±1.88%309.69 ms 340.27 mselixir f64 3.11 321.70 ms±1.44% 322.10 ms 328.98 ms比较:xla gpu f32保持15308.14xla gpu f64保持4550.59-慢3.36x +0.154 msxla cpu f32 434.21-慢35.26x +2.24 msxla gpu f32 398.45-38.42x慢+2.44 msxla gpu f32 -2.464 xxla gpu f64 46。 5.19 msxla cpu f64 168.25-慢90.98x +5.88 mselixir f32 3.22-4760.93x慢+310.94 mselixir f64 3.11-4924.56x慢+321.63 ms

defn依靠一种称为多阶段编程的技术,该技术建立在Elixir功能和元编程功能的基础上:我们将Elixir代码转换为发出AST的代码,然后将其转换为在CPU / GPU上运行。最终,defn编译器是可插入的,这意味着开发人员可以为不同的张量编译器技术实现绑定并选择最合适的技术。

defn中支持Elixir的许多功能,例如管道运算符,别名,条件,模式匹配等。路线图上还包含其他功能,例如循环和就地更新。 defn还支持转换,该转换允许在运行时转换数字定义。通过grad函数自动进行区分是转换的一个示例。

根据Apache许可证2.0版(" License")获得许可,除非遵守该许可,否则不得使用此文件。您可以从http://www.apache获得该许可的副本。 org / licenses / LICENSE-2.0

除非适用法律要求或书面同意,否则根据许可协议分发的软件将按“原样”分发。 没有任何形式的保证或条件的基础,无论明示或暗示。请参见许可以了解用于管理许可下的许可和限制的特定语言。