互联网资讯 / 人工智能 · 2023年12月18日

JAX与TensorFlow、PyTorch的对比分析

JAX与TensorFlow、PyTorch的对比分析

在机器学习的领域中,TensorFlow 和 PyTorch 已经成为了大家耳熟能详的框架。然而,除了这两个老牌框架之外,谷歌推出的 JAX 也逐渐受到关注,许多研究者寄希望于它能够取代 TensorFlow 等其他机器学习框架。

JAX 项目最初由谷歌大脑团队的 Matt Johnson、Roy Frostig、Dougal Maclaurin 和 Chris Leary 等人共同创建。

目前,JAX 在 GitHub 上已经获得了 13.7K 的星标。

Github1.3万星,迅猛发展的JAX对比TensorFlow、PyTorch

项目地址:https://Github.com/Google/jax

JAX 的迅速发展

JAX 最早是基于 Autograd 开发的,它结合了更新版本的 Autograd 和 XLA,能够对 Python 程序及 NumPy 运算进行自动微分,支持循环、分支、递归和闭包函数的求导,甚至可以计算三阶导数。依赖于 XLA,JAX 可以在 GPU 和 TPU 上编译和运行 NumPy 程序,通过 grad 函数,它支持自动的反向传播和正向传播,且两者可任意组合。

Github1.3万星,迅猛发展的JAX对比TensorFlow、PyTorch

开发 JAX 的初衷是什么?这要从 NumPy 说起。NumPy 是 Python 中一个基础的数值计算库,虽然广泛使用,但不支持 GPU 或其他硬件加速,并且没有内置的反向传播支持。此外,Python 本身的速度限制使得很少有研究者在生产环境中直接使用 NumPy 进行深度学习模型的训练或部署。

因此,许多深度学习框架应运而生,如 PyTorch 和 TensorFlow 等。但 NumPy 具有灵活性、易调试和稳定的 API 等独特优势。JAX 的目标就是将这些优势与硬件加速结合。

目前,基于 JAX 的开源项目也不断涌现。例如,谷歌的神经网络库团队开发了 Haiku,这是一个专为 JAX 设计的深度学习代码库,用户可以在 JAX 上进行面向对象的开发。此外,还有 RLax,这是一个基于 JAX 的强化学习库,用户可以使用 RLax 构建和训练 Q-learning 模型。还有 JAXNet,一个基于 JAX 的深度学习库,用户只需一行代码即可定义计算图,并实现 GPU 加速。可以说,在过去几年中,JAX 促使深度学习研究迅速发展。

JAX 的安装

使用 JAX 的第一步是在 Python 环境或 Google Colab 中安装它,使用 pip 进行安装:

$ pip install –upgrade jax jaxlib

需要注意的是,上述安装方式仅支持在 CPU 上运行。如果希望在 GPU 上运行程序,首先需要安装 CUDA 和 cuDNN,然后运行以下命令(确保将 jaxlib 版本映射到 CUDA 版本):

$ pip install –upgrade jax jaxlib==0.1.61+cuda110 -f https://storage.googleapis.com/jax-releases/jax_releases.html

现在可以将 JAX 和 NumPy 一起导入:

import jax import jax.numpy as jnp import numpy as np

JAX 的一些特性

使用 grad() 函数进行自动微分,这对于深度学习应用非常重要,方便进行反向传播。以下是一个简单的二次函数在点 1.0 上求导的示例:

from jax import grad def f(x): return 3*x**2 + 2*x + 5 def f_prime(x): return 6*x + 2 grad(f)(1.0) # DeviceArray(8., dtype=float32) f_prime(1.0) # 8.0

JIT(Just in Time):为了充分利用 XLA 的强大能力,代码需要编译到 XLA 内核中。这就是 JIT 的作用。用户可以使用 jax.jit() 函数或 @jax.jit 注释来实现。

from jax import jit x = np.random.rand(1000, 1000) y = jnp.array(x) def f(x): for _ in range(10): x = 0.5*x + 0.1*jnp.sin(x) return x g = jit(f) %timeit -n 5 -r 5 f(y).block_until_ready() # 5 loops, best of 5: 10.8 ms per loop %timeit -n 5 -r 5 g(y).block_until_ready() # 5 loops, best of 5: 341 µs per loop

pmap:自动将计算分配到所有可用设备,并处理它们之间的通信。通过 pmap,JAX 支持大规模数据并行,使单个处理器无法处理的大数据得以处理。要检查可用设备,可以运行 jax.devices():

from jax import pmap def f(x): return jnp.sin(x) + x**2 f(np.arange(4)) # DeviceArray([0. , 1.841471, 4.9092975, 9.14112 ], dtype=float32) pmap(f)(np.arange(4)) # ShardedDeviceArray([0. , 1.841471, 4.9092975, 9.14112 ], dtype=float32)

vmap:是一种函数转换,JAX 通过 vmap 提供了自动矢量化算法,大大简化了此类计算,研究人员在处理新算法时无需再考虑批量化的问题。示例如下:

from jax import vmap def f(x): return jnp.square(x) f(jnp.arange(10)) # DeviceArray([0, 1, 4, 9, 16, 25, 36, 49, 64, 81], dtype=int32) vmap(f)(jnp.arange(10)) # DeviceArray([0, 1, 4, 9, 16, 25, 36, 49, 64, 81], dtype=int32)

TensoRFlow vs PyTorch vs JAX

在深度学习领域,有几家大型公司推出的框架被广泛使用,如谷歌的 TensorFlow、Facebook 的 PyTorch、微软的 CNTK 和亚马逊 AWS 的 MXNet 等。

每种框架都有其优缺点,选择时需要根据自己的需求进行判断。

Github1.3万星,迅猛发展的JAX对比TensorFlow、PyTorch

我们将以 Python 中的三大深度学习框架——TensorFlow、PyTorch 和 JAX 进行比较。尽管这些框架各有不同,但它们有两个共同点:

它们都是开源的,这意味着如果库中存在错误,用户可以在 GitHub 上提交问题(并进行修复),同时也可以在库中添加自己的功能;由于全局解释器锁,Python 在内部运行较慢,因此这些框架使用 C/C++ 作为后端处理计算和并行过程。

那么它们之间的不同之处在哪呢?以下是 TensorFlow、PyTorch 和 JAX 三个框架的比较表。

Github1.3万星,迅猛发展的JAX对比TensorFlow、PyTorch

TensoRFlow

TensoRFlow 由谷歌开发,最初版本可追溯到 2015 年开源的 TensorFlow 0.1,经过多年的发展,已稳定下来,拥有强大的用户基础,成为最受欢迎的深度学习框架。但在使用过程中,用户也发现了 TensorFlow 的一些缺点,例如 API 稳定性不足和静态计算图编程复杂等问题。因此,在 TensorFlow 2.0 版本中,谷歌将 Keras 纳入其中,成为 tf.keras。

目前,TensorFlow 的主要特点包括:

这是一个友好的框架,高级 API Keras 提供了简便的模型层定义、损失函数和模型创建方式; TensorFlow 2.0 引入了 Eager Execution(动态图机制),使得库更加用户友好,且是对之前版本的重大升级; Keras 这种高级接口的缺点是,TensorFlow 抽象了许多底层机制(仅为了方便最终用户),使得研究人员在处理模型时自由度减少; TensorFlow 提供了 TensorBoard,这是 TensorFlow 的可视化工具包,可以帮助研究者可视化损失函数、模型图和模型分析等。

PyTorch

PyTorch(Python-Torch)是 Facebook 提供的机器学习库。在一年前,研究者在选择 TensoRFlow 还是 PyTorch 时,几乎没有争议,大多数人会选择 TensorFlow。然而,现如今,使用 PyTorch 的研究者数量日益增加。PyTorch 的一些重要特性包括:

Github1.3万星,迅猛发展的JAX对比TensorFlow、PyTorch
与 TensorFlow 不同,PyTorch 使用动态计算图,这意味着图是在运行时创建的,允许用户随时修改和查看图的内部结构; 除了提供用户友好的高级 API,PyTorch 还包括精心设计的低级 API,使得对机器学习模型的控制变得更加灵活。用户可以在训练期间检查和修改模型的前向和后向传播输出,这对于梯度裁剪和神经风格迁移非常有帮助; PyTorch 允许用户扩展代码,轻松添加新的损失函数和用户定义的层。PyTorch 的 Autograd 模块自动实现了深度学习算法中的反向传播求导,所有在 Tensor 上的操作,Autograd 都能自动提供微分,简化了手动计算导数的复杂性; PyTorch 对数据并行和 GPU 的使用有广泛支持; 相较于 TensorFlow,PyTorch 更加 Python 化,适合 Python 生态系统,用户可以使用 Python 类调试器工具来调试 PyTorch 代码。

JAX

JAX 是来自谷歌的一个相对较新的机器学习库,更像是一个 Autograd 库,能够区分原生的 Python 和 NumPy 代码。JAX 的一些特性主要包括:

如官方网站所述,JAX 能够执行 Python 与 NumPy 程序的可组合转换:包括向量化、JIT 编译为 GPU/TPU 等; 与 PyTorch 不同,JAX 在计算梯度时的方式也有所不同。在 PyTorch 中,图是在前向传递期间创建的,而梯度则在后向传递期间计算;而在 JAX 中,计算是通过静态计算图的方式进行的,这使得 JAX 在性能上具有一定的优势。