当前位置: 首页 > ops >正文

深度学习中的三种Embedding技术详解

提纲

    • 背景介绍
    • 特征类型与Embedding方法
    • 1. ID类特征的Embedding处理
      • 1.1 标准Embedding方法
      • 1.2 IdHashEmbedding方法
    • 2. 数值型特征的Embedding处理
      • 2.1 RawEmbedding方法
    • 三种Embedding方法对比总结
    • 实践建议
    • 总结

背景介绍

在深度学习领域,Embedding(嵌入)技术是一种将高维稀疏数据转换为低维稠密向量表示的核心方法。它在推荐系统、自然语言处理、图像识别等多个领域中发挥着重要作用。

Embedding的主要目的是:

  1. 将离散的类别特征(如用户ID、商品类别)转换为连续的向量表示
  2. 在低维空间中捕获特征间的语义关系
  3. 提高模型的表达能力和泛化性能

正确使用Embedding技术对于模型性能至关重要。本文将详细解析三种不同的Embedding方法及其应用场景。

特征类型与Embedding方法

在实际项目中,我们通常会遇到两种主要的特征类型:

  1. ID类特征:如用户ID、骑手ID、商品类别等离散的标识符
  2. 数值型特征:如用户年龄、订单金额等连续数值

针对不同类型的特征,我们需要采用不同的Embedding策略来处理。

1. ID类特征的Embedding处理

1.1 标准Embedding方法

适用场景:适用于中低基数类别特征,如商品类别、用户性别、服务类型、地区编码等。

工作原理

  • 为每个唯一的特征值维护一个独立的向量表示
  • 直接使用特征值作为索引查找Embedding table中对应的Embedding向量
  • Embedding表的大小(行数)需要大于等于特征最大的特征值

Example
假设我们有一个item_category特征,有50个不同的类别,我们希望将其映射到16维的向量空间:

import torch
import torch.nn as nn
# 特征配置
feature_config = {'item_category': {'hash_bucket_size': 50, 'embedding_dimension': 16}
}# 创建Embedding层(table),维度为50X16的矩阵
embedding_layer = nn.Embedding(num_embeddings=50,     # 类别数量embedding_dim=16       # 目标维度
)# 输入样本:两个订单的类别分别为5和23
input_tensor = torch.tensor([5, 23])  # shape: [2]# Embedding过程就时vlookup过程,从embedding_layer 矩阵中选取第5行和第23行数据向量
embedded = embedding_layer(input_tensor)  # shape: [2, 16]print(f"输入形状: {input_tensor.shape}")    # [2]
print(f"输出形状: {embedded.shape}")       # [2, 16]

初始化方式

- PyTorch中nn.Embedding的默认初始化
- 权重矩阵形状为[50, 16],每个元素从标准正态分布N(0,1)中采样,weight[i, j] ~ N(0, 1) for all i in [0, 49] and j in [0, 15]

优缺点分析

  • 优点:无哈希冲突,每个特征值有独立表示,易于解释
  • 缺点:对于高基数特征内存消耗巨大,无法处理训练时未见过的特征值大于embedding table的情况,item_category例中,假设有一个item_category的数值为51,则运行时会报突破边界错,因为从embedding table中找不到第51行。

1.2 IdHashEmbedding方法

适用场景:适用于高基数类别特征,如用户ID、商品ID、骑手ID等具有大量唯一值的特征。

为什么需要IdHashEmbedding?
当面对高基数特征时,标准Embedding方法会遇到以下问题:

  1. 内存爆炸:如果用户ID有千万/亿级别,直接Embedding需要巨大的内存
  2. 未见特征值:遇到特征值大于embedding table行数时报错
  3. 训练效率低:大量参数需要更新,训练速度慢

工作原理

  • 通过哈希函数将原始特征值映射到固定大小的桶中
  • 使用哈希后的值作为索引从Embedding表中查找向量
  • 内存使用固定,不受原始特征基数影响

Example

class IdHashEmbedding(nn.Module):...def _build_embeddings(self):for feat_name, config in self.feature_config.items():self.embeddings[feat_name] = nn.Embedding(num_embeddings=config['hash_bucket_size'],embedding_dim=config['embedding_dimension'])def forward(self, features):embed_list = []for feat_name, config in self.feature_config.items():id_x = features[feat_name]hash_x = id_x.cpu().apply_(lambda x: hash_bucket(x, config['hash_bucket_size'])).long().to(self.device)     embed_list.append(self.embeddings[feat_name](hash_x))embeds = torch.cat(embed_list, dim=1)return embeds# 用户ID特征,原始ID可能达到千万级别
user_ids = torch.tensor([123456, 789012])# 使用IdHashEmbedding压缩到1000个桶中
id_hash_embedding = IdHashEmbedding({'user_id': {'hash_bucket_size': 1000, 'embedding_dimension': 32}
}, device)# 计算Id特征对应的IdHash值 
hash_bucket(123456, 1000) = 123456 % 1000 = 456, 
hash_bucket(789012, 1000) = 789012 % 1000 = 12
# 实际查找的是索引456和12,而非原始ID
embedded = id_hash_embedding({'user_id': user_ids})

优缺点分析

  • 优点:内存高效,可以处理任意基数的特征,能处理未见过的特征值
  • 缺点:可能产生哈希冲突,不同特征值可能映射到同一向量,比如对用户223456,其IdHash值也是456,和用户ID123456对应的Id Hash值相同,在模型训练时会当作相同的特征

2. 数值型特征的Embedding处理

2.1 RawEmbedding方法

适用场景:适用于连续数值特征,如用户年龄、注册时长、订单金额等。

工作原理

  • 使用线性变换将原始数值映射到目标维度
  • 相比查找表方式,更适合处理连续值

Example

class RawEmbedding(nn.Module):def __init__(self, input_dim, output_dim):super().__init__()self.input_dim = input_dimself.linear = nn.Linear(input_dim, output_dim)def forward(self, x):return self.linear(x)# 用户年龄特征
user_ages = torch.tensor([[25.0], [35.0]])  # shape: [2, 1]# 使用RawEmbedding将1维年龄映射到8维向量
raw_embedding = RawEmbedding(input_dim=1, output_dim=8)
embedded = raw_embedding(user_ages)  # shape: [2, 8]print(f"输入形状: {user_ages.shape}")   # [2, 1]
print(f"输出形状: {embedded.shape}")     # [2, 8]

优缺点分析

  • 优点:适合处理连续特征,计算简单高效,参数量少
  • 缺点:不适合处理未编码的类别特征,表达能力有限

三种Embedding方法对比总结

特征类型Embedding方法适用场景内存消耗处理未见特征哈希冲突表达能力计算复杂度
ID类特征Embedding中低基数类别特征(商品类别、服务类型等)
ID类特征IdHashEmbedding高基数类别特征(用户ID、骑手ID等)
数值特征RawEmbedding连续数值特征(年龄、服务时长等)

实践建议

在骑手等级转换预测项目中,我们建议:

  1. 高基数ID特征(如骑手ID、用户ID)使用IdHashEmbedding
  2. 中低基数ID特征(如服务类型、地区编码)使用Embedding
  3. 连续数值特征(如年龄、注册时长)使用RawEmbedding

合理选择和使用Embedding方法,可以显著提升模型性能,同时控制内存消耗。

总结

Embedding技术是深度学习模型中不可或缺的组成部分。通过合理选择IdHashEmbedding、Embedding和RawEmbedding,我们可以有效处理不同类型的特征,构建高性能的预测模型。理解每种方法的原理和适用场景,对于构建成功的机器学习系统具有重要意义。

在实际应用中,我们需要根据特征的类型、基数大小、内存限制以及业务需求来选择合适的Embedding方法,这样才能在模型效果和计算效率之间取得最佳平衡。

http://www.xdnf.cn/news/16926.html

相关文章:

  • 【内容规范】关于标题中【】标记的使用说明
  • 02.Redis 安装
  • 浅窥Claude-Prompting for Agents的Talk
  • Thread 类的基本用法
  • 位运算在权限授权中的应用及Vue3实践
  • 深度学习(鱼书)day10--与学习相关的技巧(后两节)
  • 【Python练习】075. 编写一个函数,实现简单的语音识别功能
  • MySQL Undo Log
  • 从零开始设计一个分布式KV存储:基于Raft的协程化实现
  • golang 函数选项模式
  • 手机(电脑)与音响的蓝牙通信
  • Python 实例属性与方法命名冲突:一次隐藏的Bug引发的思考
  • 抽奖系统中 Logback 的日志配置文件说明
  • Easy系列PLC相对运动指令实现定长输送(ST源代码)
  • 长文:Java入门教程
  • 求定积分常用技巧
  • 前端工程化:npmvite
  • 小红书开源dots.ocr:单一视觉语言模型中的多语言文档布局解析
  • CUDA杂记--nvcc使用介绍
  • k8s黑马教程笔记
  • MySQL 索引失效的场景与原因
  • 第二章 矩阵
  • Apple基础(Xcode④-Flutter-Platform Channels)
  • OpenCV轻松入门_面向python(第一章OpenCV入门)
  • 【PDF + ZIP 合并器:把ZIP文件打包至PDF文件中】
  • RabbitMQ面试精讲 Day 8:死信队列与延迟队列实现
  • 反向代理+网关部署架构
  • Flask ORM 模型(轻松版)
  • 如何在不停机的情况下,将MySQL单库的数据迁移到分库分表的架构上?
  • Unity_数据持久化_IXmlSerializable接口