TensorFlow 深度学习 | 使用子类 API 实现 Wide Deep 模型
💖亲爱的技术爱好者们,热烈欢迎来到 Kant2048 的博客!我是 Thomas Kant,很开心能在CSDN上与你们相遇~💖
本博客的精华专栏:
【自动化测试】 【测试经验】 【人工智能】 【Python】
TensorFlow 深度学习 | 使用子类 API 实现 Wide & Deep 模型
在推荐系统、CTR 预估等场景中,Wide & Deep 模型 是一种经典架构,结合了 Wide 线性部分 与 Deep 深度部分,既能捕捉记忆性特征(memorization),又能学习到泛化能力(generalization)。
本文将带你使用 TensorFlow 子类 API 来实现一个简单的 Wide & Deep 模型,并讲解核心思路与实现细节。
🔹 一、Wide & Deep 模型简介
Wide & Deep 模型最早由 Google 提出,广泛应用在推荐与广告点击率预测中。
-
Wide 部分(线性模型)
直接对输入特征做线性组合,适合捕捉稀疏特征与特征交叉。 -
Deep 部分(深度神经网络)
通过多层非线性网络提取特征组合与高阶特征关系,具备强泛化能力。
二者结合后,模型既能“记忆”已有规律,也能“泛化”出新的特征关系。
🔸 二、数据准备
为了演示,我们使用 模拟数据 来构建一个二分类问题(如 CTR 预测)。在实际业务中,可以替换为 广告点击数据集、推荐数据集 等。
import tensorflow as tf
import numpy as np# 模拟数据
num_samples = 1000
num_features = 10X = np.random.rand(num_samples, num_features).astype(np.float32)
y = np.random.randint(0, 2, size=(num_samples, 1)).astype(np.float32)train_ds = tf.data.Dataset.from_tensor_slices((X, y)).batch(32).