【AI模型学习】ESM2
文章目录
- 1. 版本
- 2. 开始
- 2.1 安装
- 2.2 使用预训练模型
- 2.2.1 代码
- 2.2.2 讲解
- 2.2 结构预测
- 3. 任务类型总结
- 1. 蛋白质结构预测(ESMfold)
- 2. 特征嵌入提取(esm-extract)
- 3. 零镜头变体预测(ESM-1v/ESM-2)
- 4. 逆向折叠(ESM-IF1)
- 5. 宏基因组图谱数据(ESM Atlas)
- 6. 多序列比对分析(ESM-MSA-1b)
- 7. 生成式蛋白质设计(ESM-2)
1. 版本
ESM-2 一共有多个版本,主要区别在于:
层数(depth)、参数量(size)、推理速度和精度权衡。
这些版本都遵循相同的 Transformer 编码器架构,只是在大小和计算能力上有差异。
版本一览
模型名称(Hugging Face 名) | 层数 | 参数量 | 说明 |
---|---|---|---|
esm2_t6_8M_UR50D | 6 | 8M | 极小模型,适合快速原型 |
esm2_t12_35M_UR50D | 12 | 35M | 中等小型,推荐用于入门任务 |
esm2_t30_150M_UR50D | 30 | 150M | 中等模型,效果与效率平衡 |
esm2_t33_650M_UR50D | 33 | 650M | 较大模型,适合更复杂任务 |
esm2_t36_3B_UR50D | 36 | 3B | 超大模型,强大建模能力 |
esm2_t48_15B_UR50D | 48 | 15B | 最大模型,性能最强但最重 |
模型名一般遵循这个格式:
esm2_t<层数>_<参数量>_UR50D
t<层数>
:比如t12
表示有 12 层 Transformer 编码器。<参数量>
:大概的参数数量,比如35M
,3B
等。UR50D
:代表训练用的数据集(Uniref50),是去冗余后的蛋白质数据库。
需求场景 | 推荐模型 |
---|---|
快速测试 / 教学 / CPU 调用 | esm2_t6_8M 或 t12_35M |
标准下游任务建模 | esm2_t30_150M |
结构预测 / 蛋白功能预测 | esm2_t33_650M 及以上 |
追求最强性能 | esm2_t48_15B |
注意:3B 和 15B 模型通常需要 A100 等大显存 GPU,或者分布式推理支持。
2. 开始
2.1 安装
pip install fair-esm # latest release, OR:
pip install git+https://github.com/facebookresearch/esm.git # bleeding edge, current repo main branch
2.2 使用预训练模型
2.2.1 代码
我们直接拿官网的作为示例
import torch
import esm# 加载 ESM-2 模型
model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
batch_converter = alphabet.get_batch_converter()
model.eval() # 设置为评估模式(关闭 dropout,确保结果可复现)# 准备数据(来自 ESMStructuralSplitDataset 数据集的前两个蛋白质序列)
data = [("protein1", "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG"),("protein2", "KALTARQQEVFDLIRDHISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE"),("protein2 with mask","KALTARQQEVFDLIRD<mask>ISQTGMPPTRAEIAQRLGFRSPNAAEEHLKALARKGVIEIVSGASRGIRLLQEE"),("protein3", "K A <mask> I S Q"),
]
batch_labels, batch_strs, batch_tokens = batch_converter(data)
batch_lens = (batch_tokens != alphabet.padding_idx).sum(1)# 提取每个残基的表示(使用 CPU 推理)
with torch.no_grad():results = model(batch_tokens, repr_layers=[33], return_contacts=True)
token_representations = results["representations"][33]# 通过对每个序列的残基取平均,生成序列级别的表示
# 注意:token 0 是序列起始符号 <cls>,第一个氨基酸是 token 1
sequence_representations = []
for i, tokens_len in enumerate(batch_lens):sequence_representations.append(token_representations[i, 1 : tokens_len - 1].mean(0))# 查看模型中无监督注意力图生成的接触预测图
import matplotlib.pyplot as plt
for (_, seq), tokens_len, attention_contacts in zip(data, batch_lens, results["contacts"]):plt.matshow(attention_contacts[: tokens_len, : tokens_len])plt.title(seq)plt.show()
2.2.2 讲解
results
的结构总览:
results.keys()
# dict_keys(['logits', 'representations', 'attentions', 'contacts'])
其中包含了 4 个主要部分,分别是:
- logits:
torch.Size([4, 73, 33])
含义:每个位置的分类 logits,用于预测氨基酸(或 <mask>
的掩码预测)。
-
[4, 73, 33]
- 4:表示 batch 中有 4 个序列
- 73:是 padding 后的最大 token 长度
- 33:是氨基酸 vocabulary 的大小(包括
<mask>
、<pad>
等)
用途:
- 如果你用
<mask>
,这个张量可以用于做“突变打分”或“掩码填空” - 可以通过
torch.nn.functional.softmax(logits, dim=-1)
得到每个位置的预测概率分布
- representations:
dict
,键是层号,值是 embedding
你指定了 repr_layers=[33]
,所以返回了:
representations = {33: torch.Size([4, 73, 1280])
}
含义:第 33 层(即最后一层)输出的每个 token 的 embedding 表示。
-
维度
[4, 73, 1280]
:- 4:batch 大小
- 73:token 序列长度(含
<cls>
和<eos>
) - 1280:embedding 的维度(模型隐藏层大小)
用途:
- 提取每个残基的表示用于下游任务(分类、聚类、结构预测)
- 可以对 1~L 之间的向量取平均,生成序列级表示
- attentions:
torch.Size([4, 33, 20, 73, 73])
含义:每一层、每个头的 self-attention 权重
-
[4, 33, 20, 73, 73]
- 4:batch 中 4 条序列
- 33:ESM2 的 transformer 层数
- 20:每层的 attention head 数量
- 73 x 73:每个 head 的注意力矩阵
用途:
- 可视化每层每个头的注意力
- 为接触图预测提供基础(即下一个)
- contacts:
torch.Size([4, 71, 71])
含义:预测的残基接触图(非监督 attention 平均生成的)
- 维度
[4, 71, 71]
:对应于每个序列的残基之间的接触概率
为什么是 71 而不是 73?
因为 73 包含了<cls>
和<eos>
,它们会被自动排除,真正的残基只有 71 个。
用途:
- 可以直接作为结构接触预测图的初步结果
- 在结构预测任务中可用于辅助建模残基之间的关系
总结(表格版)
键 | 类型 | 维度 | 含义 |
---|---|---|---|
logits | Tensor | [B, L, V] | 每个 token 的分类 logits,用于 <mask> 推断 |
representations | dict | {层号: Tensor[B, L, D]} | 每层 token 表示,通常用最后一层 |
attentions | Tensor | [B, L, H, T, T] | 每层每个 head 的注意力矩阵 |
contacts | Tensor | [B, L', L'] | 每个序列的残基接触预测图(不含 /) |
2.2 结构预测
import torch
import esm# 这里中间会有几个库需要pip一下,看报错信息补好即可
# 不过好像。。这里面的问题有点小多。。。然后就是如果想做这一块的话,python版本放到3.9及以下
model = esm.pretrained.esmfold_v1()
model = model.eval().cuda()# 可选:取消注释以下语句以设置轴向注意力的块大小。这可以帮助降低显存占用。
# 块越小,显存占用越低,但计算速度可能会变慢。
# model.set_chunk_size(128)sequence = "MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG"
# 多聚体预测时,可以用 ':' 分隔不同链with torch.no_grad():output = model.infer_pdb(sequence)with open("result.pdb", "w") as f:f.write(output)import biotite.structure.io as bsio
struct = bsio.load_structure("result.pdb", extra_fields=["b_factor"])
print(struct.b_factor.mean()) # 这将输出 pLDDT 分数的平均值
# 88.3
3. 任务类型总结
1. 蛋白质结构预测(ESMfold)
输入
- 单序列:氨基酸序列字符串(如
"MKTVRQERLKSIVRILERSKEPVSGAQLAEELSVSRQVIVQDIAYLRSLGYNIVATPRGYVLAGG"
)。 - 多聚体:用
:
分隔链的序列(如"chainA:chainB"
)。 - 批量输入:FASTA 文件(含多条序列)。
输出
- 单序列结果:PDB 格式的蛋白质结构文件,包含原子坐标和 pLDDT 评分(预测置信度)。
# 示例代码输出 output = model.infer_pdb(sequence) # PDB 文本字符串
- 批量结果:指定目录下的多个 PDB 文件,文件名对应 FASTA 序列 ID。
2. 特征嵌入提取(esm-extract)
输入
- 模型:预训练模型名称(如
esm2_t33_650M_UR50D
)或本地模型路径。 - 序列数据:FASTA 文件(含一条或多条序列)。
- 参数:指定提取的层(
--repr_layers
)和输出类型(--include mean/per_tok/contacts
)。
输出
- ** per_tok 嵌入**:每个残基的特征向量(形状为
[seq_len, hidden_dim]
)。 - ** mean 嵌入**:序列全局平均特征(形状为
[hidden_dim]
)。 - 接触预测:注意力图导出的接触概率矩阵(形状为
[seq_len, seq_len]
)。 - 文件格式:每个序列对应一个
.pt
文件,存储为 PyTorch 张量。
3. 零镜头变体预测(ESM-1v/ESM-2)
输入
- 野生型序列:氨基酸序列字符串。
- 突变位点:如
"A123G"
(第123位丙氨酸突变为甘氨酸)。 - 批量输入:CSV 或 FASTA 文件,包含多组野生型-突变序列对。
输出
- 功能影响评分:突变对蛋白质功能的预测效应(如稳定性、活性变化)。
- 示例输出:对数概率或相对效应值,用于排序突变的有害性。
4. 逆向折叠(ESM-IF1)
4.1 序列设计(给定结构采样)
- 输入:
- 结构文件:PDB 或 mmCIF 文件(含主链坐标,如
5YH2.pdb
)。 - 链选择:指定目标链(如
--chain C
)。 - 参数:采样温度(
--temperature
,控制序列多样性)。
- 结构文件:PDB 或 mmCIF 文件(含主链坐标,如
- 输出:FASTA 文件,包含生成的氨基酸序列(如
sampled_sequences.fasta
)。
4.2 序列评分(给定结构评估)
- 输入:
- 结构文件:PDB/mmCIF 文件。
- 序列文件:FASTA 文件(含待评分序列)。
- 输出:CSV 文件,包含每条序列的平均对数似然值(如
5YH2_mutated_seqs_scores.csv
)。
5. 宏基因组图谱数据(ESM Atlas)
输入
- 查询方式:
- 序列搜索:FASTA 序列,通过 API 或 Foldseek 搜索相似结构。
- 结构搜索:PDB/mmCIF 文件或结构 ID,检索同源结构。
- 批量下载:通过官网提供的链接下载全量结构数据(如.tar.gz压缩包)。
输出
- 结构数据:PDB 文件或预计算的 ESM-2 嵌入(.npy 或 .pt 文件)。
- 搜索结果:匹配的结构列表,包含相似性分数和功能注释。
6. 多序列比对分析(ESM-MSA-1b)
输入
- MSA 数据:A3M 格式的多序列比对文件(含同源序列)。
- 模型输入:通过
esm.pretrained.esm_msa1b_t12_100M_UR50S()
加载模型。
输出
- MSA 特征:从比对中提取的进化保守性嵌入,用于增强结构预测或功能分析。
- 接触预测:结合 MSA 信息的残基接触图,精度高于单序列模型。
7. 生成式蛋白质设计(ESM-2)
输入
- 设计约束:自然语言描述(如“设计一个具有 ATP 结合位点的螺旋结构”)或编程指令(如示例中的蛋白质编程语言)。
输出
- 全新序列:满足特定功能或结构约束的氨基酸序列,可通过 ESMfold 验证结构合理性。
数据格式总结表
任务 | 输入格式 | 输出格式 | 关键工具/模型 |
---|---|---|---|
结构预测 | 单/多聚体序列、FASTA | PDB 文件、pLDDT 评分 | ESMfold、esm-fold |
特征提取 | FASTA、模型名称 | .pt 张量文件 | esm-extract |
变体预测 | 野生型/突变序列 | 功能影响评分 | ESM-1v、ESM-2 |
逆向折叠 | PDB/mmCIF、FASTA | FASTA、CSV 评分文件 | ESM-IF1、sample_sequences.py |
宏基因组搜索 | FASTA、结构文件 | 匹配结构列表、嵌入数据 | ESM Atlas API |
MSA 分析 | A3M 比对文件 | 接触图、进化特征 | ESM-MSA-1b |
生成式设计 | 自然语言/编程指令 | 氨基酸序列 | ESM-2 |