当前位置: 首页 > news >正文

【AI模型学习】ESM2

文章目录

  • 1. 版本
  • 2. 开始
    • 2.1 安装
    • 2.2 使用预训练模型
      • 2.2.1 代码
      • 2.2.2 讲解
    • 2.2 结构预测
  • 3. 任务类型总结
      • 1. 蛋白质结构预测(ESMfold)
      • 2. 特征嵌入提取(esm-extract)
      • 3. 零镜头变体预测(ESM-1v/ESM-2)
      • 4. 逆向折叠(ESM-IF1)
      • 5. 宏基因组图谱数据(ESM Atlas)
      • 6. 多序列比对分析(ESM-MSA-1b)
      • 7. 生成式蛋白质设计(ESM-2)

1. 版本

ESM-2 一共有多个版本,主要区别在于:
层数(depth)参数量(size)推理速度和精度权衡
这些版本都遵循相同的 Transformer 编码器架构,只是在大小和计算能力上有差异。

版本一览

模型名称(Hugging Face 名)层数参数量说明
esm2_t6_8M_UR50D68M极小模型,适合快速原型
esm2_t12_35M_UR50D1235M中等小型,推荐用于入门任务
esm2_t30_150M_UR50D30150M中等模型,效果与效率平衡
esm2_t33_650M_UR50D33650M较大模型,适合更复杂任务
esm2_t36_3B_UR50D363B超大模型,强大建模能力
esm2_t48_15B_UR50D4815B最大模型,性能最强但最重

模型名一般遵循这个格式:

esm2_t<层数>_<参数量>_UR50D
  • t<层数>:比如 t12 表示有 12 层 Transformer 编码器。
  • <参数量>:大概的参数数量,比如 35M, 3B 等。
  • UR50D:代表训练用的数据集(Uniref50),是去冗余后的蛋白质数据库。
需求场景推荐模型
快速测试 / 教学 / CPU 调用esm2_t6_8Mt12_35M
标准下游任务建模esm2_t30_150M
结构预测 / 蛋白功能预测esm2_t33_650M 及以上
追求最强性能esm2_t48_15B

注意:3B 和 15B 模型通常需要 A100 等大显存 GPU,或者分布式推理支持。


2. 开始

2.1 安装

pip install fair-esm  # latest release, OR:
pip install git+https://github.com/facebookresearch/esm.git  # bleeding edge, current repo main branch

2.2 使用预训练模型

2.2.1 代码

我们直接拿官网的作为示例

import torch
import esm# 加载 ESM-2 模型
model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
batch_converter = alphabet.get_batch_converter()
model.eval()  # 设置为评估模式(关闭 dropout,确保结果可复现)# 准备数据(来自 ESMStructuralSplitDataset 数据集的前两个蛋白质序列)
data = [("protein1", "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG"),("protein2", "KALTARQQEVFDLIRDHISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE"),("protein2 with mask","KALTARQQEVFDLIRD<mask>ISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE"),("protein3",  "K A <mask> I S Q"),
]
batch_labels, batch_strs, batch_tokens = batch_converter(data)
batch_lens = (batch_tokens != alphabet.padding_idx).sum(1)# 提取每个残基的表示(使用 CPU 推理)
with torch.no_grad():results = model(batch_tokens, repr_layers=[33], return_contacts=True)
token_representations = results["representations"][33]# 通过对每个序列的残基取平均,生成序列级别的表示
# 注意:token 0 是序列起始符号 <cls>,第一个氨基酸是 token 1
sequence_representations = []
for i, tokens_len in enumerate(batch_lens):sequence_representations.append(token_representations[i, 1 : tokens_len - 1].mean(0))# 查看模型中无监督注意力图生成的接触预测图
import matplotlib.pyplot as plt
for (_, seq), tokens_len, attention_contacts in zip(data, batch_lens, results["contacts"]):plt.matshow(attention_contacts[: tokens_len, : tokens_len])plt.title(seq)plt.show()

2.2.2 讲解

results 的结构总览:

results.keys()
# dict_keys(['logits', 'representations', 'attentions', 'contacts'])

其中包含了 4 个主要部分,分别是:

  1. logitstorch.Size([4, 73, 33])

含义:每个位置的分类 logits,用于预测氨基酸(或 <mask> 的掩码预测)。

  • [4, 73, 33]

    • 4:表示 batch 中有 4 个序列
    • 73:是 padding 后的最大 token 长度
    • 33:是氨基酸 vocabulary 的大小(包括 <mask><pad> 等)

用途

  • 如果你用 <mask>,这个张量可以用于做“突变打分”或“掩码填空”
  • 可以通过 torch.nn.functional.softmax(logits, dim=-1) 得到每个位置的预测概率分布
  1. representationsdict,键是层号,值是 embedding

你指定了 repr_layers=[33],所以返回了:

representations = {33: torch.Size([4, 73, 1280])
}

含义:第 33 层(即最后一层)输出的每个 token 的 embedding 表示。

  • 维度 [4, 73, 1280]

    • 4:batch 大小
    • 73:token 序列长度(含 <cls><eos>
    • 1280:embedding 的维度(模型隐藏层大小)

用途

  • 提取每个残基的表示用于下游任务(分类、聚类、结构预测)
  • 可以对 1~L 之间的向量取平均,生成序列级表示
  1. attentionstorch.Size([4, 33, 20, 73, 73])

含义:每一层、每个头的 self-attention 权重

  • [4, 33, 20, 73, 73]

    • 4:batch 中 4 条序列
    • 33:ESM2 的 transformer 层数
    • 20:每层的 attention head 数量
    • 73 x 73:每个 head 的注意力矩阵

用途

  • 可视化每层每个头的注意力
  • 为接触图预测提供基础(即下一个)
  1. contactstorch.Size([4, 71, 71])

含义:预测的残基接触图(非监督 attention 平均生成的)

  • 维度 [4, 71, 71]:对应于每个序列的残基之间的接触概率

为什么是 71 而不是 73?
因为 73 包含了 <cls><eos>,它们会被自动排除,真正的残基只有 71 个。

用途

  • 可以直接作为结构接触预测图的初步结果
  • 在结构预测任务中可用于辅助建模残基之间的关系

总结(表格版)

类型维度含义
logitsTensor[B, L, V]每个 token 的分类 logits,用于 <mask> 推断
representationsdict{层号: Tensor[B, L, D]}每层 token 表示,通常用最后一层
attentionsTensor[B, L, H, T, T]每层每个 head 的注意力矩阵
contactsTensor[B, L', L']每个序列的残基接触预测图(不含 /)

2.2 结构预测

import torch
import esm# 这里中间会有几个库需要pip一下,看报错信息补好即可
# 不过好像。。这里面的问题有点小多。。。然后就是如果想做这一块的话,python版本放到3.9及以下
model = esm.pretrained.esmfold_v1()
model = model.eval().cuda()# 可选:取消注释以下语句以设置轴向注意力的块大小。这可以帮助降低显存占用。
# 块越小,显存占用越低,但计算速度可能会变慢。
# model.set_chunk_size(128)sequence = "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG"
# 多聚体预测时,可以用 ':' 分隔不同链with torch.no_grad():output = model.infer_pdb(sequence)with open("result.pdb", "w") as f:f.write(output)import biotite.structure.io as bsio
struct = bsio.load_structure("result.pdb", extra_fields=["b_factor"])
print(struct.b_factor.mean())  # 这将输出 pLDDT 分数的平均值
# 88.3

3. 任务类型总结

1. 蛋白质结构预测(ESMfold)

输入

  • 单序列:氨基酸序列字符串(如 "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG")。
  • 多聚体:用 : 分隔链的序列(如 "chainA:chainB")。
  • 批量输入:FASTA 文件(含多条序列)。

输出

  • 单序列结果:PDB 格式的蛋白质结构文件,包含原子坐标和 pLDDT 评分(预测置信度)。
    # 示例代码输出  
    output = model.infer_pdb(sequence)  # PDB 文本字符串  
    
  • 批量结果:指定目录下的多个 PDB 文件,文件名对应 FASTA 序列 ID。

2. 特征嵌入提取(esm-extract)

输入

  • 模型:预训练模型名称(如 esm2_t33_650M_UR50D)或本地模型路径。
  • 序列数据:FASTA 文件(含一条或多条序列)。
  • 参数:指定提取的层(--repr_layers)和输出类型(--include mean/per_tok/contacts)。

输出

  • ** per_tok 嵌入**:每个残基的特征向量(形状为 [seq_len, hidden_dim])。
  • ** mean 嵌入**:序列全局平均特征(形状为 [hidden_dim])。
  • 接触预测:注意力图导出的接触概率矩阵(形状为 [seq_len, seq_len])。
  • 文件格式:每个序列对应一个 .pt 文件,存储为 PyTorch 张量。

3. 零镜头变体预测(ESM-1v/ESM-2)

输入

  • 野生型序列:氨基酸序列字符串。
  • 突变位点:如 "A123G"(第123位丙氨酸突变为甘氨酸)。
  • 批量输入:CSV 或 FASTA 文件,包含多组野生型-突变序列对。

输出

  • 功能影响评分:突变对蛋白质功能的预测效应(如稳定性、活性变化)。
  • 示例输出:对数概率或相对效应值,用于排序突变的有害性。

4. 逆向折叠(ESM-IF1)

4.1 序列设计(给定结构采样)

  • 输入
    • 结构文件:PDB 或 mmCIF 文件(含主链坐标,如 5YH2.pdb)。
    • 链选择:指定目标链(如 --chain C)。
    • 参数:采样温度(--temperature,控制序列多样性)。
  • 输出:FASTA 文件,包含生成的氨基酸序列(如 sampled_sequences.fasta)。

4.2 序列评分(给定结构评估)

  • 输入
    • 结构文件:PDB/mmCIF 文件。
    • 序列文件:FASTA 文件(含待评分序列)。
  • 输出:CSV 文件,包含每条序列的平均对数似然值(如 5YH2_mutated_seqs_scores.csv)。

5. 宏基因组图谱数据(ESM Atlas)

输入

  • 查询方式
    • 序列搜索:FASTA 序列,通过 API 或 Foldseek 搜索相似结构。
    • 结构搜索:PDB/mmCIF 文件或结构 ID,检索同源结构。
  • 批量下载:通过官网提供的链接下载全量结构数据(如.tar.gz压缩包)。

输出

  • 结构数据:PDB 文件或预计算的 ESM-2 嵌入(.npy 或 .pt 文件)。
  • 搜索结果:匹配的结构列表,包含相似性分数和功能注释。

6. 多序列比对分析(ESM-MSA-1b)

输入

  • MSA 数据:A3M 格式的多序列比对文件(含同源序列)。
  • 模型输入:通过 esm.pretrained.esm_msa1b_t12_100M_UR50S() 加载模型。

输出

  • MSA 特征:从比对中提取的进化保守性嵌入,用于增强结构预测或功能分析。
  • 接触预测:结合 MSA 信息的残基接触图,精度高于单序列模型。

7. 生成式蛋白质设计(ESM-2)

输入

  • 设计约束:自然语言描述(如“设计一个具有 ATP 结合位点的螺旋结构”)或编程指令(如示例中的蛋白质编程语言)。

输出

  • 全新序列:满足特定功能或结构约束的氨基酸序列,可通过 ESMfold 验证结构合理性。

数据格式总结表

任务输入格式输出格式关键工具/模型
结构预测单/多聚体序列、FASTAPDB 文件、pLDDT 评分ESMfold、esm-fold
特征提取FASTA、模型名称.pt 张量文件esm-extract
变体预测野生型/突变序列功能影响评分ESM-1v、ESM-2
逆向折叠PDB/mmCIF、FASTAFASTA、CSV 评分文件ESM-IF1、sample_sequences.py
宏基因组搜索FASTA、结构文件匹配结构列表、嵌入数据ESM Atlas API
MSA 分析A3M 比对文件接触图、进化特征ESM-MSA-1b
生成式设计自然语言/编程指令氨基酸序列ESM-2
http://www.xdnf.cn/news/576307.html

相关文章:

  • 部署rsync远程同步+inotify监控
  • 前端学习(6)—— WebAPI部分案例
  • 前端面经-WebGL/threeJS
  • 《Saliency Attack: Towards Imperceptible Black-box Adversarial Attack》论文分享(侵删)
  • Spring AI 1.0 快速入门
  • NVIDIA GPU 性能调优与诊断完全指南
  • ConcurrentHashMap导致的死锁事故
  • 环境搭建
  • 根据Spring官方文档,三分钟完成Springboot项目集成Spring AI
  • sqli-labs第十七关——POST注入点
  • Spring Boot整合Redis
  • RestTemplate 发送的字段第二个大写字母变成小写的问题探究
  • 9-码蹄集600题基础python篇
  • leetcode 螺旋矩阵 java
  • 5-码蹄集600题基础python篇
  • 如何设计智慧工地系统的数据库?
  • 系统程序变更管理:确保IT环境稳定性和安全性的关键
  • Entity-Relationship Model(实体-关系模型)
  • FlashAttention:传统自注意力( Self-Attention)优化加速实现
  • 用户刷题记录日历——签到表功能实现
  • 基于 Guns v5.1 框架的分页教程
  • SseEmitter是什么
  • 卷积神经网络基础(十)
  • chrono类 根据duration 类的周期类型得到对应的周期名称
  • 预警功能深度测评:如何用系统降低设备突发故障率?
  • JavaScript常用事件
  • 第P10周:Pytorch实现车牌识别
  • 如何解决测试覆盖率与迭代速度的冲突问题?
  • 手搓四人麻将程序
  • 正大模型视角下的高频交易因子构建策略研究