ToB企服应用市场:ToB评测及商务社交产业平台

标题: Mamba干系环境安装通用教程(causal_conv1d及mamba-ssm) [打印本页]

作者: 曹旭辉    时间: 2024-11-26 23:31
标题: Mamba干系环境安装通用教程(causal_conv1d及mamba-ssm)
Ubuntu下安装Mamba干系模子环境,起首按照你要运行的代码创建环境,并安装相应的torch库
报错通常发生在 pip install causal-conv1d和pip install mamba-ssm时,特别是进入causal_conv1d或mamba-ssm目录运行python setup.py develop时,可能包括
1. cuda版本错误

起首 nvcc -V 看下环境中的cuda版本是否精确,需要在11.6以上,可以根据torch的后缀+cu118安装相应版本,在环境中运行以下命令:
  1. conda install -c "nvidia/label/cuda-11.8.0" cuda-nvcc
复制代码
假如需要安装其他版本的可以参考文末的链接
2. 编译报错如ERROR: Could not build wheels for mamba-ssm, which is required to install pyproject.toml-based projects

mamba-ssm可以换成别的如selective_scan/causal_conv1d,总之就是编译出错,而且在这一步要好久,由于是联网下载所以速率可能很慢
可以去下面两个网站中下载所需的whl版本并手动安装。
causal-conv1d:https://github.com/Dao-AILab/causal-conv1d/releases
mamba-ssm:https://github.com/state-spaces/mamba/releases
如何知道需要哪个版本的mamba-ssm和causal_conv1d呢?
下载完成后,进入下载完成的whl文件目录,再进入你的conda环境,pip install ***.whl文件名安装
3. bimamba等import错误:ImportError: cannot import name ‘bimamba_inner_fn’ from ‘mamba_ssm.ops.selective_scan_interface’

以及在from mamba_ssm.ops.selective_scan_interface import bimamba_inner_fn, bimamba_inner_ref出现红线
这是由于论文作者修改了mamba-ssm导致与原始的(pip install的大概从官网release里下载whl后安装的)不一样了
需要用作者自己写的mamba-ssm库更换原来中安装的:
复制作者的mamba-ssm文件夹(留意是子文件夹,假如有父文件夹也叫mamba-ssm,就复制里面的子文件夹mamba-ssm),进入目录anaconda3/envs/mamba(你的环境名)/lib/python3.9(你的py版本)/site-packages/粘贴更换
causal-conv1d库不一定要更换,假如test_causal_conv1d.py能运行正常就可以
4. causal_conv1d和Mamba-ssm版本不匹配:TypeError: causal_conv1d_fwd(): incompatible function arguments. The following argument types are supported: 1. (arg0: torch.Tensor, arg1: torch.Tensor, arg2: Optional[torch.Tensor], arg3: Optional[torch.Tensor], arg4: bool) -> torch.Tensor

找到作者给出的版本重新安装,重新安装前最好先pip uninstall原来的版本,再pip install新版本 :
  1. pip uninstall causal_conv1d
  2. pip uninsntall mamba-ssm
复制代码
  1. // 版本号请按照你要复现的代码来
  2. pip install causal_conv1d==1.0.0
  3. pip uninsntall mamba-ssm==1.0.1
复制代码
5. 其他衍生错误

衍生错误还可能包括以下这些,请根据以上4点问题举行排查:

验证是否安装成功

如何知道自己causal_conv1d和mamba-ssm版本是否安装成功,是否匹配?
  1. 举例:在vivim.py里加入
  2. if __name__ == "__main__" :
  3.         model = Vivim().cuda()
  4.         x = torch.randn(4,5,3,256,256).cuda()
  5.         y = model(x)
  6.         print(y.size())
复制代码
总结

参考链接


免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。




欢迎光临 ToB企服应用市场:ToB评测及商务社交产业平台 (https://dis.qidao123.com/) Powered by Discuz! X3.4