ESMC-600M蛋白质语言模子本地部署攻略

打印 上一主题 下一主题

主题 983|帖子 983|积分 2951

马上注册,结交更多好友,享用更多功能,让你轻松玩转社区。

您需要 登录 才可以下载或查看,没有账号?立即注册

x
前言

之前介绍了ESMC-6B模子的网络接口调用方法,但申请token比较慢,有网友问能不能出一个本地部署ESMC小模子的攻略,遂有本文。
其实本地部署并不复杂,官方github上面也比较清楚了。
操作过程

情况设置:CUDA 12.1、torch 2.2.1+cu121、esm 3.1.1
完整的情况包列表:(因为做了些其他任务,这个里面其实不是全部都会用到,可以先把上面三个安装好,差哪些库再补哪些库)
  1. Package                  Version
  2. ------------------------ ------------
  3. asttokens                3.0.0
  4. attrs                    24.3.0
  5. biopython                1.84
  6. biotite                  0.41.2
  7. Brotli                   1.1.0
  8. certifi                  2024.12.14
  9. charset-normalizer       3.4.0
  10. cloudpathlib             0.20.0
  11. decorator                5.1.1
  12. einops                   0.8.0
  13. esm                      3.1.1
  14. executing                2.1.0
  15. filelock                 3.13.1
  16. fsspec                   2024.2.0
  17. huggingface-hub          0.27.0
  18. idna                     3.10
  19. ipython                  8.30.0
  20. jedi                     0.19.2
  21. Jinja2                   3.1.3
  22. joblib                   1.4.2
  23. MarkupSafe               2.1.5
  24. matplotlib-inline        0.1.7
  25. mpmath                   1.3.0
  26. msgpack                  1.1.0
  27. msgpack-numpy            0.4.8
  28. networkx                 3.2.1
  29. numpy                    1.26.3
  30. nvidia-cublas-cu12       12.1.3.1
  31. nvidia-cuda-cupti-cu12   12.1.105
  32. nvidia-cuda-nvrtc-cu12   12.1.105
  33. nvidia-cuda-runtime-cu12 12.1.105
  34. nvidia-cudnn-cu12        8.9.2.26
  35. nvidia-cufft-cu12        11.0.2.54
  36. nvidia-curand-cu12       10.3.2.106
  37. nvidia-cusolver-cu12     11.4.5.107
  38. nvidia-cusparse-cu12     12.1.0.106
  39. nvidia-nccl-cu12         2.19.3
  40. nvidia-nvjitlink-cu12    12.1.105
  41. nvidia-nvtx-cu12         12.1.105
  42. packaging                24.2
  43. pandas                   2.2.3
  44. parso                    0.8.4
  45. pexpect                  4.9.0
  46. pillow                   10.2.0
  47. pip                      24.2
  48. prompt_toolkit           3.0.48
  49. ptyprocess               0.7.0
  50. pure_eval                0.2.3
  51. Pygments                 2.18.0
  52. python-dateutil          2.9.0.post0
  53. pytz                     2024.2
  54. PyYAML                   6.0.2
  55. regex                    2024.11.6
  56. requests                 2.32.3
  57. safetensors              0.4.5
  58. scikit-learn             1.6.0
  59. scipy                    1.14.1
  60. setuptools               75.1.0
  61. six                      1.17.0
  62. stack-data               0.6.3
  63. sympy                    1.13.1
  64. tenacity                 9.0.0
  65. threadpoolctl            3.5.0
  66. tokenizers               0.20.3
  67. torch                    2.2.1+cu121
  68. torchdata                0.7.1
  69. torchtext                0.17.1
  70. torchvision              0.17.1+cu121
  71. tqdm                     4.67.1
  72. traitlets                5.14.3
  73. transformers             4.46.3
  74. triton                   2.2.0
  75. typing_extensions        4.9.0
  76. tzdata                   2024.2
  77. urllib3                  2.2.3
  78. wcwidth                  0.2.13
  79. wheel                    0.44.0
复制代码
下载ESMC-600m的权重:
EvolutionaryScale/esmc-600m-2024-12 at main
下载之后把权重放在工作目录下的这个地址:data/weights

代码

和官方github上给出的例子比较雷同,不过加了些修改。

  1. from esm.models.esmc import ESMC
  2. from esm.sdk.api import *
  3. import torch
  4. import os
  5. import pickle
  6. from esm.tokenization import EsmSequenceTokenizer
  7. # 使用预下载的参数
  8. os.environ["INFRA_PROVIDER"] = "True"
  9. device = torch.device("cuda:0")
  10. client = ESMC.from_pretrained("esmc_600m",device=device)
  11. # 读取蛋白质序列,这里需要根据自己的数据格式进行调整
  12. def read_seq(seqfilepath):
  13.     with open(seqfilepath,"r") as f:
  14.         line = f.readline()
  15.         seq = f.readline()
  16.     return seq
  17. # 这里沿用了上一次逆向出来的编码格式,可以替换为ESM自带的编码格式
  18. all_amino_acid_number = {'A':5, 'C':23,'D':13,'E':9, 'F':18,
  19.                          'G':6, 'H':21,'I':12,'K':15,'L':4,
  20.                          'M':20,'N':17,'P':14,'Q':16,'R':10,
  21.                          'S':8, 'T':11,'V':7, 'W':22,'Y':19,
  22.                          '_':32}
  23. def esm_encoder_seq(seq, pad_len):
  24.     s = [all_amino_acid_number[x] for x in seq]
  25.     while len(s)<pad_len:
  26.         s.append(1)
  27.     s.insert(0,0)
  28.     s.append(2)
  29.     return torch.tensor(s)
  30. def get_esm_embedding(seq):
  31.     protein_tensor = ESMProteinTensor(sequence=esm_encoder_seq(seq,len(seq)).to(device))
  32.     logits_output = client.logits(protein_tensor, LogitsConfig(sequence=True, return_embeddings=True))
  33.     esm_embedding = logits_output.embeddings
  34.     assert isinstance(esm_embedding,torch.Tensor)
  35.     return esm_embedding
  36. # 这个路径设置并不重要,可以自行调整
  37. seq_path = "seq.fasta"
  38. seq = read_seq(seq_path)
  39. print(seq)
  40. # 获取序列embedding
  41. seq_list = [seq]
  42. emb = get_esm_embedding(seq)
  43. with open("seq_emb.pkl","wb") as f:
  44.     pickle.dump(emb,f)
  45. print(emb.shape)
复制代码
 任意用了一个序列,得到的运行结果,tensor形状是[1,序列长度+2,1152]


免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

祗疼妳一个

金牌会员
这个人很懒什么都没写!
快速回复 返回顶部 返回列表