Wide&Deep
- 该文主要记录
深度推荐模型|DataWhale
Task2 Wide&Deep 内容
Wide & Deep 模型的核心思想是结合线性模型的记忆能力和 DNN 模型的泛化能力,从而提升整体模型性能。根据论文Wide & Deep Learning for Recommender System描述,Wide & Deep 已成功应用在 Google Play 的app推荐业务。
网络结构
Wide & Deep 模型的核心思想是结合线性模型的记忆能力(memorization)和 DNN 模型的泛化能力(generalization),在训练过程中同时优化 2 个模型的参数,从而达到整体模型的预测能力最优。
记忆(memorization)即从历史数据中发现item或者特征之间的相关性。
泛化(generalization)即相关性的传递,发现在历史数据中很少或者没有出现的新的特征组合。
WideDeep = LR + DNN
推荐系统
论文推荐系统流程举例:
- Query 召回 100 个相关的 items;
O(100) items
- 根据 Query + Items + User Actions 学习排序模型;
- 利用 2 的模型,对 1 中召回的 items 排序,将 top10 推荐给用户;
import warnings
warnings.filterwarnings("ignore")
import itertools
import pandas as pd
import numpy as np
from tqdm import tqdm
from collections import namedtuple
import tensorflow as tf
from tensorflow.keras.layers import *
from tensorflow.keras.models import *
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler, LabelEncoder
from deepctr.feature_column import SparseFeat, DenseFeat, VarLenSparseFeat
# 设置GPU显存使用,tensorflow不然会报错
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
try:
# Currently, memory growth needs to be the same across GPUs
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
logical_gpus = tf.config.experimental.list_logical_devices('GPU')
print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
except RuntimeError as e:
# Memory growth must be set before GPUs have been initialized
print(e)
1 Physical GPUs, 1 Logical GPUs
1、Load Dataset
data = pd.read_csv("../代码/data/criteo_sample.txt")
columns = data.columns.values
columns
array(['label', 'I1', 'I2', 'I3', 'I4', 'I5', 'I6', 'I7', 'I8', 'I9',
'I10', 'I11', 'I12', 'I13', 'C1', 'C2', 'C3', 'C4', 'C5', 'C6',
'C7', 'C8', 'C9', 'C10', 'C11', 'C12', 'C13', 'C14', 'C15', 'C16',
'C17', 'C18', 'C19', 'C20', 'C21', 'C22', 'C23', 'C24', 'C25',
'C26'], dtype=object)
# 划分dense和sparse特征
dense_features = [feat for feat in columns if 'I' in feat]
sparse_features = [feat for feat in columns if 'C' in feat]
print(dense_features)
print(sparse_features)
['I1', 'I2', 'I3', 'I4', 'I5', 'I6', 'I7', 'I8', 'I9', 'I10', 'I11', 'I12', 'I13']
['C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'C7', 'C8', 'C9', 'C10', 'C11', 'C12', 'C13', 'C14', 'C15', 'C16', 'C17', 'C18', 'C19', 'C20', 'C21', 'C22', 'C23', 'C24', 'C25', 'C26']
2、Data Process
# 简单处理特征,包括填充缺失值,数值处理,类别编码
def data_process(data_df, dense_features, sparse_features):
data_df[dense_features] = data_df[dense_features].fillna(0.0)
for f in dense_features:
data_df[f] = data_df[f].apply(lambda x: np.log(x+1) if x > -1 else -1)
data_df[sparse_features] = data_df[sparse_features].fillna("-1")
for f in sparse_features:
lbe = LabelEncoder()
data_df[f] = lbe.fit_transform(data_df[f])
return data_df[dense_features + sparse_features]
# 简单的数据预处理
train_data = data_process(data, dense_features, sparse_features)
train_data['label'] = data['label']
train_data.head()
I1 | I2 | I3 | I4 | I5 | I6 | I7 | I8 | I9 | I10 | ... | C18 | C19 | C20 | C21 | C22 | C23 | C24 | C25 | C26 | label | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 0.0 | 1.386294 | 5.564520 | 0.000000 | 9.779567 | 0.000000 | 0.000000 | 3.526361 | 0.000000 | 0.0 | ... | 66 | 0 | 0 | 3 | 0 | 1 | 96 | 0 | 0 | 0 |
1 | 0.0 | -1.000000 | 2.995732 | 3.583519 | 10.317318 | 5.513429 | 0.693147 | 3.583519 | 5.081404 | 0.0 | ... | 52 | 0 | 0 | 47 | 0 | 7 | 112 | 0 | 0 | 0 |
2 | 0.0 | 0.000000 | 1.098612 | 2.564949 | 7.607878 | 5.105945 | 1.945910 | 3.583519 | 6.261492 | 0.0 | ... | 49 | 0 | 0 | 25 | 0 | 6 | 53 | 0 | 0 | 0 |
3 | 0.0 | 2.639057 | 0.693147 | 1.609438 | 9.731334 | 5.303305 | 1.791759 | 1.609438 | 3.401197 | 0.0 | ... | 37 | 0 | 0 | 156 | 0 | 0 | 32 | 0 | 0 | 0 |
4 | 0.0 | 0.000000 | 4.653960 | 3.332205 | 7.596392 | 4.962845 | 1.609438 | 3.496508 | 3.637586 | 0.0 | ... | 14 | 5 | 3 | 9 | 0 | 0 | 5 | 1 | 47 | 0 |
5 rows × 40 columns
# 与 DeepCrossing 处理方式一致,这里是分成 Linear 和 DNN 两部分
linear_feature_columns = [
SparseFeat(feat, vocabulary_size=data[feat].nunique(),embedding_dim=4)
for i,feat in enumerate(sparse_features)
]+[
DenseFeat(feat, 1) for feat in dense_features
]
linear_feature_columns[0],len(linear_feature_columns)
(SparseFeat(name='C1', vocabulary_size=27, embedding_dim=4, use_hash=False, dtype='int32', embeddings_initializer=<tensorflow.python.keras.initializers.initializers_v1.RandomNormal object at 0x0000022F91F17850>, embedding_name='C1', group_name='default_group', trainable=True),
39)
dnn_feature_columns = [
SparseFeat(feat, vocabulary_size=data[feat].nunique(),embedding_dim=4)
for i,feat in enumerate(sparse_features)
]+[
DenseFeat(feat, 1) for feat in dense_features
]
dnn_feature_columns[0],len(dnn_feature_columns)
(SparseFeat(name='C1', vocabulary_size=27, embedding_dim=4, use_hash=False, dtype='int32', embeddings_initializer=<tensorflow.python.keras.initializers.initializers_v1.RandomNormal object at 0x0000022F91A1BD90>, embedding_name='C1', group_name='default_group', trainable=True),
39)
3、模型原理
3.1 The Wide Component
wide部分是一个广义的线性模型,输入的特征主要有两部分组成,一部分是原始的部分特征,另一部分是原始特征的交叉特征(cross-product transformation),对于交互特征可以定义为:
$c_{ki}$是一个布尔变量,当第i个特征属于第k个特征组合时,$c_{ki}$的值为1,否则为0,$x_i$是第i个特征的值,大体意思就是两个特征都同时为1这个新的特征才能为1,否则就是0。交叉特征可以理解为任意两个特征的乘积,等同于一个特征组合。
3.2 The Deep Component
Deep部分是一个DNN模型,输入的特征主要分为两大类,一类是数值特征(可直接输入DNN),一类是类别特征(需要经过Embedding之后才能输入到DNN中),Deep部分的数学形式如下:
我们知道DNN模型随着层数的增加,中间的特征就越抽象,也就提高了模型的泛化能力。对于Deep部分的DNN模型作者使用了深度学习常用的优化器AdaGrad,这也是为了使得模型可以得到更精确的解。
3.3 joint training
联合训练(Joint Training)和集成(Ensemble)是不同的,集成是每个模型单独训练,再将模型的结果汇合。相比联合训练,集成的每个独立模型都得学得足够好才有利于随后的汇合,因此每个模型的model size也相对更大。而联合训练的wide部分只需要作一小部分的特征叉乘来弥补deep部分的不足,不需要 一个full-size 的wide 模型。
W&D模型是将两部分输出的结果结合起来联合训练,将deep和wide部分的输出重新使用一个逻辑回归模型做最终的预测,输出概率值。联合训练的数学形式如下:需要注意的是,因为Wide侧的数据是高维稀疏的,所以作者使用了FTRL算法优化,而Deep侧使用的是 Adagrad。
在论文中,作者通过梯度的反向传播,使用 mini-batch stochastic optimization 训练参数,并对wide部分使用带L1正则的Follow- the-regularized-leader (FTRL) 算法,对deep部分使用 AdaGrad算法。
4、Wide&Deep model build
4.1 Input Layers
def build_input_layers(feature_columns):
# 构建Input层字典,并以dense和sparse两类字典的形式返回
dense_input_dict, sparse_input_dict = {}, {}
for fc in feature_columns:
if isinstance(fc, SparseFeat):
sparse_input_dict[fc.name] = Input(shape=(1, ), name=fc.name)
elif isinstance(fc, DenseFeat):
dense_input_dict[fc.name] = Input(shape=(fc.dimension, ), name=fc.name)
return dense_input_dict, sparse_input_dict
# 同样使用字典输入方式,方便后续构建模型
# 这里为什么传入 linear + dnn,结果和单独传入相同
# 是因为前一步 linear & dnn 特征选择不同?但是这样依旧将 linear 和 dnn 进行了融合,没有区分
dense_input_dict, sparse_input_dict = build_input_layers(linear_feature_columns + dnn_feature_columns)
print(dense_input_dict["I1"],len(dense_input_dict),type(dense_input_dict))
print(sparse_input_dict["C1"],len(sparse_input_dict),type(sparse_input_dict))
KerasTensor(type_spec=TensorSpec(shape=(None, 1), dtype=tf.float32, name='I1'), name='I1', description="created by layer 'I1'") 13 <class 'dict'>
KerasTensor(type_spec=TensorSpec(shape=(None, 1), dtype=tf.float32, name='C1'), name='C1', description="created by layer 'C1'") 26 <class 'dict'>
# 构建模型的输入层,模型的输入层不能是字典的形式,应该将字典的形式转换成列表的形式
# 注意:这里实际的输入与Input()层的对应,是通过模型输入时候的字典数据的key与对应name的Input层
input_layers = list(dense_input_dict.values()) + list(sparse_input_dict.values())
input_layers[0],len(input_layers),type(input_layers)
(<KerasTensor: shape=(None, 1) dtype=float32 (created by layer 'I1')>,
39,
list)
def build_embedding_layers(feature_columns, input_layers_dict, is_linear):
embedding_layers_dict = dict()
sparse_feature_columns = list(filter(lambda x: isinstance(x, SparseFeat), feature_columns)) if feature_columns else []
if is_linear:
for fc in sparse_feature_columns:
embedding_layers_dict[fc.name] = Embedding(fc.vocabulary_size, 1, name="1d_emb_"+fc.name)
else:
for fc in sparse_feature_columns:
embedding_layers_dict[fc.name] = Embedding(fc.vocabulary_size, fc.embedding_dim, name="kd_emb_"+fc.name)
return embedding_layers_dict
4.2 Linear
# 将linear部分的特征中sparse特征筛选出来,后面用来做1维的embedding
linear_sparse_feature_columns = list(filter(lambda x: isinstance(x, SparseFeat), linear_feature_columns))
print(linear_sparse_feature_columns[0])
print(len(linear_sparse_feature_columns))
print(type(linear_sparse_feature_columns))
SparseFeat(name='C1', vocabulary_size=27, embedding_dim=4, use_hash=False, dtype='int32', embeddings_initializer=<tensorflow.python.keras.initializers.initializers_v1.RandomNormal object at 0x0000022F91F17850>, embedding_name='C1', group_name='default_group', trainable=True)
26
<class 'list'>
def get_linear_logits(dense_input_dict, sparse_input_dict, sparse_feature_columns):
concat_dense_inputs = Concatenate(axis=1)(list(dense_input_dict.values()))
dense_logits_output = Dense(1)(concat_dense_inputs)
linear_embedding_layers = build_embedding_layers(sparse_feature_columns,\
sparse_input_dict, is_linear=True)
print(dense_logits_output)
sparse_1d_embed = []
for fc in sparse_feature_columns:
feat_input = sparse_input_dict[fc.name]
embed = Flatten()(linear_embedding_layers[fc.name](feat_input))
sparse_1d_embed.append(embed)
sparse_logits_output = Add()(sparse_1d_embed)
print(sparse_logits_output)
linear_logits = Add()([dense_logits_output, sparse_logits_output])
return linear_logits
# Wide&Deep模型论文中Wide部分使用的特征比较简单,并且得到的特征非常的稀疏,所以使用了FTRL优化Wide部分(这里没有实现FTRL)
# 但是是根据他们业务进行选择的,我们这里将所有可能用到的特征都输入到Wide部分,具体的细节可以根据需求进行修改
linear_logits = get_linear_logits(dense_input_dict, sparse_input_dict, linear_sparse_feature_columns)
linear_logits,type(linear_logits)
KerasTensor(type_spec=TensorSpec(shape=(None, 1), dtype=tf.float32, name=None), name='dense_2/BiasAdd:0', description="created by layer 'dense_2'")
KerasTensor(type_spec=TensorSpec(shape=(None, 1), dtype=tf.float32, name=None), name='add_4/add_24:0', description="created by layer 'add_4'")
(<KerasTensor: shape=(None, 1) dtype=float32 (created by layer 'add_5')>,
tensorflow.python.keras.engine.keras_tensor.KerasTensor)
4.3 DNN
embedding_layers = build_embedding_layers(dnn_feature_columns, sparse_input_dict, is_linear=False)
embedding_layers["C1"],len(embedding_layers)
(<tensorflow.python.keras.layers.embeddings.Embedding at 0x23000d75c40>, 26)
# 将所有的sparse特征embedding拼接
def concat_embedding_list(feature_columns, input_layer_dict, embedding_layer_dict, flatten=False):
# 将sparse特征筛选出来
sparse_feature_columns = list(filter(lambda x: isinstance(x, SparseFeat), feature_columns))
embedding_list = []
for fc in sparse_feature_columns:
_input = input_layer_dict[fc.name] # 获取输入层
_embed = embedding_layer_dict[fc.name] # B x 1 x dim 获取对应的embedding层
embed = _embed(_input) # B x dim 将input层输入到embedding层中
# 是否需要flatten, 如果embedding列表最终是直接输入到Dense层中,需要进行Flatten,否则不需要
if flatten:
embed = Flatten()(embed)
embedding_list.append(embed)
return embedding_list
def get_dnn_logits(dense_input_dict, sparse_input_dict, sparse_feature_columns, dnn_embedding_layers):
concat_dense_inputs = Concatenate(axis=1)(list(dense_input_dict.values()))
sparse_kd_embed = concat_embedding_list(sparse_feature_columns,sparse_input_dict,dnn_embedding_layers,flatten=True)
concat_sparse_kd_embed = Concatenate(axis=1)(sparse_kd_embed)
dnn_input = Concatenate(axis=1)([concat_dense_inputs, concat_sparse_kd_embed])
dnn_out = Dropout(0.5)(Dense(1024,activation="relu")(dnn_input))
dnn_out = Dropout(0.3)(Dense(512,activation="relu")(dnn_out))
dnn_out = Dropout(0.1)(Dense(256,activation="relu")(dnn_out))
dnn_logits = Dense(1)(dnn_out)
return dnn_logits
dnn_sparse_feature_columns = list(filter(lambda x: isinstance(x, SparseFeat), dnn_feature_columns))
dnn_sparse_feature_columns[0]
SparseFeat(name='C1', vocabulary_size=27, embedding_dim=4, use_hash=False, dtype='int32', embeddings_initializer=<tensorflow.python.keras.initializers.initializers_v1.RandomNormal object at 0x0000022F91A1BD90>, embedding_name='C1', group_name='default_group', trainable=True)
# 在Wide&Deep模型中,deep部分的输入是将dense特征和embedding特征拼在一起输入到dnn中
dnn_logits = get_dnn_logits(dense_input_dict, sparse_input_dict, dnn_sparse_feature_columns, embedding_layers)
dnn_logits
<KerasTensor: shape=(None, 1) dtype=float32 (created by layer 'dense_10')>
4.4 Linear + DNN
output_logits = Add()([linear_logits, dnn_logits])
output_layer = Activation("sigmoid")(output_logits)
model = Model(input_layers, output_layer)
model.summary()
Model: "model"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
C1 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
C2 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
C3 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
C4 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
C5 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
C6 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
C7 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
C8 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
C9 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
C10 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
C11 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
C12 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
C13 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
C14 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
C15 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
C16 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
C17 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
C18 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
C19 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
C20 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
C21 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
C22 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
C23 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
C24 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
C25 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
C26 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
kd_emb_C1 (Embedding) (None, 1, 4) 108 C1[0][0]
__________________________________________________________________________________________________
kd_emb_C2 (Embedding) (None, 1, 4) 368 C2[0][0]
__________________________________________________________________________________________________
kd_emb_C3 (Embedding) (None, 1, 4) 688 C3[0][0]
__________________________________________________________________________________________________
kd_emb_C4 (Embedding) (None, 1, 4) 628 C4[0][0]
__________________________________________________________________________________________________
kd_emb_C5 (Embedding) (None, 1, 4) 48 C5[0][0]
__________________________________________________________________________________________________
kd_emb_C6 (Embedding) (None, 1, 4) 28 C6[0][0]
__________________________________________________________________________________________________
kd_emb_C7 (Embedding) (None, 1, 4) 732 C7[0][0]
__________________________________________________________________________________________________
kd_emb_C8 (Embedding) (None, 1, 4) 76 C8[0][0]
__________________________________________________________________________________________________
kd_emb_C9 (Embedding) (None, 1, 4) 8 C9[0][0]
__________________________________________________________________________________________________
kd_emb_C10 (Embedding) (None, 1, 4) 568 C10[0][0]
__________________________________________________________________________________________________
kd_emb_C11 (Embedding) (None, 1, 4) 692 C11[0][0]
__________________________________________________________________________________________________
kd_emb_C12 (Embedding) (None, 1, 4) 680 C12[0][0]
__________________________________________________________________________________________________
kd_emb_C13 (Embedding) (None, 1, 4) 664 C13[0][0]
__________________________________________________________________________________________________
kd_emb_C14 (Embedding) (None, 1, 4) 56 C14[0][0]
__________________________________________________________________________________________________
kd_emb_C15 (Embedding) (None, 1, 4) 680 C15[0][0]
__________________________________________________________________________________________________
kd_emb_C16 (Embedding) (None, 1, 4) 672 C16[0][0]
__________________________________________________________________________________________________
kd_emb_C17 (Embedding) (None, 1, 4) 36 C17[0][0]
__________________________________________________________________________________________________
kd_emb_C18 (Embedding) (None, 1, 4) 508 C18[0][0]
__________________________________________________________________________________________________
kd_emb_C19 (Embedding) (None, 1, 4) 176 C19[0][0]
__________________________________________________________________________________________________
kd_emb_C20 (Embedding) (None, 1, 4) 16 C20[0][0]
__________________________________________________________________________________________________
kd_emb_C21 (Embedding) (None, 1, 4) 676 C21[0][0]
__________________________________________________________________________________________________
kd_emb_C22 (Embedding) (None, 1, 4) 24 C22[0][0]
__________________________________________________________________________________________________
kd_emb_C23 (Embedding) (None, 1, 4) 40 C23[0][0]
__________________________________________________________________________________________________
kd_emb_C24 (Embedding) (None, 1, 4) 500 C24[0][0]
__________________________________________________________________________________________________
kd_emb_C25 (Embedding) (None, 1, 4) 80 C25[0][0]
__________________________________________________________________________________________________
kd_emb_C26 (Embedding) (None, 1, 4) 360 C26[0][0]
__________________________________________________________________________________________________
I1 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
I2 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
I3 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
I4 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
I5 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
I6 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
I7 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
I8 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
I9 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
I10 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
I11 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
I12 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
I13 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
flatten_104 (Flatten) (None, 4) 0 kd_emb_C1[1][0]
__________________________________________________________________________________________________
flatten_105 (Flatten) (None, 4) 0 kd_emb_C2[1][0]
__________________________________________________________________________________________________
flatten_106 (Flatten) (None, 4) 0 kd_emb_C3[1][0]
__________________________________________________________________________________________________
flatten_107 (Flatten) (None, 4) 0 kd_emb_C4[1][0]
__________________________________________________________________________________________________
flatten_108 (Flatten) (None, 4) 0 kd_emb_C5[1][0]
__________________________________________________________________________________________________
flatten_109 (Flatten) (None, 4) 0 kd_emb_C6[1][0]
__________________________________________________________________________________________________
flatten_110 (Flatten) (None, 4) 0 kd_emb_C7[1][0]
__________________________________________________________________________________________________
flatten_111 (Flatten) (None, 4) 0 kd_emb_C8[1][0]
__________________________________________________________________________________________________
flatten_112 (Flatten) (None, 4) 0 kd_emb_C9[1][0]
__________________________________________________________________________________________________
flatten_113 (Flatten) (None, 4) 0 kd_emb_C10[1][0]
__________________________________________________________________________________________________
flatten_114 (Flatten) (None, 4) 0 kd_emb_C11[1][0]
__________________________________________________________________________________________________
flatten_115 (Flatten) (None, 4) 0 kd_emb_C12[1][0]
__________________________________________________________________________________________________
flatten_116 (Flatten) (None, 4) 0 kd_emb_C13[1][0]
__________________________________________________________________________________________________
flatten_117 (Flatten) (None, 4) 0 kd_emb_C14[1][0]
__________________________________________________________________________________________________
flatten_118 (Flatten) (None, 4) 0 kd_emb_C15[1][0]
__________________________________________________________________________________________________
flatten_119 (Flatten) (None, 4) 0 kd_emb_C16[1][0]
__________________________________________________________________________________________________
flatten_120 (Flatten) (None, 4) 0 kd_emb_C17[1][0]
__________________________________________________________________________________________________
flatten_121 (Flatten) (None, 4) 0 kd_emb_C18[1][0]
__________________________________________________________________________________________________
flatten_122 (Flatten) (None, 4) 0 kd_emb_C19[1][0]
__________________________________________________________________________________________________
flatten_123 (Flatten) (None, 4) 0 kd_emb_C20[1][0]
__________________________________________________________________________________________________
flatten_124 (Flatten) (None, 4) 0 kd_emb_C21[1][0]
__________________________________________________________________________________________________
flatten_125 (Flatten) (None, 4) 0 kd_emb_C22[1][0]
__________________________________________________________________________________________________
flatten_126 (Flatten) (None, 4) 0 kd_emb_C23[1][0]
__________________________________________________________________________________________________
flatten_127 (Flatten) (None, 4) 0 kd_emb_C24[1][0]
__________________________________________________________________________________________________
flatten_128 (Flatten) (None, 4) 0 kd_emb_C25[1][0]
__________________________________________________________________________________________________
flatten_129 (Flatten) (None, 4) 0 kd_emb_C26[1][0]
__________________________________________________________________________________________________
concatenate_8 (Concatenate) (None, 13) 0 I1[0][0]
I2[0][0]
I3[0][0]
I4[0][0]
I5[0][0]
I6[0][0]
I7[0][0]
I8[0][0]
I9[0][0]
I10[0][0]
I11[0][0]
I12[0][0]
I13[0][0]
__________________________________________________________________________________________________
concatenate_9 (Concatenate) (None, 104) 0 flatten_104[0][0]
flatten_105[0][0]
flatten_106[0][0]
flatten_107[0][0]
flatten_108[0][0]
flatten_109[0][0]
flatten_110[0][0]
flatten_111[0][0]
flatten_112[0][0]
flatten_113[0][0]
flatten_114[0][0]
flatten_115[0][0]
flatten_116[0][0]
flatten_117[0][0]
flatten_118[0][0]
flatten_119[0][0]
flatten_120[0][0]
flatten_121[0][0]
flatten_122[0][0]
flatten_123[0][0]
flatten_124[0][0]
flatten_125[0][0]
flatten_126[0][0]
flatten_127[0][0]
flatten_128[0][0]
flatten_129[0][0]
__________________________________________________________________________________________________
concatenate_10 (Concatenate) (None, 117) 0 concatenate_8[0][0]
concatenate_9[0][0]
__________________________________________________________________________________________________
dense_7 (Dense) (None, 1024) 120832 concatenate_10[0][0]
__________________________________________________________________________________________________
dropout_3 (Dropout) (None, 1024) 0 dense_7[0][0]
__________________________________________________________________________________________________
dense_8 (Dense) (None, 512) 524800 dropout_3[0][0]
__________________________________________________________________________________________________
1d_emb_C1 (Embedding) (None, 1, 1) 27 C1[0][0]
__________________________________________________________________________________________________
1d_emb_C2 (Embedding) (None, 1, 1) 92 C2[0][0]
__________________________________________________________________________________________________
1d_emb_C3 (Embedding) (None, 1, 1) 172 C3[0][0]
__________________________________________________________________________________________________
1d_emb_C4 (Embedding) (None, 1, 1) 157 C4[0][0]
__________________________________________________________________________________________________
1d_emb_C5 (Embedding) (None, 1, 1) 12 C5[0][0]
__________________________________________________________________________________________________
1d_emb_C6 (Embedding) (None, 1, 1) 7 C6[0][0]
__________________________________________________________________________________________________
1d_emb_C7 (Embedding) (None, 1, 1) 183 C7[0][0]
__________________________________________________________________________________________________
1d_emb_C8 (Embedding) (None, 1, 1) 19 C8[0][0]
__________________________________________________________________________________________________
1d_emb_C9 (Embedding) (None, 1, 1) 2 C9[0][0]
__________________________________________________________________________________________________
1d_emb_C10 (Embedding) (None, 1, 1) 142 C10[0][0]
__________________________________________________________________________________________________
1d_emb_C11 (Embedding) (None, 1, 1) 173 C11[0][0]
__________________________________________________________________________________________________
1d_emb_C12 (Embedding) (None, 1, 1) 170 C12[0][0]
__________________________________________________________________________________________________
1d_emb_C13 (Embedding) (None, 1, 1) 166 C13[0][0]
__________________________________________________________________________________________________
1d_emb_C14 (Embedding) (None, 1, 1) 14 C14[0][0]
__________________________________________________________________________________________________
1d_emb_C15 (Embedding) (None, 1, 1) 170 C15[0][0]
__________________________________________________________________________________________________
1d_emb_C16 (Embedding) (None, 1, 1) 168 C16[0][0]
__________________________________________________________________________________________________
1d_emb_C17 (Embedding) (None, 1, 1) 9 C17[0][0]
__________________________________________________________________________________________________
1d_emb_C18 (Embedding) (None, 1, 1) 127 C18[0][0]
__________________________________________________________________________________________________
1d_emb_C19 (Embedding) (None, 1, 1) 44 C19[0][0]
__________________________________________________________________________________________________
1d_emb_C20 (Embedding) (None, 1, 1) 4 C20[0][0]
__________________________________________________________________________________________________
1d_emb_C21 (Embedding) (None, 1, 1) 169 C21[0][0]
__________________________________________________________________________________________________
1d_emb_C22 (Embedding) (None, 1, 1) 6 C22[0][0]
__________________________________________________________________________________________________
1d_emb_C23 (Embedding) (None, 1, 1) 10 C23[0][0]
__________________________________________________________________________________________________
1d_emb_C24 (Embedding) (None, 1, 1) 125 C24[0][0]
__________________________________________________________________________________________________
1d_emb_C25 (Embedding) (None, 1, 1) 20 C25[0][0]
__________________________________________________________________________________________________
1d_emb_C26 (Embedding) (None, 1, 1) 90 C26[0][0]
__________________________________________________________________________________________________
dropout_4 (Dropout) (None, 512) 0 dense_8[0][0]
__________________________________________________________________________________________________
concatenate_4 (Concatenate) (None, 13) 0 I1[0][0]
I2[0][0]
I3[0][0]
I4[0][0]
I5[0][0]
I6[0][0]
I7[0][0]
I8[0][0]
I9[0][0]
I10[0][0]
I11[0][0]
I12[0][0]
I13[0][0]
__________________________________________________________________________________________________
flatten_52 (Flatten) (None, 1) 0 1d_emb_C1[0][0]
__________________________________________________________________________________________________
flatten_53 (Flatten) (None, 1) 0 1d_emb_C2[0][0]
__________________________________________________________________________________________________
flatten_54 (Flatten) (None, 1) 0 1d_emb_C3[0][0]
__________________________________________________________________________________________________
flatten_55 (Flatten) (None, 1) 0 1d_emb_C4[0][0]
__________________________________________________________________________________________________
flatten_56 (Flatten) (None, 1) 0 1d_emb_C5[0][0]
__________________________________________________________________________________________________
flatten_57 (Flatten) (None, 1) 0 1d_emb_C6[0][0]
__________________________________________________________________________________________________
flatten_58 (Flatten) (None, 1) 0 1d_emb_C7[0][0]
__________________________________________________________________________________________________
flatten_59 (Flatten) (None, 1) 0 1d_emb_C8[0][0]
__________________________________________________________________________________________________
flatten_60 (Flatten) (None, 1) 0 1d_emb_C9[0][0]
__________________________________________________________________________________________________
flatten_61 (Flatten) (None, 1) 0 1d_emb_C10[0][0]
__________________________________________________________________________________________________
flatten_62 (Flatten) (None, 1) 0 1d_emb_C11[0][0]
__________________________________________________________________________________________________
flatten_63 (Flatten) (None, 1) 0 1d_emb_C12[0][0]
__________________________________________________________________________________________________
flatten_64 (Flatten) (None, 1) 0 1d_emb_C13[0][0]
__________________________________________________________________________________________________
flatten_65 (Flatten) (None, 1) 0 1d_emb_C14[0][0]
__________________________________________________________________________________________________
flatten_66 (Flatten) (None, 1) 0 1d_emb_C15[0][0]
__________________________________________________________________________________________________
flatten_67 (Flatten) (None, 1) 0 1d_emb_C16[0][0]
__________________________________________________________________________________________________
flatten_68 (Flatten) (None, 1) 0 1d_emb_C17[0][0]
__________________________________________________________________________________________________
flatten_69 (Flatten) (None, 1) 0 1d_emb_C18[0][0]
__________________________________________________________________________________________________
flatten_70 (Flatten) (None, 1) 0 1d_emb_C19[0][0]
__________________________________________________________________________________________________
flatten_71 (Flatten) (None, 1) 0 1d_emb_C20[0][0]
__________________________________________________________________________________________________
flatten_72 (Flatten) (None, 1) 0 1d_emb_C21[0][0]
__________________________________________________________________________________________________
flatten_73 (Flatten) (None, 1) 0 1d_emb_C22[0][0]
__________________________________________________________________________________________________
flatten_74 (Flatten) (None, 1) 0 1d_emb_C23[0][0]
__________________________________________________________________________________________________
flatten_75 (Flatten) (None, 1) 0 1d_emb_C24[0][0]
__________________________________________________________________________________________________
flatten_76 (Flatten) (None, 1) 0 1d_emb_C25[0][0]
__________________________________________________________________________________________________
flatten_77 (Flatten) (None, 1) 0 1d_emb_C26[0][0]
__________________________________________________________________________________________________
dense_9 (Dense) (None, 256) 131328 dropout_4[0][0]
__________________________________________________________________________________________________
dense_2 (Dense) (None, 1) 14 concatenate_4[0][0]
__________________________________________________________________________________________________
add_4 (Add) (None, 1) 0 flatten_52[0][0]
flatten_53[0][0]
flatten_54[0][0]
flatten_55[0][0]
flatten_56[0][0]
flatten_57[0][0]
flatten_58[0][0]
flatten_59[0][0]
flatten_60[0][0]
flatten_61[0][0]
flatten_62[0][0]
flatten_63[0][0]
flatten_64[0][0]
flatten_65[0][0]
flatten_66[0][0]
flatten_67[0][0]
flatten_68[0][0]
flatten_69[0][0]
flatten_70[0][0]
flatten_71[0][0]
flatten_72[0][0]
flatten_73[0][0]
flatten_74[0][0]
flatten_75[0][0]
flatten_76[0][0]
flatten_77[0][0]
__________________________________________________________________________________________________
dropout_5 (Dropout) (None, 256) 0 dense_9[0][0]
__________________________________________________________________________________________________
add_5 (Add) (None, 1) 0 dense_2[0][0]
add_4[0][0]
__________________________________________________________________________________________________
dense_10 (Dense) (None, 1) 257 dropout_5[0][0]
__________________________________________________________________________________________________
add_6 (Add) (None, 1) 0 add_5[0][0]
dense_10[0][0]
__________________________________________________________________________________________________
activation (Activation) (None, 1) 0 add_6[0][0]
==================================================================================================
Total params: 788,621
Trainable params: 788,621
Non-trainable params: 0
__________________________________________________________________________________________________
model.compile(optimizer="adam",loss="binary_crossentropy",metrics=["binary_crossentropy",tf.keras.metrics.AUC(name="auc")])
train_model_input = {name:data[name] for name in dense_features+sparse_features}
model.fit(train_model_input,train_data["label"].values,batch_size=64,epochs=20,validation_split=0.3)
Epoch 1/20
3/3 [==============================] - 1s 283ms/step - loss: 0.5477 - binary_crossentropy: 0.5477 - auc: 0.6436 - val_loss: 0.7410 - val_binary_crossentropy: 0.7410 - val_auc: 0.6200
Epoch 2/20
3/3 [==============================] - 0s 49ms/step - loss: 0.5829 - binary_crossentropy: 0.5829 - auc: 0.6614 - val_loss: 0.6583 - val_binary_crossentropy: 0.6583 - val_auc: 0.6290
Epoch 3/20
3/3 [==============================] - 0s 49ms/step - loss: 0.4333 - binary_crossentropy: 0.4333 - auc: 0.7712 - val_loss: 0.6934 - val_binary_crossentropy: 0.6934 - val_auc: 0.5841
Epoch 4/20
3/3 [==============================] - 0s 48ms/step - loss: 0.5629 - binary_crossentropy: 0.5629 - auc: 0.6908 - val_loss: 0.6786 - val_binary_crossentropy: 0.6786 - val_auc: 0.5931
Epoch 5/20
3/3 [==============================] - 0s 49ms/step - loss: 0.5090 - binary_crossentropy: 0.5090 - auc: 0.7138 - val_loss: 0.6061 - val_binary_crossentropy: 0.6061 - val_auc: 0.6271
Epoch 6/20
3/3 [==============================] - 0s 52ms/step - loss: 0.4698 - binary_crossentropy: 0.4698 - auc: 0.7412 - val_loss: 0.6045 - val_binary_crossentropy: 0.6045 - val_auc: 0.6117
Epoch 7/20
3/3 [==============================] - 0s 51ms/step - loss: 0.3880 - binary_crossentropy: 0.3880 - auc: 0.8280 - val_loss: 0.5973 - val_binary_crossentropy: 0.5973 - val_auc: 0.6284
Epoch 8/20
3/3 [==============================] - 0s 49ms/step - loss: 0.3896 - binary_crossentropy: 0.3896 - auc: 0.8258 - val_loss: 0.6379 - val_binary_crossentropy: 0.6379 - val_auc: 0.6752
Epoch 9/20
3/3 [==============================] - 0s 52ms/step - loss: 0.4310 - binary_crossentropy: 0.4310 - auc: 0.8127 - val_loss: 0.6128 - val_binary_crossentropy: 0.6128 - val_auc: 0.6476
Epoch 10/20
3/3 [==============================] - 0s 52ms/step - loss: 0.3449 - binary_crossentropy: 0.3449 - auc: 0.8792 - val_loss: 0.6169 - val_binary_crossentropy: 0.6169 - val_auc: 0.6239
Epoch 11/20
3/3 [==============================] - 0s 49ms/step - loss: 0.3839 - binary_crossentropy: 0.3839 - auc: 0.8647 - val_loss: 0.6150 - val_binary_crossentropy: 0.6150 - val_auc: 0.6380
Epoch 12/20
3/3 [==============================] - 0s 48ms/step - loss: 0.3433 - binary_crossentropy: 0.3433 - auc: 0.8788 - val_loss: 0.6719 - val_binary_crossentropy: 0.6719 - val_auc: 0.6797
Epoch 13/20
3/3 [==============================] - 0s 52ms/step - loss: 0.3312 - binary_crossentropy: 0.3312 - auc: 0.8942 - val_loss: 0.7839 - val_binary_crossentropy: 0.7839 - val_auc: 0.6874
Epoch 14/20
3/3 [==============================] - 0s 50ms/step - loss: 0.3214 - binary_crossentropy: 0.3214 - auc: 0.9198 - val_loss: 0.6589 - val_binary_crossentropy: 0.6589 - val_auc: 0.6694
Epoch 15/20
3/3 [==============================] - 0s 49ms/step - loss: 0.2330 - binary_crossentropy: 0.2330 - auc: 0.9506 - val_loss: 0.6497 - val_binary_crossentropy: 0.6497 - val_auc: 0.6040
Epoch 16/20
3/3 [==============================] - 0s 50ms/step - loss: 0.2512 - binary_crossentropy: 0.2512 - auc: 0.9612 - val_loss: 0.6793 - val_binary_crossentropy: 0.6793 - val_auc: 0.5879
Epoch 17/20
3/3 [==============================] - 0s 53ms/step - loss: 0.2180 - binary_crossentropy: 0.2180 - auc: 0.9765 - val_loss: 0.6687 - val_binary_crossentropy: 0.6687 - val_auc: 0.6297
Epoch 18/20
3/3 [==============================] - 0s 50ms/step - loss: 0.1580 - binary_crossentropy: 0.1580 - auc: 0.9880 - val_loss: 0.8166 - val_binary_crossentropy: 0.8166 - val_auc: 0.6386
Epoch 19/20
3/3 [==============================] - 0s 51ms/step - loss: 0.1155 - binary_crossentropy: 0.1155 - auc: 0.9997 - val_loss: 0.7452 - val_binary_crossentropy: 0.7452 - val_auc: 0.6187
Epoch 20/20
3/3 [==============================] - 0s 52ms/step - loss: 0.1096 - binary_crossentropy: 0.1096 - auc: 0.9976 - val_loss: 0.6991 - val_binary_crossentropy: 0.6991 - val_auc: 0.5911
5、思考
在你的应用场景中,哪些特征适合放在Wide侧,哪些特征适合放在Deep侧,为什么呢?
wide侧用于记忆,适合输入组合特征,用于记住那些已经存在过的特征组合。deep侧用于泛化,适合输入非组合特征,包括离散特征和连续特征,用于泛化那些未曾出现过或者低频的特征组合。
为什么Wide部分要用L1 FTRL训练?
W&D采用L1 FTRL是想让Wide部分变得更加稀疏。该部分也是和 Google 当时的业务需求有关,论文中 Wide 部分采用的特征是:User Installed App 和 Impression App,两个id类特征向量进行组合,在维度爆炸的同时,会让原本已经非常稀疏的multihot特征向量,变得更加稀疏。正因如此,wide部分的权重数量其实是海量的。为了不把数量如此之巨的权重都搬到线上进行model serving,采用FTRL过滤掉哪些稀疏特征无疑是非常好的工程经验。
为什么Deep部分不特别考虑稀疏性的问题?
Deep 的输入数据为 Continuous Features 和经过 Embedding 的 Categorical Features,所以 Deep 部分并不存在很严重的特征稀疏问题。