DeepFM
- 该文主要记录
深度推荐模型|DataWhale
Task3 DeepFM 内容
与上次一次 Task 的 Wide&Deep 的模型很相似,DeepFM 同样使用浅层模型和深层模型进行联合训练。不同点主要有以下两点:
- Wide 部分由 LR 替换为 FM,FM 具有自动学习交叉特征的能力,避免了原始 Wide&Deep 模型在 Wide 端人工特征工程的工作
- 共享原始输入特征,DeepFM 模型的原始特征将作为 FM 和 Deep 模型部分的共同输入,保证模型特征的准确和一致。
论文提出的贡献点:
- DeepFM 模型包含 FM 和 DNN 两部分,FM 模型可以抽取 low-order 特征,DNN 可以抽取 high-order 特征。无需 Wide&Deep 模型人工特征工程。
- 由于输入仅为原始特征,而且 FM 和 DNN 共享输入向量特征,DeepFM 模型训练速度很快。
- 在 Benchmark 数据集和商业数据集上,DeepFM 效果超过目前所有模型。
import warnings
warnings.filterwarnings("ignore")
import itertools
import pandas as pd
import numpy as np
from tqdm import tqdm
from collections import namedtuple
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras.layers import *
from tensorflow.keras.models import *
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import MinMaxScaler, LabelEncoder
from deepctr.feature_column import SparseFeat, DenseFeat, VarLenSparseFeat
# 设置GPU显存使用,tensorflow不然会报错
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
try:
# Currently, memory growth needs to be the same across GPUs
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
logical_gpus = tf.config.experimental.list_logical_devices('GPU')
print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
except RuntimeError as e:
# Memory growth must be set before GPUs have been initialized
print(e)
1 Physical GPUs, 1 Logical GPUs
1、Load Dataset
data = pd.read_csv("../代码/data/criteo_sample.txt")
columns = data.columns.values
columns
array(['label', 'I1', 'I2', 'I3', 'I4', 'I5', 'I6', 'I7', 'I8', 'I9',
'I10', 'I11', 'I12', 'I13', 'C1', 'C2', 'C3', 'C4', 'C5', 'C6',
'C7', 'C8', 'C9', 'C10', 'C11', 'C12', 'C13', 'C14', 'C15', 'C16',
'C17', 'C18', 'C19', 'C20', 'C21', 'C22', 'C23', 'C24', 'C25',
'C26'], dtype=object)
# 划分 dense 和 sparse 特征
dense_features = [i for i in columns if 'I' in i]
sparse_features = [i for i in columns if 'C' in i]
print(dense_features)
print(sparse_features)
['I1', 'I2', 'I3', 'I4', 'I5', 'I6', 'I7', 'I8', 'I9', 'I10', 'I11', 'I12', 'I13']
['C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'C7', 'C8', 'C9', 'C10', 'C11', 'C12', 'C13', 'C14', 'C15', 'C16', 'C17', 'C18', 'C19', 'C20', 'C21', 'C22', 'C23', 'C24', 'C25', 'C26']
2、Data Process
# 简单处理特征,包括填充缺失值,数值处理,类别编码
def data_process(data_df, dense_features, sparse_features):
data_df[dense_features] = data_df[dense_features].fillna(0.0)
for f in dense_features:
data_df[f] = data_df[f].apply(lambda x: np.log(x+1) if x > -1 else -1)
data_df[sparse_features] = data_df[sparse_features].fillna("-1")
for f in sparse_features:
lbe = LabelEncoder()
data_df[f] = lbe.fit_transform(data_df[f])
return data_df[dense_features + sparse_features]
# 简单的数据预处理
train_data = data_process(data, dense_features, sparse_features)
train_data['label'] = data['label']
train_data.head()
I1 | I2 | I3 | I4 | I5 | I6 | I7 | I8 | I9 | I10 | ... | C18 | C19 | C20 | C21 | C22 | C23 | C24 | C25 | C26 | label | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 0.0 | 1.386294 | 5.564520 | 0.000000 | 9.779567 | 0.000000 | 0.000000 | 3.526361 | 0.000000 | 0.0 | ... | 66 | 0 | 0 | 3 | 0 | 1 | 96 | 0 | 0 | 0 |
1 | 0.0 | -1.000000 | 2.995732 | 3.583519 | 10.317318 | 5.513429 | 0.693147 | 3.583519 | 5.081404 | 0.0 | ... | 52 | 0 | 0 | 47 | 0 | 7 | 112 | 0 | 0 | 0 |
2 | 0.0 | 0.000000 | 1.098612 | 2.564949 | 7.607878 | 5.105945 | 1.945910 | 3.583519 | 6.261492 | 0.0 | ... | 49 | 0 | 0 | 25 | 0 | 6 | 53 | 0 | 0 | 0 |
3 | 0.0 | 2.639057 | 0.693147 | 1.609438 | 9.731334 | 5.303305 | 1.791759 | 1.609438 | 3.401197 | 0.0 | ... | 37 | 0 | 0 | 156 | 0 | 0 | 32 | 0 | 0 | 0 |
4 | 0.0 | 0.000000 | 4.653960 | 3.332205 | 7.596392 | 4.962845 | 1.609438 | 3.496508 | 3.637586 | 0.0 | ... | 14 | 5 | 3 | 9 | 0 | 0 | 5 | 1 | 47 | 0 |
5 rows × 40 columns
# 与 Wide&Deep、DeepCrossing 处理方式一致,这里是分成 Linear 和 DNN 两部分
linear_feature_columns = [
SparseFeat(feat, vocabulary_size=data[feat].nunique(),embedding_dim=4)
for i,feat in enumerate(sparse_features)
]+[
DenseFeat(feat, 1) for feat in dense_features
]
linear_feature_columns[0],len(linear_feature_columns)
(SparseFeat(name='C1', vocabulary_size=27, embedding_dim=4, use_hash=False, dtype='int32', embeddings_initializer=<tensorflow.python.keras.initializers.initializers_v1.RandomNormal object at 0x0000026A0E4F3C70>, embedding_name='C1', group_name='default_group', trainable=True),
39)
dnn_feature_columns = [
SparseFeat(feat, vocabulary_size=data[feat].nunique(),embedding_dim=4)
for i,feat in enumerate(sparse_features)
]+[
DenseFeat(feat, 1) for feat in dense_features
]
dnn_feature_columns[0],len(dnn_feature_columns)
(SparseFeat(name='C1', vocabulary_size=27, embedding_dim=4, use_hash=False, dtype='int32', embeddings_initializer=<tensorflow.python.keras.initializers.initializers_v1.RandomNormal object at 0x0000026A0E4F7D90>, embedding_name='C1', group_name='default_group', trainable=True),
39)
3、模型原理
FM
在 FM 模型之前,CTR 预估最常用的模型是 SVM 和 LR,包括用 GBDT 给它们提取非线性特征。SVM 和 LR 都是线性的模型,SVM 理论上可以引入核函数从而引入非线性,但是由于非常高维的稀疏特征向量使得它很容易过拟合,因此实际很少用,但是实际的这些特征是有相关性的。比如根据这次的学习论文DeepFM: A Factorization-Machine based Neural Network for CTR Prediction的调研,在主流的app应用市场里,在吃饭的时间点食谱类app的下载量明显会比平时更大,因此app的类别和时间会存在二阶的关系。另外一个例子是年轻男性更喜欢射击类游戏,这是年龄、性别和游戏类别的三阶关系。因此怎么建模二阶、三阶甚至更高阶的非线性的关系是预估能够更加准确的关键点。
下文参考了深入FFM原理与实践。
直接建模二阶关系:
其中,n 是特征向量的维度,$x_{i}$是第i维特征,$w_{0}$、$w_{i}$、$w_{ij}$是模型参数。这种方法存在一个问题,参数个数较多为$\frac{n(n-1)}{2}$,另外由于样本数据本身就较为稀疏,想要找到$x_{i}$和$x_{j}$都非零的样本更少。训练样本的不足,很容易导致参数$w_{ij}$不准确,影响最终模型的性能。
如何解决二次项参数的训练问题?矩阵分解提供了一种解决思路。根据《Factorization Machines的一个talk》在 model-based 的协同过滤中,一个 rating 矩阵可以分解为 user 矩阵和 item 矩阵,每个 user 和 item 都可以采用一个隐向量表示。比如在下图中的例子中,我们把每个 user 表示成一个二维向量,同时把每个 item 表示成一个二维向量,两个向量的点积就是矩阵中 user 对 item 的打分。
(公式显示问题,暂用图片
4、DeepFM model build
FM
下图是 FM 的一个结构图,从图中大致可以看出 FM Layer 是由一阶特征和二阶特征 Concatenate 到一起在经过一个 Sigmoid 得到 logits(结合FM的公式一起看),所以在实现的时候需要单独考虑 linear 部分和 FM 交叉特征部分。
代码实现中也是分成了两部分:
- 一阶特征:Linear
- 二阶特征:FM
Deep
Deep架构图
Deep Module 是为了学习高阶的特征组合,在上图中使用用全连接的方式将 Dense Embedding 输入到Hidden Layer ,这里面 Dense Embeddings 就是为了解决 DNN 中的参数爆炸问题,这也是推荐模型中常用的处理方法。
4.1 Input Layers
def build_input_layers(feature_columns):
# 构建Input层字典,并以dense和sparse两类字典的形式返回
dense_input_dict, sparse_input_dict = {}, {}
for fc in feature_columns:
if isinstance(fc, SparseFeat):
sparse_input_dict[fc.name] = Input(shape=(1, ), name=fc.name)
elif isinstance(fc, DenseFeat):
dense_input_dict[fc.name] = Input(shape=(fc.dimension, ), name=fc.name)
return dense_input_dict, sparse_input_dict
# 同样使用字典输入方式,方便后续构建模型
# 这里传入数据是一样的,实际情况可以根据具体场景进行选择,这点不要混淆
dense_input_dict, sparse_input_dict = build_input_layers(linear_feature_columns + dnn_feature_columns)
print(dense_input_dict["I1"],len(dense_input_dict),type(dense_input_dict))
print(sparse_input_dict["C1"],len(sparse_input_dict),type(sparse_input_dict))
KerasTensor(type_spec=TensorSpec(shape=(None, 1), dtype=tf.float32, name='I1'), name='I1', description="created by layer 'I1'") 13 <class 'dict'>
KerasTensor(type_spec=TensorSpec(shape=(None, 1), dtype=tf.float32, name='C1'), name='C1', description="created by layer 'C1'") 26 <class 'dict'>
# 构建模型的输入层,模型的输入层不能是字典的形式,应该将字典的形式转换成列表的形式
# 注意:这里实际的输入与Input()层的对应,是通过模型输入时候的字典数据的key与对应name的Input层
input_layers = list(dense_input_dict.values()) + list(sparse_input_dict.values())
input_layers[0],len(input_layers),type(input_layers)
(<KerasTensor: shape=(None, 1) dtype=float32 (created by layer 'I1')>,
39,
list)
# 示例代码给了 input_layers_dict 参数不过并没有被使用
def build_embedding_layers(feature_columns, input_layers_dict, is_linear):
embedding_layers_dict = dict()
sparse_feature_columns = list(filter(lambda x: isinstance(x, SparseFeat), feature_columns)) if feature_columns else []
if is_linear:
for fc in sparse_feature_columns:
embedding_layers_dict[fc.name] = Embedding(fc.vocabulary_size, 1, name="1d_emb_"+fc.name)
else:
for fc in sparse_feature_columns:
embedding_layers_dict[fc.name] = Embedding(fc.vocabulary_size, fc.embedding_dim, name="kd_emb_"+fc.name)
return embedding_layers_dict
4.2 Linear&Dnn sparse feature
# 将linear部分的特征中sparse特征筛选出来,后面用来做1维的embedding
linear_sparse_feature_columns = list(filter(lambda x: isinstance(x, SparseFeat), linear_feature_columns))
# 将输入到dnn中的所有sparse特征筛选出来
dnn_sparse_feature_columns = list(filter(lambda x: isinstance(x, SparseFeat), dnn_feature_columns))
print(linear_sparse_feature_columns[0],len(linear_sparse_feature_columns))
print(dnn_sparse_feature_columns[0],len(dnn_sparse_feature_columns))
SparseFeat(name='C1', vocabulary_size=27, embedding_dim=4, use_hash=False, dtype='int32', embeddings_initializer=<tensorflow.python.keras.initializers.initializers_v1.RandomNormal object at 0x0000026A0E4F3C70>, embedding_name='C1', group_name='default_group', trainable=True) 26
SparseFeat(name='C1', vocabulary_size=27, embedding_dim=4, use_hash=False, dtype='int32', embeddings_initializer=<tensorflow.python.keras.initializers.initializers_v1.RandomNormal object at 0x0000026A0E4F7D90>, embedding_name='C1', group_name='default_group', trainable=True) 26
4.3 Linear Logits
- linear_logits: 这部分是有关于线性计算,也就是FM的前半部分$w1x1+w2x2…wnxn+b$的计算。对于这一块的计算,我们用了一个 get_linear_logits 函数实现,后面再说,总之通过这个函数,我们就可以实现上面这个公式的计算过程,得到 linear 的输出, 这部分特征由数值特征和类别特征的 onehot 编码组成的一维向量组成,实际应用中根据自己的业务放置不同的一阶特征(这里的 dense 特征并不是必须的,有可能会将数值特征进行分桶,然后在当做类别特征来处理)
def get_linear_logits(dense_input_dict,sparse_input_dict,sparse_feature_columns):
# dense 特征处理
concat_dense_inputs = Concatenate(axis=1)(list(dense_input_dict.values()))
dense_logits_output = Dense(1)(concat_dense_inputs)
# sparse 特征处理
linear_embedding_layers = build_embedding_layers(sparse_feature_columns, sparse_input_dict, is_linear=True)
# 将一维的 embedding 进行拼接
sparse_1d_embed = []
for fc in sparse_feature_columns:
feat_input = sparse_input_dict[fc.name]
embed = Flatten()(linear_embedding_layers[fc.name](feat_input)) # B(batch)*1
sparse_1d_embed.append(embed)
sparse_logits_output = Add()(sparse_1d_embed)
# dense + sparse
linear_logits = Add()([dense_logits_output, sparse_logits_output])
return linear_logits
linear_logits = get_linear_logits(dense_input_dict,sparse_input_dict,linear_sparse_feature_columns)
linear_logits
<KerasTensor: shape=(None, 1) dtype=float32 (created by layer 'add_5')>
4.4 FM Logits
- fm_logits: 这一块主要是针对离散的特征,首先过 embedding,然后使用 FM 特征交叉的方式,两两特征进行交叉,得到新的特征向量,最后计算交叉特征的 logits
# 构建维度为k的embedding层,这里使用字典的形式返回,方便后面搭建模型
# embedding层用户构建FM交叉部分和DNN的输入部分
embedding_layers = build_embedding_layers(dnn_feature_columns, sparse_input_dict, is_linear=False)
embedding_layers["C1"],len(embedding_layers)
(<tensorflow.python.keras.layers.embeddings.Embedding at 0x26a0e599790>, 26)
因为只考虑了二阶项,这部分直接使用公式推导的结论,并没有带上一阶项:
class FM_Layer(Layer):
def __init__(self):
super(FM_Layer, self).__init__()
def call(self, inputs):
# 优化后的公式为: 0.5 * 求和(和的平方-平方的和) =>> B x 1
concated_embeds_value = inputs
square_of_sum = tf.square(tf.reduce_sum(concated_embeds_value, axis=1, keepdims=True))# B x 1 x k
sum_of_square = tf.reduce_sum(concated_embeds_value * concated_embeds_value, axis=1, keepdims=True)# B x 1 x k
cross_term = square_of_sum - sum_of_square # B x 1 x k
cross_term = 0.5*tf.reduce_sum(cross_term, axis=2, keepdims=False) # B x 1
return cross_term
def compute_output_shape(self, input_shape):
return (None, 1)
def get_fm_logits(sparse_input_dict, sparse_feature_columns, dnn_embedding_layers):
sparse_feature_columns = list(filter(lambda x: isinstance(x, SparseFeat),sparse_feature_columns))
sparse_kd_embed = []
for fc in sparse_feature_columns:
feat_input = sparse_input_dict[fc.name]
# KerasTensor(type_spec=TensorSpec(shape=(None, 1, 4), dtype=tf.float32, name=None),
# name='kd_emb_C1/embedding_lookup/Identity_1:0', description="created by layer 'kd_emb_C1'")
_embed = dnn_embedding_layers[fc.name](feat_input)# B x 1 x k
sparse_kd_embed.append(_embed)
# 将所有sparse的embedding拼接起来,得到 (n, k)的矩阵,其中n为特征数,k为embedding大小
concat_sparse_kd_embed = Concatenate(axis=1)(sparse_kd_embed) # B x n x k
fm_cross_out = FM_Layer()(concat_sparse_kd_embed)
return fm_cross_out
# 和 linear_logits 不同的是,fm_logits 只需要处理 spare 类型
fm_logits = get_fm_logits(sparse_input_dict, dnn_sparse_feature_columns, embedding_layers) # 只考虑二阶项
fm_logits
<KerasTensor: shape=(None, 1) dtype=float32 (created by layer 'fm__layer_2')>
4.5 DNN Logits
- dnn_logits: 这一块主要是针对离散的特征,首先过 embedding,然后将得到的 embedding 拼接成一个向量(具体的可以看代码,也可以看一下下面的模型结构图),通过 dnn 学习类别特征之间的隐式特征交叉并输出 logits 值
def get_dnn_logits(sparse_input_dict, sparse_feature_columns, dnn_embedding_layers):
# 将特征中的sparse特征筛选出来
sparse_feature_columns = list(filter(lambda x: isinstance(x, SparseFeat), sparse_feature_columns))
sparse_kd_embed = []
for fc in sparse_feature_columns:
feat_input = sparse_input_dict[fc.name]
_embed = dnn_embedding_layers[fc.name](feat_input)
_embed = Flatten()(_embed)
sparse_kd_embed.append(_embed)
concat_sparse_kd_embed = Concatenate(axis=1)(sparse_kd_embed)
mlp_out = Dropout(0.5)(Dense(256, activation="relu")(concat_sparse_kd_embed))
mlp_out = Dropout(0.3)(Dense(256, activation="relu")(mlp_out))
mlp_out = Dropout(0.1)(Dense(256, activation="relu")(mlp_out))
dnn_out = Dense(1)(mlp_out)
return dnn_out
dnn_logits = get_dnn_logits(sparse_input_dict, dnn_sparse_feature_columns, embedding_layers)
dnn_logits
<KerasTensor: shape=(None, 1) dtype=float32 (created by layer 'dense_6')>
4.6 Output Layers
output_logits = Add()([linear_logits, fm_logits, dnn_logits])
output_layers = Activation("sigmoid")(output_logits)
model = Model(input_layers, output_layers)
model.summary()
Model: "model"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
C1 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
C2 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
C3 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
C4 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
C5 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
C6 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
C7 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
C8 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
C9 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
C10 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
C11 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
C12 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
C13 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
C14 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
C15 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
C16 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
C17 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
C18 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
C19 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
C20 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
C21 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
C22 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
C23 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
C24 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
C25 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
C26 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
kd_emb_C1 (Embedding) (None, 1, 4) 108 C1[0][0]
C1[0][0]
__________________________________________________________________________________________________
kd_emb_C2 (Embedding) (None, 1, 4) 368 C2[0][0]
C2[0][0]
__________________________________________________________________________________________________
kd_emb_C3 (Embedding) (None, 1, 4) 688 C3[0][0]
C3[0][0]
__________________________________________________________________________________________________
kd_emb_C4 (Embedding) (None, 1, 4) 628 C4[0][0]
C4[0][0]
__________________________________________________________________________________________________
kd_emb_C5 (Embedding) (None, 1, 4) 48 C5[0][0]
C5[0][0]
__________________________________________________________________________________________________
kd_emb_C6 (Embedding) (None, 1, 4) 28 C6[0][0]
C6[0][0]
__________________________________________________________________________________________________
kd_emb_C7 (Embedding) (None, 1, 4) 732 C7[0][0]
C7[0][0]
__________________________________________________________________________________________________
kd_emb_C8 (Embedding) (None, 1, 4) 76 C8[0][0]
C8[0][0]
__________________________________________________________________________________________________
kd_emb_C9 (Embedding) (None, 1, 4) 8 C9[0][0]
C9[0][0]
__________________________________________________________________________________________________
kd_emb_C10 (Embedding) (None, 1, 4) 568 C10[0][0]
C10[0][0]
__________________________________________________________________________________________________
kd_emb_C11 (Embedding) (None, 1, 4) 692 C11[0][0]
C11[0][0]
__________________________________________________________________________________________________
kd_emb_C12 (Embedding) (None, 1, 4) 680 C12[0][0]
C12[0][0]
__________________________________________________________________________________________________
kd_emb_C13 (Embedding) (None, 1, 4) 664 C13[0][0]
C13[0][0]
__________________________________________________________________________________________________
kd_emb_C14 (Embedding) (None, 1, 4) 56 C14[0][0]
C14[0][0]
__________________________________________________________________________________________________
kd_emb_C15 (Embedding) (None, 1, 4) 680 C15[0][0]
C15[0][0]
__________________________________________________________________________________________________
kd_emb_C16 (Embedding) (None, 1, 4) 672 C16[0][0]
C16[0][0]
__________________________________________________________________________________________________
kd_emb_C17 (Embedding) (None, 1, 4) 36 C17[0][0]
C17[0][0]
__________________________________________________________________________________________________
kd_emb_C18 (Embedding) (None, 1, 4) 508 C18[0][0]
C18[0][0]
__________________________________________________________________________________________________
kd_emb_C19 (Embedding) (None, 1, 4) 176 C19[0][0]
C19[0][0]
__________________________________________________________________________________________________
kd_emb_C20 (Embedding) (None, 1, 4) 16 C20[0][0]
C20[0][0]
__________________________________________________________________________________________________
kd_emb_C21 (Embedding) (None, 1, 4) 676 C21[0][0]
C21[0][0]
__________________________________________________________________________________________________
kd_emb_C22 (Embedding) (None, 1, 4) 24 C22[0][0]
C22[0][0]
__________________________________________________________________________________________________
kd_emb_C23 (Embedding) (None, 1, 4) 40 C23[0][0]
C23[0][0]
__________________________________________________________________________________________________
kd_emb_C24 (Embedding) (None, 1, 4) 500 C24[0][0]
C24[0][0]
__________________________________________________________________________________________________
kd_emb_C25 (Embedding) (None, 1, 4) 80 C25[0][0]
C25[0][0]
__________________________________________________________________________________________________
kd_emb_C26 (Embedding) (None, 1, 4) 360 C26[0][0]
C26[0][0]
__________________________________________________________________________________________________
flatten_78 (Flatten) (None, 4) 0 kd_emb_C1[5][0]
__________________________________________________________________________________________________
flatten_79 (Flatten) (None, 4) 0 kd_emb_C2[5][0]
__________________________________________________________________________________________________
flatten_80 (Flatten) (None, 4) 0 kd_emb_C3[5][0]
__________________________________________________________________________________________________
flatten_81 (Flatten) (None, 4) 0 kd_emb_C4[5][0]
__________________________________________________________________________________________________
flatten_82 (Flatten) (None, 4) 0 kd_emb_C5[5][0]
__________________________________________________________________________________________________
flatten_83 (Flatten) (None, 4) 0 kd_emb_C6[5][0]
__________________________________________________________________________________________________
flatten_84 (Flatten) (None, 4) 0 kd_emb_C7[5][0]
__________________________________________________________________________________________________
flatten_85 (Flatten) (None, 4) 0 kd_emb_C8[5][0]
__________________________________________________________________________________________________
flatten_86 (Flatten) (None, 4) 0 kd_emb_C9[5][0]
__________________________________________________________________________________________________
flatten_87 (Flatten) (None, 4) 0 kd_emb_C10[5][0]
__________________________________________________________________________________________________
flatten_88 (Flatten) (None, 4) 0 kd_emb_C11[5][0]
__________________________________________________________________________________________________
flatten_89 (Flatten) (None, 4) 0 kd_emb_C12[5][0]
__________________________________________________________________________________________________
flatten_90 (Flatten) (None, 4) 0 kd_emb_C13[5][0]
__________________________________________________________________________________________________
flatten_91 (Flatten) (None, 4) 0 kd_emb_C14[5][0]
__________________________________________________________________________________________________
flatten_92 (Flatten) (None, 4) 0 kd_emb_C15[5][0]
__________________________________________________________________________________________________
flatten_93 (Flatten) (None, 4) 0 kd_emb_C16[5][0]
__________________________________________________________________________________________________
flatten_94 (Flatten) (None, 4) 0 kd_emb_C17[5][0]
__________________________________________________________________________________________________
flatten_95 (Flatten) (None, 4) 0 kd_emb_C18[5][0]
__________________________________________________________________________________________________
flatten_96 (Flatten) (None, 4) 0 kd_emb_C19[5][0]
__________________________________________________________________________________________________
flatten_97 (Flatten) (None, 4) 0 kd_emb_C20[5][0]
__________________________________________________________________________________________________
flatten_98 (Flatten) (None, 4) 0 kd_emb_C21[5][0]
__________________________________________________________________________________________________
flatten_99 (Flatten) (None, 4) 0 kd_emb_C22[5][0]
__________________________________________________________________________________________________
flatten_100 (Flatten) (None, 4) 0 kd_emb_C23[5][0]
__________________________________________________________________________________________________
flatten_101 (Flatten) (None, 4) 0 kd_emb_C24[5][0]
__________________________________________________________________________________________________
flatten_102 (Flatten) (None, 4) 0 kd_emb_C25[5][0]
__________________________________________________________________________________________________
flatten_103 (Flatten) (None, 4) 0 kd_emb_C26[5][0]
__________________________________________________________________________________________________
concatenate_7 (Concatenate) (None, 104) 0 flatten_78[0][0]
flatten_79[0][0]
flatten_80[0][0]
flatten_81[0][0]
flatten_82[0][0]
flatten_83[0][0]
flatten_84[0][0]
flatten_85[0][0]
flatten_86[0][0]
flatten_87[0][0]
flatten_88[0][0]
flatten_89[0][0]
flatten_90[0][0]
flatten_91[0][0]
flatten_92[0][0]
flatten_93[0][0]
flatten_94[0][0]
flatten_95[0][0]
flatten_96[0][0]
flatten_97[0][0]
flatten_98[0][0]
flatten_99[0][0]
flatten_100[0][0]
flatten_101[0][0]
flatten_102[0][0]
flatten_103[0][0]
__________________________________________________________________________________________________
dense_3 (Dense) (None, 256) 26880 concatenate_7[0][0]
__________________________________________________________________________________________________
dropout (Dropout) (None, 256) 0 dense_3[0][0]
__________________________________________________________________________________________________
dense_4 (Dense) (None, 256) 65792 dropout[0][0]
__________________________________________________________________________________________________
I1 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
I2 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
I3 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
I4 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
I5 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
I6 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
I7 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
I8 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
I9 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
I10 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
I11 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
I12 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
I13 (InputLayer) [(None, 1)] 0
__________________________________________________________________________________________________
1d_emb_C1 (Embedding) (None, 1, 1) 27 C1[0][0]
__________________________________________________________________________________________________
1d_emb_C2 (Embedding) (None, 1, 1) 92 C2[0][0]
__________________________________________________________________________________________________
1d_emb_C3 (Embedding) (None, 1, 1) 172 C3[0][0]
__________________________________________________________________________________________________
1d_emb_C4 (Embedding) (None, 1, 1) 157 C4[0][0]
__________________________________________________________________________________________________
1d_emb_C5 (Embedding) (None, 1, 1) 12 C5[0][0]
__________________________________________________________________________________________________
1d_emb_C6 (Embedding) (None, 1, 1) 7 C6[0][0]
__________________________________________________________________________________________________
1d_emb_C7 (Embedding) (None, 1, 1) 183 C7[0][0]
__________________________________________________________________________________________________
1d_emb_C8 (Embedding) (None, 1, 1) 19 C8[0][0]
__________________________________________________________________________________________________
1d_emb_C9 (Embedding) (None, 1, 1) 2 C9[0][0]
__________________________________________________________________________________________________
1d_emb_C10 (Embedding) (None, 1, 1) 142 C10[0][0]
__________________________________________________________________________________________________
1d_emb_C11 (Embedding) (None, 1, 1) 173 C11[0][0]
__________________________________________________________________________________________________
1d_emb_C12 (Embedding) (None, 1, 1) 170 C12[0][0]
__________________________________________________________________________________________________
1d_emb_C13 (Embedding) (None, 1, 1) 166 C13[0][0]
__________________________________________________________________________________________________
1d_emb_C14 (Embedding) (None, 1, 1) 14 C14[0][0]
__________________________________________________________________________________________________
1d_emb_C15 (Embedding) (None, 1, 1) 170 C15[0][0]
__________________________________________________________________________________________________
1d_emb_C16 (Embedding) (None, 1, 1) 168 C16[0][0]
__________________________________________________________________________________________________
1d_emb_C17 (Embedding) (None, 1, 1) 9 C17[0][0]
__________________________________________________________________________________________________
1d_emb_C18 (Embedding) (None, 1, 1) 127 C18[0][0]
__________________________________________________________________________________________________
1d_emb_C19 (Embedding) (None, 1, 1) 44 C19[0][0]
__________________________________________________________________________________________________
1d_emb_C20 (Embedding) (None, 1, 1) 4 C20[0][0]
__________________________________________________________________________________________________
1d_emb_C21 (Embedding) (None, 1, 1) 169 C21[0][0]
__________________________________________________________________________________________________
1d_emb_C22 (Embedding) (None, 1, 1) 6 C22[0][0]
__________________________________________________________________________________________________
1d_emb_C23 (Embedding) (None, 1, 1) 10 C23[0][0]
__________________________________________________________________________________________________
1d_emb_C24 (Embedding) (None, 1, 1) 125 C24[0][0]
__________________________________________________________________________________________________
1d_emb_C25 (Embedding) (None, 1, 1) 20 C25[0][0]
__________________________________________________________________________________________________
1d_emb_C26 (Embedding) (None, 1, 1) 90 C26[0][0]
__________________________________________________________________________________________________
dropout_1 (Dropout) (None, 256) 0 dense_4[0][0]
__________________________________________________________________________________________________
concatenate_2 (Concatenate) (None, 13) 0 I1[0][0]
I2[0][0]
I3[0][0]
I4[0][0]
I5[0][0]
I6[0][0]
I7[0][0]
I8[0][0]
I9[0][0]
I10[0][0]
I11[0][0]
I12[0][0]
I13[0][0]
__________________________________________________________________________________________________
flatten_52 (Flatten) (None, 1) 0 1d_emb_C1[0][0]
__________________________________________________________________________________________________
flatten_53 (Flatten) (None, 1) 0 1d_emb_C2[0][0]
__________________________________________________________________________________________________
flatten_54 (Flatten) (None, 1) 0 1d_emb_C3[0][0]
__________________________________________________________________________________________________
flatten_55 (Flatten) (None, 1) 0 1d_emb_C4[0][0]
__________________________________________________________________________________________________
flatten_56 (Flatten) (None, 1) 0 1d_emb_C5[0][0]
__________________________________________________________________________________________________
flatten_57 (Flatten) (None, 1) 0 1d_emb_C6[0][0]
__________________________________________________________________________________________________
flatten_58 (Flatten) (None, 1) 0 1d_emb_C7[0][0]
__________________________________________________________________________________________________
flatten_59 (Flatten) (None, 1) 0 1d_emb_C8[0][0]
__________________________________________________________________________________________________
flatten_60 (Flatten) (None, 1) 0 1d_emb_C9[0][0]
__________________________________________________________________________________________________
flatten_61 (Flatten) (None, 1) 0 1d_emb_C10[0][0]
__________________________________________________________________________________________________
flatten_62 (Flatten) (None, 1) 0 1d_emb_C11[0][0]
__________________________________________________________________________________________________
flatten_63 (Flatten) (None, 1) 0 1d_emb_C12[0][0]
__________________________________________________________________________________________________
flatten_64 (Flatten) (None, 1) 0 1d_emb_C13[0][0]
__________________________________________________________________________________________________
flatten_65 (Flatten) (None, 1) 0 1d_emb_C14[0][0]
__________________________________________________________________________________________________
flatten_66 (Flatten) (None, 1) 0 1d_emb_C15[0][0]
__________________________________________________________________________________________________
flatten_67 (Flatten) (None, 1) 0 1d_emb_C16[0][0]
__________________________________________________________________________________________________
flatten_68 (Flatten) (None, 1) 0 1d_emb_C17[0][0]
__________________________________________________________________________________________________
flatten_69 (Flatten) (None, 1) 0 1d_emb_C18[0][0]
__________________________________________________________________________________________________
flatten_70 (Flatten) (None, 1) 0 1d_emb_C19[0][0]
__________________________________________________________________________________________________
flatten_71 (Flatten) (None, 1) 0 1d_emb_C20[0][0]
__________________________________________________________________________________________________
flatten_72 (Flatten) (None, 1) 0 1d_emb_C21[0][0]
__________________________________________________________________________________________________
flatten_73 (Flatten) (None, 1) 0 1d_emb_C22[0][0]
__________________________________________________________________________________________________
flatten_74 (Flatten) (None, 1) 0 1d_emb_C23[0][0]
__________________________________________________________________________________________________
flatten_75 (Flatten) (None, 1) 0 1d_emb_C24[0][0]
__________________________________________________________________________________________________
flatten_76 (Flatten) (None, 1) 0 1d_emb_C25[0][0]
__________________________________________________________________________________________________
flatten_77 (Flatten) (None, 1) 0 1d_emb_C26[0][0]
__________________________________________________________________________________________________
dense_5 (Dense) (None, 256) 65792 dropout_1[0][0]
__________________________________________________________________________________________________
dense_2 (Dense) (None, 1) 14 concatenate_2[0][0]
__________________________________________________________________________________________________
add_4 (Add) (None, 1) 0 flatten_52[0][0]
flatten_53[0][0]
flatten_54[0][0]
flatten_55[0][0]
flatten_56[0][0]
flatten_57[0][0]
flatten_58[0][0]
flatten_59[0][0]
flatten_60[0][0]
flatten_61[0][0]
flatten_62[0][0]
flatten_63[0][0]
flatten_64[0][0]
flatten_65[0][0]
flatten_66[0][0]
flatten_67[0][0]
flatten_68[0][0]
flatten_69[0][0]
flatten_70[0][0]
flatten_71[0][0]
flatten_72[0][0]
flatten_73[0][0]
flatten_74[0][0]
flatten_75[0][0]
flatten_76[0][0]
flatten_77[0][0]
__________________________________________________________________________________________________
concatenate_6 (Concatenate) (None, 26, 4) 0 kd_emb_C1[4][0]
kd_emb_C2[4][0]
kd_emb_C3[4][0]
kd_emb_C4[4][0]
kd_emb_C5[4][0]
kd_emb_C6[4][0]
kd_emb_C7[4][0]
kd_emb_C8[4][0]
kd_emb_C9[4][0]
kd_emb_C10[4][0]
kd_emb_C11[4][0]
kd_emb_C12[4][0]
kd_emb_C13[4][0]
kd_emb_C14[4][0]
kd_emb_C15[4][0]
kd_emb_C16[4][0]
kd_emb_C17[4][0]
kd_emb_C18[4][0]
kd_emb_C19[4][0]
kd_emb_C20[4][0]
kd_emb_C21[4][0]
kd_emb_C22[4][0]
kd_emb_C23[4][0]
kd_emb_C24[4][0]
kd_emb_C25[4][0]
kd_emb_C26[4][0]
__________________________________________________________________________________________________
dropout_2 (Dropout) (None, 256) 0 dense_5[0][0]
__________________________________________________________________________________________________
add_5 (Add) (None, 1) 0 dense_2[0][0]
add_4[0][0]
__________________________________________________________________________________________________
fm__layer_2 (FM_Layer) (None, 1) 0 concatenate_6[0][0]
__________________________________________________________________________________________________
dense_6 (Dense) (None, 1) 257 dropout_2[0][0]
__________________________________________________________________________________________________
add_6 (Add) (None, 1) 0 add_5[0][0]
fm__layer_2[0][0]
dense_6[0][0]
__________________________________________________________________________________________________
activation (Activation) (None, 1) 0 add_6[0][0]
==================================================================================================
Total params: 170,125
Trainable params: 170,125
Non-trainable params: 0
__________________________________________________________________________________________________
model.compile(optimizer="adam",
loss="binary_crossentropy",
metrics=["binary_crossentropy", tf.keras.metrics.AUC(name='auc')])
# 将输入数据转化成字典的形式输入
train_model_input = {name: data[name] for name in dense_features + sparse_features}
# 模型训练
model.fit(train_model_input, train_data['label'].values,
batch_size=64, epochs=10, validation_split=0.3)
Epoch 1/10
3/3 [==============================] - 6s 418ms/step - loss: 0.6597 - binary_crossentropy: 0.6597 - auc: 0.5713 - val_loss: 0.8220 - val_binary_crossentropy: 0.8220 - val_auc: 0.5642
Epoch 2/10
3/3 [==============================] - 0s 53ms/step - loss: 0.6426 - binary_crossentropy: 0.6426 - auc: 0.5833 - val_loss: 0.8138 - val_binary_crossentropy: 0.8138 - val_auc: 0.5693
Epoch 3/10
3/3 [==============================] - 0s 53ms/step - loss: 0.6861 - binary_crossentropy: 0.6861 - auc: 0.5929 - val_loss: 0.8050 - val_binary_crossentropy: 0.8050 - val_auc: 0.5687
Epoch 4/10
3/3 [==============================] - 0s 52ms/step - loss: 0.5978 - binary_crossentropy: 0.5978 - auc: 0.6370 - val_loss: 0.7967 - val_binary_crossentropy: 0.7967 - val_auc: 0.5712
Epoch 5/10
3/3 [==============================] - 0s 51ms/step - loss: 0.5634 - binary_crossentropy: 0.5634 - auc: 0.6670 - val_loss: 0.7855 - val_binary_crossentropy: 0.7855 - val_auc: 0.5802
Epoch 6/10
3/3 [==============================] - 0s 54ms/step - loss: 0.5294 - binary_crossentropy: 0.5294 - auc: 0.7190 - val_loss: 0.7734 - val_binary_crossentropy: 0.7734 - val_auc: 0.5860
Epoch 7/10
3/3 [==============================] - 0s 54ms/step - loss: 0.4167 - binary_crossentropy: 0.4167 - auc: 0.8086 - val_loss: 0.7695 - val_binary_crossentropy: 0.7695 - val_auc: 0.5956
Epoch 8/10
3/3 [==============================] - 0s 52ms/step - loss: 0.3800 - binary_crossentropy: 0.3800 - auc: 0.8545 - val_loss: 0.7757 - val_binary_crossentropy: 0.7757 - val_auc: 0.6117
Epoch 9/10
3/3 [==============================] - 0s 52ms/step - loss: 0.2956 - binary_crossentropy: 0.2956 - auc: 0.9235 - val_loss: 0.7982 - val_binary_crossentropy: 0.7982 - val_auc: 0.6348
Epoch 10/10
3/3 [==============================] - 0s 51ms/step - loss: 0.1592 - binary_crossentropy: 0.1592 - auc: 0.9863 - val_loss: 0.8413 - val_binary_crossentropy: 0.8413 - val_auc: 0.6457
4、思考
如果对于FM采用随机梯度下降SGD训练模型参数,请写出模型各个参数的梯度和FM参数训练的复杂度
复杂度为$O\left(k n\right)$
对于下图所示,根据你的理解Sparse Feature中的不同颜色节点分别表示什么意思
黄色节点代表为1,对于每一个 Field,只有这一个位置是1,对应的灰色节点为0。