当前位置: 首页 > news >正文

windows安装jax和jaxlib的教程(cuda)成功安装

本文你将解决3个问题:1、jaxlib没有安装的问题;2、python3.9以上(不可忽略)、cuda12.1(可忽略)以上配置要求不满足的问题;3、numpy版本太高的问题。

1、问题描述

当你直接pip install jax或者conda install jax后,执行以下代码检查是否错误:

import jax
print(jax.devices())  # 应输出类似 [gpu(id=0)]

总是会报错:ModuleNotFoundError: jax requires jaxlib to be installed. See https://github.com/google/jax#installation for installation instructions.

在这里插入图片描述

出现该问题的原因是没有安装jaxlib。jaxlib只支持python3.9以上版本,且需要手动安装(直接用pip install jaxlib会报错)

ERROR: Could not find a version that satisfies the requirement jaxlib (from versions: none)
ERROR: No matching distribution found for jaxlib

2、解决办法

下面有2种情况,按照你的Windows电脑是否需要cuda来选择对应的教程。

  • 情况1,你不需要GPU加速,即不用显卡cuda,那么只需要执行以下2步:

1、在虚拟环境中,在python3.9及以上的版本安装jax库,如 pip install jax 或者conda install jax,可以指定版本,这些就和一般的安装库那样。
2、下载jaxlib的文件,并手动安装。在https://storage.googleapis.com/jax-releases/jax_releases.html 地址中,键盘快捷键"ctrl + F"搜索"win" 找到对应python版本的jaxlib文件,jaxlib的版本自行测试吧。将其下载在本地任意文件夹中,然后像一般安装那样,在你的虚拟环境中安装此文件。

在这里插入图片描述

  • 情况2,你需要GPU加速,并且有自己的显卡cuda,而且已经配置了一个cuda11(或者以下的版本;如果你是cuda12及以上的版本,同样按照下面第2个步骤执行),那么只需要执行以下2步:

1、先安装cuda12(12.1以上的版本,必要的操作,不能跳过;无需卸载之前的cuda版本,多个版本的cuda可以共存),具体教程见以下两个教程(如果链接失效,请到我的csdn主页查找同名教程):
a. cuda 安装两个版本 https://blog.csdn.net/AdamCY888/article/details/147516608
b. 驱动支持的最高CUDA版本与实际安装的Runtime版本 https://blog.csdn.net/AdamCY888/article/details/147516543


在这里插入图片描述


(截图来自jax教程:https://jax.net.cn/en/latest/installation.html#installation)

2、上面步骤1确保你已经有一个12.1以上版本的cuda。

a. 下载jax:pip install -U "jax[cuda12]", 注意,引号不能省略,且建议不指定其jax版本。
b. 接下来同前面情况1的步骤2一样,下载jaxlibwhl文件。自行对应相应的版本。

在这里插入图片描述

3、测试jax对应jaxlib的版本

由于并没有找到jax对应jaxlib的版本,于是就安装一个最低版本的jaxlib 0.4.13,按照其报错提示,来得到满足的版本。正确的对应关系是:jax 0.4.21 对应的 jaxlib 0.4.19;如果安装的其它版本,也可以通过这个方法来解决。

RuntimeError: jaxlib is version 0.4.13, but this version of jax requires version >= 0.4.19.

在这里插入图片描述
于是,重新在 https://storage.googleapis.com/jax-releases/jax_releases.html 下载"jaxlib 0.4.19",并安装。

在这里插入图片描述

接下来进一步测试以下程序:

import jax.numpy as jnp
def selu(x, alpha=1.67, lmbda=1.05):return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)x = jnp.arange(5.0)
print(selu(x))

报错:

A module that was compiled using NumPy 1.x cannot be run in
NumPy 2.2.5 as it may crash. To support both 1.x and 2.x
versions of NumPy, modules must be compiled with NumPy 2.0.
Some module may need to rebuild instead e.g. with 'pybind11>=2.12'.If you are a user of the module, the easiest solution will be to
downgrade to 'numpy<2' or try to upgrade the affected module.
We expect that some modules will need time to support NumPy 2.Traceback (most recent call last):  File "d:\Anaconda\envs\jax_cuda12\lib\runpy.py", line 196, in _run_module_as_mainreturn _run_code(code, main_globals, None,File "d:\Anaconda\envs\jax_cuda12\lib\runpy.py", line 86, in _run_codeexec(code, run_globals)...

报错的原因是NumPy版本太高,需要降低版本。执行以下代码即可解决:

# 在虚拟环境中执行
conda activate jax_cuda12
pip uninstall numpy -y
pip install numpy==1.24.4  # 选择广泛兼容的1.x版本

4、安装成功!

import jax
print(jax.devices())  # 应输出类似 [gpu(id=0)]import jax.numpy as jnp

在这里插入图片描述

那么,接下来,请享受你的加速计算吧。

import jax.numpy as jnp
def selu(x, alpha=1.67, lmbda=1.05):return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)x = jnp.arange(5.0)
print(selu(x))

在这里插入图片描述

联系我

如果你在Windows系统下安装jax过程中,有任何困难,请留言或者私信,我将定期回复。

  • jax备忘录 https://blog.csdn.net/AdamCY888/article/details/147402803
http://www.xdnf.cn/news/147853.html

相关文章:

  • C++进阶----多态
  • 这些项目可以在以后年度结转扣除!
  • 从 0 开始认识 WebSocket:前端实时通信的利器!
  • 腾讯云系统盘占满
  • Node.js 应用场景
  • AIGC实战之如何构建出更好的大模型RAG系统
  • B站C语言课程笔记2
  • SD-WAN:企业网络架构的智能化革命
  • 蓝牙GATT协议
  • OAuth2AuthorizationEndpointFilter类介绍、应用场景和示例代码
  • 【北京迅为】iTOP-4412精英版使用手册-第二章 开发板初体验
  • 非序列实现MEMS聚焦功能
  • 【软件设计师】模拟题三
  • 如何将 Apache Hudi 接入 Ambari?完整部署与验证指南
  • 《深入理解计算机系统》阅读笔记之第十一章 网络编程
  • 100个用户的聊天系统:轮询 vs WebSocket 综合对比
  • Android项目升级插件到kotlin 2.1.0后混淆网络请求异常
  • “IAmMusicFont.com“:将音乐变成视觉
  • 内联函数(c++)
  • 信奥赛之c++基础(计算机存储+数据类型转换)
  • Android中的多线程
  • java.lang.ArrayIndexOutOfBoundsException: 11
  • BFD会话
  • 【蓝桥杯】P12165 [蓝桥杯 2025 省 C/Java A] 最短距离
  • 【2025 最新前沿 MCP 教程 01】模型上下文协议:AI 领域的 USB-C
  • 数据库证书可以选OCP认证吗?
  • Redis的主从模式和哨兵模式
  • 文档驱动:“提纲挈领”视角下的项目管理中枢构建
  • 《深入理解计算机系统》阅读笔记之第四章 处理器体系结构
  • 乐视系列玩机------乐视系列机型mtk芯片 乐视x620 x600 x501 pro3 双摄x650等改写参数 步骤解析