【AlphaFold3】网络架构篇(2)|Input Embedding 对输入进行特征嵌入
- 博主简介:努力学习的22级计算机科学与技术本科生一枚🌸
- 博主主页: @Yaoyao2024
- 往期回顾:【AlphaFold3】网络架构篇(1)|概览+预测算法
- 每日一言🌼: 去留无意,闲看庭前花开花落;宠辱不惊,漫随天外云卷云舒。——《幽窗小记》🌺
前言
在这篇讲解/笔记文章的开头,我想先聊聊Featurisation
和Input Embedding
分别是什么,有什么样子的关系和区别。
=理解Featurisation(特征化)和Input Embedding(输入嵌入)的关系与区别,可以类比成“准备食材” & “食材预处理”的过程——前者是把原始材料变成能用的“食材”,后者是把这些食材切成模型“好入口”的形状。
先明确两个概念的核心含义
-
Featurisation(特征化,支撑材料2.8节内容):
简单说,就是把模型无法直接“看懂”的原始数据(比如分子的化学结构、蛋白质序列、进化信息等),转换成模型能“识别”的结构化特征。
比如:- 对于一个分子,原始数据可能是“分子式C6H12O6”“原子坐标(x,y,z)”,特征化会把这些转换成“原子类型(碳/氢/氧)”“化学键类型”“参考构象的空间信息”等;
- 对于进化信息(MSA),原始数据是一堆同源序列,特征化会提炼出“profile(序列保守性分布)”“deletion_mean(缺失率均值)”等总结性特征。
核心:把原始数据“翻译”成模型能处理的“特征语言”,是数据从“原始状态”到“可用状态”的第一步。
-
Input Embedding(输入嵌入,3.1节内容) :
是在特征化之后,把已经提取好的结构化特征,进一步转换成模型内部能“计算”的 向量形式。
就像你已经有了“苹果(红色、圆形、50g)”“香蕉(黄色、长条形、100g)”这些特征,嵌入就是把这些特征转换成一串数字(比如苹果→[0.2, 0.5, 0.3],香蕉→[0.8, 0.1, 0.6]),让模型能通过向量运算“理解”它们的差异和关系。
核心:把结构化特征“压缩”成模型能运算的向量,是特征从“可用状态”到“可计算状态”的关键一步。
关系:“上游”与“下游”,前者是后者的前提
Input Embedding完全依赖于Featurisation的结果。
比如你提供的代码里,InputFeatureEmbedder函数处理的f*
(包括f_i^{restype}
残基类型、f_i^{profile}
进化谱、f_i^{deletion_mean}
缺失均值等),都是2.8节Featurisation步骤已经提取好的结构化特征。
没有Featurisation先把原始数据变成这些特征,Input Embedding就成了“无米之炊”——巧妇难为无米之炊,模型也无法对原始数据直接做嵌入。
区别:处理阶段和目标不同
维度 | Featurisation(特征化) | Input Embedding(输入嵌入) |
---|---|---|
处理阶段 | 数据预处理阶段(模型“之外”的准备) | 模型输入阶段(模型“之内”的转换) |
处理对象 | 原始数据(如分子结构、原始序列) | 特征化后的结构化特征(如残基类型、进化谱) |
目标 | 得到“模型能识别的特征” | 得到“模型能计算的向量” |
关键特点 | 保留原始数据的核心信息,结构化 | 把特征压缩成低维向量,加入语义关联 |
举个生活化的例子
假设你要教一个机器人“认识水果”:
- Featurisation:你先把“苹果”的原始信息(红色、圆的、咬起来脆)整理成结构化的描述(颜色=红,形状=圆,口感=脆)——这一步是特征化,把“苹果”这个原始概念变成机器人能“看懂”的属性列表。
- Input Embedding:机器人把这些属性(红、圆、脆)转换成一串数字(比如[0.1, 0.8, 0.3]),方便它通过计算比较不同水果的差异(比如和“草莓”的[0.1, 0.2, 0.9]对比)——这一步是嵌入,把属性变成机器人能“运算”的语言。
总结:特征化是“整理信息”,嵌入是“翻译信息为模型语言”,前者是后者的基础,两者共同完成“原始数据→模型可处理向量”的全过程。
一、Input Embedding
翻译:
用户提供的任何键(通过token_bonds特征)均在算法 1 (上一篇博客中讲解)中进行线性嵌入;其他用户输入的嵌入,以及 RDKit 参考构象的嵌入,将在下文描述。
讲解
输入嵌入是将特征化后的各类信息(键、残基类型、空间结构等)转换为模型可计算的向量的过程,分为两部分:
- 用户提供的键(token_bonds):直接通过线性变换嵌入(算法1),快速整合分子间的成键信息(如蛋白质与配体的结合键)。
- 其他输入+RDKit参考构象:通过更复杂的逻辑嵌入(3.1.1和3.1.2),重点处理残基类型、空间结构和位置关系。
1.1 Input Embedder
翻译:
残基类型、参考构象和MSA汇总特征(分布谱和缺失均值)的嵌入过程如算法2所示。参考构象通过AtomAttentionEncoder
(算法5)以置换不变的方式(permutation invariant way) 进行嵌入(即原子的排列顺序不影响最终嵌入结果)。
👉🏻算法2:构建初始一维嵌入
def InputFeatureEmbedder({f *}) :
# 嵌入每个原子的特征
1: {ai}, _, _, _ = AtomAttentionEncoder({f *}, ∅, ∅, ∅, catom = 128, catompair = 16, ctoken = 384)
# 拼接每个令牌的特征
2: si = concat(ai, f_i^{restype}, f_i^{profile}, f_i^{deletion_mean})
3: return {si}
讲解:
一句话:输入嵌入器——整合原子与序列特征
算法2展示了初始一维嵌入的构建过程,核心是“原子特征→令牌特征”的整合:
步骤 1:通过 AtomAttentionEncoder 生成原子级嵌入{ai}\left\{\mathbf{a}_i\right\}{ai}
- 作用:将一个令牌包含的所有原子的特征,通过注意力机制“汇总”为一个向量
ai
(每个令牌而非单个原子对应一个ai
)。-
例1:一个标准氨基酸令牌(如丙氨酸)包含Cα、C、N、O等多个原子,编码器会计算这些原子间的相互作用(如键长、角度),最终输出一个
ai
(128维),代表这个氨基酸的整体原子结构特征。 -
例2:一个配体的原子令牌(假设配体的每个原子都是独立令牌),编码器会处理该单个原子的特征(如元素类型、周围原子环境),输出一个
ai
(128维),代表这个原子的结构特征。
-
- 关联逻辑(具体看算法5,下一篇博客中):
- 先将每个原子的特征写成
catom=128
维的 “原子小卡片”; - 再计算每对原子的关系,写成
catompair=16
维的 “关系小卡片”; - 然后将一个令牌内的所有 “原子小卡片” 和 “关系小卡片”
汇总、融合
,最终浓缩成一份代表该令牌的 “原子档案”—— 即ai; - 这份 “档案” 的维度受
ctoken=384
约束,确保格式统一且高效。
- 先将每个原子的特征写成
步骤 2:拼接令牌的其他特征,生成{si}
- 拼接逻辑:将步骤1生成的令牌级原子嵌入
ai
,与该令牌的另外三类特征按“维度拼接”(类似把多段绳子接成一根):f_i^{restype}
:残基类型(如“丙氨酸”“腺嘌呤”的独热编码,32维);f_i^{profile}
:MSA进化谱(该令牌在同源序列中的保守性分布,32维);f_i^{deletion_mean}
:MSA缺失率均值(该令牌在同源序列中的缺失频率,1维)。
- 例:一个丙氨酸令牌的
si
=ai
(128维) + 丙氨酸类型(32维) + 进化保守性(32维) + 缺失率(1维) = 193维向量(维度相加)。
1.2 Relative position encoding 相对位置编码
与AlphaFold2和AlphaFold-Multimer类似,相对编码用于打破相同残基(aijrel_pos\mathbf{a}_{ij}^{\mathrm{rel}{\_\mathrm{pos}}}aijrel_pos)和链(aijrel_chain\mathbf{a}_{ij}^\text{rel\_chain}aijrel_chain)的对称性。在AlphaFold 3中,我们引入了一种相对令牌编码,适用于同一残基内的令牌(见算法3)。相对位置和令牌索引的取值范围被限制在[rmin,rmax][r_{\min},r_{\max}][rmin,rmax],其中rmax=32r_{max}=32rmax=32;相对链索引的取值范围被限制在[smin,smax][s_{\mathrm{min}},s_{\mathrm{max}}][smin,smax],其中smax=2s_{\max}=2smax=2。
🪧讲解:”对称性“
位置编码最早应该是在transformer解决NLP问题提出。比如一句话”我打你”,“你打我”,这两个句子中“你”和“我”这两个词分别的特征编码一样,但是位置不同,就产生很大的语义差别。如果不加以位置约束,则语义就不正确。在蛋白质结构领域,也是如此。一条氨基酸序列,如果同一个氨基酸,位置不太,则会影响蛋白质的折叠从而影响蛋白质的结构。所以进行相对位置编码是很有必要的。
👉🏻算法3 相对位置编码
p.s.在这段相对位置编码算法中,clip函数是一种数值截断函数,核心作用是将输入的数值限制在一个指定的范围内,避免数值过大或过小对编码结果产生干扰。具体来说:
clip
函数的数学表达
在算法中,clip(x, min_val, max_val)
的含义是:
- 如果
x
小于min_val
,则返回min_val
;- 如果
x
大于max_val
,则返回max_val
;- 如果
x
在[min_val, max_val]
范围内,则直接返回x
。
讲解:
步骤4-5:残基级相对位置编码(aijrel_posa_{i j}^{rel\_pos}aijrel_pos)
-
步骤4:计算残基索引差值 dijresidued_{i j}^{residue}dijresidue:
dijresidue={clip(firesidue_index−fjresidue_index+rmax,0,2⋅rmax)if bijsame_chain2⋅rmax+1elsed_{i j}^{residue }= \begin{cases} clip(f_{i}^{residue\_index }-f_{j}^{residue\_index }+ r_{max}, 0, 2 \cdot r_{max}) & \text{if } b_{i j}^{same\_chain } \\ 2 \cdot r_{max} + 1 & \text{else} \end{cases} dijresidue={clip(firesidue_index−fjresidue_index+rmax,0,2⋅rmax)2⋅rmax+1if bijsame_chainelse- 若
i
和j
在同一条链(bijsame_chain=Trueb_{i j}^{same\_chain }=Truebijsame_chain=True):计算两者残基索引的差值,加上r_max
后裁剪到[0, 2*r_max]
(避免差值过大,控制编码范围)。 - 若不在同一条链:用特殊值2∗rmax+12*r_{max} + 12∗rmax+1标记(表示“跨链残基”,无直接位置关系)。
- 若
-
步骤5:对dijresidued_{i j}^{residue}dijresidue做独热编码(
one_hot
):
aijrel_pos=one_hot(dijresidue,[0,...,2⋅rmax+1])a_{i j}^{rel\_pos }= one\_hot \left(d_{i j}^{residue },\left[0, ..., 2 \cdot r_{max }+1\right]\right)aijrel_pos=one_hot(dijresidue,[0,...,2⋅rmax+1])
独热向量长度为2∗r_max+22*r\_max + 22∗r_max+2(覆盖000到2∗r_max+12*r\_max + 12∗r_max+1的所有可能值),用于将离散的位置差转为模型可理解的特征。
步骤6-7:令牌级相对位置编码(aijrel_tokena_{i j}^{rel\_token}aijrel_token)
用于描述同一残基内不同token的位置差(“令牌”可理解为残基内的细分单元,如原子、特征点):
-
步骤6:计算令牌索引差值dijtokend_{i j}^{token}dijtoken:
dijtoken={clip(fitoken_index−fjtoken_index+rmax,0,2⋅rmax)if bijsame_chainand bijsame_residue2⋅rmax+1elsed_{i j}^{token }= \begin{cases} clip(f_{i}^{token\_index }-f_{j}^{token\_index }+ r_{max}, 0, 2 \cdot r_{max}) & \text{if } b_{i j}^{same\_chain } \text{ and } b_{i j}^{same\_residue } \\ 2 \cdot r_{max} + 1 & \text{else} \end{cases} dijtoken={clip(fitoken_index−fjtoken_index+rmax,0,2⋅rmax)2⋅rmax+1if bijsame_chain and bijsame_residueelse- 仅当
i
和j
在同一条链且同一残基(same_chain
且same_residue
)时,计算令牌索引差值并裁剪;否则用2*r_max + 1
标记(表示“非同一残基的令牌”)。
- 仅当
-
步骤7:对dijtokend_{i j}^{token}dijtoken做独热编码:
aijreltoken=one_hot(dijtoken,[0,...,2⋅rmax+1])a_{i j}^{rel_token }= one\_hot \left(d_{i j}^{token },\left[0, ..., 2 \cdot r_{max }+1\right]\right)aijreltoken=one_hot(dijtoken,[0,...,2⋅rmax+1])
作用同残基编码,将令牌间的相对位置转为特征向量。
步骤8-9:链级相对位置编码(aijrel_chaina_{i j}^{rel\_chain}aijrel_chain)
-
步骤8:计算链标识差值dijchaind_{i j}^{chain}dijchain:
dijchain={clip(fisym_id−fjsym_id+smax,0,2⋅smax)if not bijsame_chain2⋅smax+1elsed_{i j}^{chain }= \begin{cases} clip(f_{i}^{sym\_id }-f_{j}^{sym\_id }+ s_{max}, 0, 2 \cdot s_{max}) & \text{if not } b_{i j}^{same\_chain } \\ 2 \cdot s_{max} + 1 & \text{else} \end{cases} dijchain={clip(fisym_id−fjsym_id+smax,0,2⋅smax)2⋅smax+1if not bijsame_chainelse- 若
i
和j
不在同一条链(notbijsamechainnot b_{i j}^{same_chain}notbijsamechain):用sym_id
(链的对称标识,区分对称链)计算差值,加上smax
后裁剪到[0, 2*smax]
。 - 若在同一条链:用特殊值
2*smax + 1
标记(表示“同链”,无链间差)。
- 若
-
步骤9:对dijchaind_{i j}^{chain}dijchain做独热编码:
aijrel_chain=onehot(dijchain,[0,...,2⋅smax+1])a_{i j}^{rel\_chain }= one_hot \left(d_{i j}^{chain },\left[0, ..., 2 \cdot s_{max }+1\right]\right)aijrel_chain=onehot(dijchain,[0,...,2⋅smax+1])
编码不同链之间的相对关系(如对称链、异源链)。
步骤10-11:融合编码(pijp_{i j}pij)
将上述4类特征拼接后通过线性层映射为统一维度的编码:
- 步骤10:拼接aijrel_posa_{i j}^{rel\_pos}aijrel_pos(残基相对位置)、aijrel_tokena_{i j}^{rel\_token}aijrel_token(令牌相对位置)、bijsame_entityb_{i j}^{same\_entity}bijsame_entity(是否同实体,布尔值转为0/1向量)、aijrel_chaina_{i j}^{rel\_chain}aijrel_chain(链相对关系),再通过无偏置线性层(
LinearNoBias
)映射到cz
维度:
pij=LinearNoBias(concat([aijrelpos,aijreltoken,bijsameentity,aijrelchain])p_{i j} = LinearNoBias\left(concat\left([a_{i j}^{rel_pos }, a_{i j}^{rel_token }, b_{i j}^{same_entity }, a_{i j}^{rel_chain }\right]\right)pij=LinearNoBias(concat([aijrelpos,aijreltoken,bijsameentity,aijrelchain]) - 步骤11:返回所有pijp_{i j}pij,作为
i
与j
的综合相对位置编码。
总结:核心作用:打破对称性**
在分子/蛋白质结构中,很多元素(如同一残基的不同令牌、对称链的残基)在绝对特征上可能完全一致(即“对称”),导致模型无法区分它们的相对位置。该算法通过:
- 量化残基、令牌、链的相对差值,用独热编码标记“差异”;
- 区分“同链/跨链”“同残基/跨残基”“同实体/跨实体”等关系,为对称元素赋予独特的相对特征;
最终让模型能识别“看似相同但位置不同”的元素,提升对结构关系的建模能力。
👉🏻算法4 基于最近区间的独热编码
def one_hot(x, vbins) : # x为输入值,vbins为区间列表
1: p = 0 # 初始化独热向量(长度为区间数)
2: b = arg min(|x − vbins|) # 找到x最接近的区间索引
3: p_b = 1 # 对应索引的位置设为1,其余为0
4: return p
讲解:
独热编码是将“连续值”(如位置差3)转换为“离散向量”(如[0,0,0,1,0,…])的工具。算法4的逻辑很简单:找到输入值最接近的区间,将对应位置设为1。例如,位置差3最接近区间[3],则向量中第3位为1,其余为0。这让模型能更高效地学习离散的位置关系。
总结
输入嵌入的核心目标是:将分子的化学特征(原子类型、键)、序列特征(残基类型、MSA)和空间特征(相对位置)统一编码为向量,既保留关键信息,又让模型能通过计算处理。其中,“置换不变嵌入”确保配体等分子的原子顺序不影响结果,“相对位置编码”帮助模型建立空间感——这两步共同为后续的Pairformer和扩散模块提供了高质量的输入“语言”。