MLXJAX框架学习
今天在阅读资料时碰到MLX和JAX两个框架,对此进行简单的总结学习。
(1)MLX
MLX
是 Apple 推出的一个 机器学习框架,全称是 Machine Learning eXchange。它是为 Apple 硬件(特别是 Apple Silicon 芯片,如 M1/M2/M3)优化的深度学习框架,具有高性能和高效能的特点。但是根据 Apple 官方文档,MLX 只支持以下环境:
条件 | 要求 |
---|---|
✅ 设备 | Apple Silicon(M1 / M2 / M3 芯片) |
✅ 系统 | macOS 13.0+(建议 macOS 14 Ventura 及以上) |
✅ Python | 3.9 ~ 3.12(推荐 Python 3.10) |
✅ pip | 必须升级 pip 到最新版本(pip>=23.1 ) |
✅ 安装方式 | 通过 pip install mlx 安装 |
✨ MLX 的主要特点
✅ 1. 统一内存模型
MLX 支持 CPU、GPU、Neural Engine 共享内存,无需手动在设备之间拷贝数据。这个机制提升了训练和推理效率。
✅ 2. 延迟执行(Lazy execution)
MLX 会延迟执行操作直到必须运行,比如在调用 .numpy()
或 .eval()
时才真正计算,类似 JAX 的行为。
✅ 3. NumPy风格的API
开发体验很像 NumPy 或 PyTorch,例如:
import mlx.core as mx
a = mx.array([[1, 2], [3, 4]])
b = a * 2 + 1
print(b)
✅ 4. 自动微分(Autograd)
内置自动求导功能,类似 PyTorch 的 autograd
。
✅ 5. 轻量、纯Python构建模型
模型结构用 Python 编写,非常直观。
from mlx.nn import Linear, Sequential, ReLU
model = Sequential(Linear(10, 20), ReLU(), Linear(20, 1))
⚙️ 安装方式
pip install mlx
但 只能在 macOS 且是 Apple Silicon(M1/M2/M3)芯片 上运行。
📚 应用场景
-
微调和部署小模型(如 LLM、ViT)
-
快速实验
-
在 Apple 硬件上高效运行 AI 应用(尤其移动设备)
🧪 示例:线性回归
import mlx.core as mx
from mlx.nn import Linear
import mlx.nn as nn
import mlx.optimi