如何做表征对齐?
一、点对点软对齐(Micro-level)
核心思路:
使用相似度函数(如 cosine similarity)对 source 和 target 分布中的样本建立“点对点”匹配,通过 soft matching 实现特征对齐。
方法:
1. Token-wise Similarity Soft Alignment(REPA)
- 给定源域特征分布 F ( s ) = { f 1 ( s ) , … , f N ( s ) } F^{(s)} = \{f^{(s)}_1, \dots, f^{(s)}_N\} F(s)={f1(s),…,fN(s)} 与目标域特征分布 F ( t ) = { f 1 ( t ) , … , f N ( t ) } F^{(t)} = \{f^{(t)}_1, \dots, f^{(t)}_N\} F(t)={f1(t),…,fN(t)},每个 token 特征 f n ( s ) , f n ( t ) ∈ R d f^{(s)}_n, f^{(t)}_n \in \mathbb{R}^d fn(s),fn(t)∈Rd。
- 使用投影函数 h ϕ ( ⋅ ) h_\phi(\cdot) hϕ(⋅) 将 target 特征映射到 source 的特征空间以实现对齐。
- 定义 token 级的 soft alignment 损失如下:
L REPA ( ϕ ) = − E x ∼ D [ 1 N ∑ n = 1 N sim ( f n ( s ) , h ϕ ( f n ( t ) ) ) ] \mathcal{L}_{\text{REPA}}(\phi) = - \mathbb{E}_{x \sim \mathcal{D}} \left[ \frac{1}{N} \sum_{n=1}^{N} \text{sim} \left( f^{(s)}_n,\; h_\phi(f^{(t)}_n) \right) \right] LREPA(ϕ)=−Ex∼D[N1n=1∑Nsim(fn(s),hϕ(fn(t)))]
其中:
- f n ( s ) f^{(s)}_n fn(s) 表示第 n n n 个源域 token 特征;
- f n ( t ) f^{(t)}_n fn(t) 表示第 n n n 个目标域 token 特征;
- h ϕ h_\phi hϕ 为可学习的特征对齐函数;
- sim ( ⋅ , ⋅ ) \text{sim}(\cdot, \cdot) sim(⋅,⋅) 为相似度函数(如 cosine similarity);
- D \mathcal{D} D 表示样本采样分布(如训练数据、时间步等)。
作用: 该方法实现源域与目标域在 token 层级的“一对一”特征对齐,适用于同构分布或 token 数量一致的情况。
变种: Marginal Cosine Similarity Loss ( L mcos \mathcal{L}_{\text{mcos}} Lmcos)
- 对于 flatten 后的特征 x s , x t ∈ R N × d x^s, x^t \in \mathbb{R}^{N \times d} xs,xt∈RN×d( N = h × w N = h \times w N=h×w):
- 计算每个位置 ( i , j ) (i, j) (i,j) 处的余弦相似度并加 margin:
L mcos = 1 h × w ∑ i = 1 h ∑ j = 1 w ReLU ( 1 − m 1 − x i j s ⋅ x i j t ∥ x i j s ∥ ∥ x i j t ∥ ) \mathcal{L}_{\text{mcos}} = \frac{1}{h \times w} \sum_{i=1}^h \sum_{j=1}^w \text{ReLU}\left(1 - m_1 - \frac{x^s_{ij} \cdot x^t_{ij}}{\|x^s_{ij}\| \|x^t_{ij}\|} \right) Lmcos=h×w1i=1∑hj=1∑wReLU(1−m1−∥xijs∥∥xijt∥xijs⋅xijt)
作用:只惩罚相似度低于 margin( m 1 m_1 m1) 的点对,使低相似度点对更加对齐。
2. Soft Nearest Neighbor Matching
- 构建 source 到 target 的相似度矩阵 S i j = sim ( x i s , x j t ) S_{ij} = \text{sim}(x_i^s, x_j^t) Sij=sim(xis,xjt)。
- 使用 softmax 对相似度矩阵按行归一化,构建 soft correspondence。
- 对应的损失函数可表示为:
L p t p = ∑ i KL ( SoftSim ( x i s , X t ) ∥ SoftSim ( x i t , X s ) ) \mathcal{L}_{ptp} = \sum_{i} \text{KL}(\text{SoftSim}(x_i^s, X^t) \,\|\, \text{SoftSim}(x_i^t, X^s)) Lptp=i∑KL(SoftSim(xis,Xt)∥SoftSim(xit,Xs))
其中:
- SoftSim ( x i s , X t ) = softmax ( sim ( x i s , X t ) ) \text{SoftSim}(x_i^s, X^t) = \text{softmax}(\text{sim}(x_i^s, X^t)) SoftSim(xis,Xt)=softmax(sim(xis,Xt))
- sim \text{sim} sim 可为 cosine similarity 或 dot-product similarity
3. Contrastive / Triplet Loss
-
用于拉近相似点对,拉远不相似点对,适用于有监督或伪标签构造下的无监督场景。
-
Contrastive loss:
L contrastive = y ⋅ D 2 + ( 1 − y ) ⋅ max ( 0 , m − D ) 2 \mathcal{L}_{\text{contrastive}} = y \cdot D^2 + (1 - y) \cdot \max(0, m - D)^2 Lcontrastive=y⋅D2+(1−y)⋅max(0,m−D)2
- Triplet loss:
L triplet = max ( 0 , D ( x s , x t + ) − D ( x s , x t − ) + m ) \mathcal{L}_{\text{triplet}} = \max(0, D(x^s, x^{t+}) - D(x^s, x^{t-}) + m) Ltriplet=max(0,D(xs,xt+)−D(xs,xt−)+m)
二、结构一致性对齐(Macro-level)
核心思路:
对 source 和 target 特征内部的结构进行建模(如相似度图、manifold 结构),保持两者的一致性,从而实现“分布结构”的对齐。
方法:
1. Manifold Similarity Alignment
- 分别构造 source 和 target 的特征相似度矩阵 S s S^s Ss 与 S t S^t St
- 最小化它们的差异:
L s t r u c t u r e = ∥ S s − S t ∥ F 2 \mathcal{L}_{structure} = \| S^s - S^t \|_F^2 Lstructure=∥Ss−St∥F2
其中 S i j s = sim ( x i s , x j s ) S_{ij}^s = \text{sim}(x_i^s, x_j^s) Sijs=sim(xis,xjs), S i j t = sim ( x i t , x j t ) S_{ij}^t = \text{sim}(x_i^t, x_j^t) Sijt=sim(xit,xjt), ∥ ⋅ ∥ F \|\cdot\|_F ∥⋅∥F 表示 Frobenius 范数。
变种: Marginal Distance Matrix Similarity Loss ( L mdms \mathcal{L}_{\text{mdms}} Lmdms)
- 对于 flatten 后的特征 x s , x t ∈ R N × d x^s, x^t \in \mathbb{R}^{N \times d} xs,xt∈RN×d( N = h × w N = h \times w N=h×w):
- 对所有特征对 ( i , j ) (i, j) (i,j),对比其余弦相似度的差异:
L mdms = 1 N 2 ∑ i = 1 N ∑ j = 1 N ReLU ( ∣ x i s ⋅ x j s ∥ x i s ∥ ∥ x j s ∥ − x i t ⋅ x j t ∥ x i t ∥ ∥ x j t ∥ ∣ − m 2 ) \mathcal{L}_{\text{mdms}} = \frac{1}{N^2} \sum_{i=1}^N \sum_{j=1}^N \text{ReLU} \left( \left| \frac{x_i^s \cdot x_j^s}{\|x_i^s\| \|x_j^s\|} - \frac{x_i^t \cdot x_j^t}{\|x_i^t\| \|x_j^t\|} \right| - m_2 \right) Lmdms=N21i=1∑Nj=1∑NReLU( ∥xis∥∥xjs∥xis⋅xjs−∥xit∥∥xjt∥xit⋅xjt −m2)
作用:保持 S o u r c e Source Source 和 T a r g e t Target Target 的内部结构(相对分布)一致,关注结构差异大于 margin( m 2 m_2 m2) 的特征对。
2. Graph Matching / Laplacian Alignment
- 将特征看作图中的节点,定义拉普拉斯矩阵 L L L,做图对齐:
L g r a p h = Tr ( X T L X ) \mathcal{L}_{graph} = \text{Tr}(X^T L X) Lgraph=Tr(XTLX)
3. Centered Kernel Alignment (CKA)
- 计算 source 和 target 的核对齐度量(相当于高级版本的结构相似度):
CKA ( K s , K t ) = HSIC ( K s , K t ) HSIC ( K s , K s ) ⋅ HSIC ( K t , K t ) \text{CKA}(K^s, K^t) = \frac{\text{HSIC}(K^s, K^t)}{\sqrt{\text{HSIC}(K^s, K^s)\cdot \text{HSIC}(K^t, K^t)}} CKA(Ks,Kt)=HSIC(Ks,Ks)⋅HSIC(Kt,Kt)HSIC(Ks,Kt)
三、联合对齐损失(Joint Alignment Loss)
将微观的点对点对齐与宏观的结构一致性对齐结合:
L t o t a l = λ 1 ⋅ L p t p + λ 2 ⋅ L s t r u c t u r e \mathcal{L}_{total} = \lambda_1 \cdot \mathcal{L}_{ptp} + \lambda_2 \cdot \mathcal{L}_{structure} Ltotal=λ1⋅Lptp+λ2⋅Lstructure
其中:
- λ 1 \lambda_1 λ1 和 λ 2 \lambda_2 λ2 为权重超参数
- 可进一步添加自监督目标(如 cluster consistency、domain adversarial loss 等)