Task4 NFM


NFM

论文 Abstract 中提到 FM 与之前的深度网络模型的不足

  • FM: 局限于线性表达和二阶交互,对于捕获显示数据非线性和复杂的内在结构表现力不够

  • 深度网络:例如 Wide&Deep 和 DeepCross ,简单地拼接特征 embedding 向量不会考虑任何的特征之间的交互,但是能够学习特征交互的非线性层的深层网络结构(DNN)又很难训练优化

NFM 放弃了直接把嵌入向量(embedding)拼接输入到神经网络的做法,在嵌入层之后增加了 B-Interaction Layer 来对二阶组合特征进行建模。这使得 low level 的输入表达的信息更加的丰富,极大的提高了后面隐藏层学习高阶非线性组合特征的能力。

import warnings
warnings.filterwarnings("ignore")
import itertools
import pandas as pd
import numpy as np
from tqdm import tqdm
from collections import namedtuple

import tensorflow as tf
from tensorflow import keras 
from tensorflow.keras.layers import *
from tensorflow.keras.models import *

from sklearn.model_selection import train_test_split
from sklearn.preprocessing import  MinMaxScaler, LabelEncoder

from deepctr.feature_column import SparseFeat, DenseFeat, VarLenSparseFeat
# 设置GPU显存使用,tensorflow不然会报错
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        # Currently, memory growth needs to be the same across GPUs
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        logical_gpus = tf.config.experimental.list_logical_devices('GPU')
        print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
    except RuntimeError as e:
        # Memory growth must be set before GPUs have been initialized
        print(e)
1 Physical GPUs, 1 Logical GPUs

1、Load Dataset

data = pd.read_csv("../代码/data/criteo_sample.txt")
columns = data.columns.values
columns
array(['label', 'I1', 'I2', 'I3', 'I4', 'I5', 'I6', 'I7', 'I8', 'I9',
       'I10', 'I11', 'I12', 'I13', 'C1', 'C2', 'C3', 'C4', 'C5', 'C6',
       'C7', 'C8', 'C9', 'C10', 'C11', 'C12', 'C13', 'C14', 'C15', 'C16',
       'C17', 'C18', 'C19', 'C20', 'C21', 'C22', 'C23', 'C24', 'C25',
       'C26'], dtype=object)
# 划分 dense 和 sparse 特征
dense_features = [i for i in columns if 'I' in i]
sparse_features = [i for i in columns if 'C' in i]
print(dense_features)
print(sparse_features)
['I1', 'I2', 'I3', 'I4', 'I5', 'I6', 'I7', 'I8', 'I9', 'I10', 'I11', 'I12', 'I13']
['C1', 'C2', 'C3', 'C4', 'C5', 'C6', 'C7', 'C8', 'C9', 'C10', 'C11', 'C12', 'C13', 'C14', 'C15', 'C16', 'C17', 'C18', 'C19', 'C20', 'C21', 'C22', 'C23', 'C24', 'C25', 'C26']

2、Data Process

# 简单处理特征,包括填充缺失值,数值处理,类别编码
def data_process(data_df, dense_features, sparse_features):
    data_df[dense_features] = data_df[dense_features].fillna(0.0)
    for f in dense_features:
        data_df[f] = data_df[f].apply(lambda x: np.log(x+1) if x > -1 else -1)
        
    data_df[sparse_features] = data_df[sparse_features].fillna("-1")
    for f in sparse_features:
        lbe = LabelEncoder()
        data_df[f] = lbe.fit_transform(data_df[f])
    
    return data_df[dense_features + sparse_features]
# 简单的数据预处理
train_data = data_process(data, dense_features, sparse_features)
train_data['label'] = data['label']
train_data.head()
I1 I2 I3 I4 I5 I6 I7 I8 I9 I10 ... C18 C19 C20 C21 C22 C23 C24 C25 C26 label
0 0.0 1.386294 5.564520 0.000000 9.779567 0.000000 0.000000 3.526361 0.000000 0.0 ... 66 0 0 3 0 1 96 0 0 0
1 0.0 -1.000000 2.995732 3.583519 10.317318 5.513429 0.693147 3.583519 5.081404 0.0 ... 52 0 0 47 0 7 112 0 0 0
2 0.0 0.000000 1.098612 2.564949 7.607878 5.105945 1.945910 3.583519 6.261492 0.0 ... 49 0 0 25 0 6 53 0 0 0
3 0.0 2.639057 0.693147 1.609438 9.731334 5.303305 1.791759 1.609438 3.401197 0.0 ... 37 0 0 156 0 0 32 0 0 0
4 0.0 0.000000 4.653960 3.332205 7.596392 4.962845 1.609438 3.496508 3.637586 0.0 ... 14 5 3 9 0 0 5 1 47 0

5 rows × 40 columns

# 与 DeepFM、Wide&Deep、DeepCrossing 处理方式一致,这里是分成 Linear 和 DNN 两部分
linear_feature_columns = [
    SparseFeat(feat, vocabulary_size=data[feat].nunique(),embedding_dim=4)
    for i,feat in enumerate(sparse_features)
]+[
    DenseFeat(feat, 1) for feat in dense_features
]
linear_feature_columns[0],len(linear_feature_columns)
(SparseFeat(name='C1', vocabulary_size=27, embedding_dim=4, use_hash=False, dtype='int32', embeddings_initializer=<tensorflow.python.keras.initializers.initializers_v1.RandomNormal object at 0x000002C0F38E6580>, embedding_name='C1', group_name='default_group', trainable=True),
 39)
dnn_feature_columns = [
    SparseFeat(feat, vocabulary_size=data[feat].nunique(),embedding_dim=4)
    for i,feat in enumerate(sparse_features)
]+[
    DenseFeat(feat, 1) for feat in dense_features
]
dnn_feature_columns[0],len(dnn_feature_columns)
(SparseFeat(name='C1', vocabulary_size=27, embedding_dim=4, use_hash=False, dtype='int32', embeddings_initializer=<tensorflow.python.keras.initializers.initializers_v1.RandomNormal object at 0x000002C093E1D070>, embedding_name='C1', group_name='default_group', trainable=True),
 39)

3、模型原理

3.1 NFM Model

与 FM 相似,第一项和第二项是线性回归部分,第三项右原来 FM 的二阶隐向量内积换成 f(x)函数

3.2 模型整体结构图

  • Figure 2: Neural Factorization Machines model (the first-order linear regression part is not shown for clarity).

原文对此处进行了标注说明,这里只是对 f(x) 建模的过程,并没有考虑一阶部分。

4、NFM model build

4.1 Input Layers

def build_input_layers(feature_columns):
    # 构建Input层字典,并以dense和sparse两类字典的形式返回
    dense_input_dict, sparse_input_dict = {}, {}

    for fc in feature_columns:
        if isinstance(fc, SparseFeat):
            sparse_input_dict[fc.name] = Input(shape=(1, ), name=fc.name)
        elif isinstance(fc, DenseFeat):
            dense_input_dict[fc.name] = Input(shape=(fc.dimension, ), name=fc.name)
        
    return dense_input_dict, sparse_input_dict
# 同样使用字典输入方式,方便后续构建模型
# 这里传入数据是一样的,实际情况可以根据具体场景进行选择,这点不要混淆
dense_input_dict, sparse_input_dict = build_input_layers(linear_feature_columns + dnn_feature_columns)
print(dense_input_dict["I1"],len(dense_input_dict),type(dense_input_dict))
print(sparse_input_dict["C1"],len(sparse_input_dict),type(sparse_input_dict))
KerasTensor(type_spec=TensorSpec(shape=(None, 1), dtype=tf.float32, name='I1'), name='I1', description="created by layer 'I1'") 13 <class 'dict'>
KerasTensor(type_spec=TensorSpec(shape=(None, 1), dtype=tf.float32, name='C1'), name='C1', description="created by layer 'C1'") 26 <class 'dict'>
# 构建模型的输入层,模型的输入层不能是字典的形式,应该将字典的形式转换成列表的形式
# 注意:这里实际的输入与Input()层的对应,是通过模型输入时候的字典数据的key与对应name的Input层
input_layers = list(dense_input_dict.values()) + list(sparse_input_dict.values())
input_layers[0],len(input_layers),type(input_layers)
(<KerasTensor: shape=(None, 1) dtype=float32 (created by layer 'I1')>,
 39,
 list)
# 示例代码给了 input_layers_dict 参数不过并没有被使用
def build_embedding_layers(feature_columns, input_layers_dict, is_linear):
    
    embedding_layers_dict = dict()
    
    sparse_feature_columns = list(filter(lambda x: isinstance(x, SparseFeat), feature_columns)) if feature_columns else []
    
    if is_linear:
        for fc in sparse_feature_columns:
            embedding_layers_dict[fc.name] = Embedding(fc.vocabulary_size, 1, name="1d_emb_"+fc.name)
    else:
        for fc in sparse_feature_columns:
            embedding_layers_dict[fc.name] = Embedding(fc.vocabulary_size, fc.embedding_dim, name="kd_emb_"+fc.name)
    return embedding_layers_dict
    

4.2 Linear&Dnn sparse feature

# 将linear部分的特征中sparse特征筛选出来,后面用来做1维的embedding
linear_sparse_feature_columns = list(filter(lambda x: isinstance(x, SparseFeat), linear_feature_columns))
# 将输入到dnn中的所有sparse特征筛选出来
dnn_sparse_feature_columns = list(filter(lambda x: isinstance(x, SparseFeat), dnn_feature_columns))
print(linear_sparse_feature_columns[0],len(linear_sparse_feature_columns))
print(dnn_sparse_feature_columns[0],len(dnn_sparse_feature_columns))
SparseFeat(name='C1', vocabulary_size=27, embedding_dim=4, use_hash=False, dtype='int32', embeddings_initializer=<tensorflow.python.keras.initializers.initializers_v1.RandomNormal object at 0x000002C0F38E6580>, embedding_name='C1', group_name='default_group', trainable=True) 26
SparseFeat(name='C1', vocabulary_size=27, embedding_dim=4, use_hash=False, dtype='int32', embeddings_initializer=<tensorflow.python.keras.initializers.initializers_v1.RandomNormal object at 0x000002C093E1D070>, embedding_name='C1', group_name='default_group', trainable=True) 26

4.3 Linear Logits

  • 线性部分
def get_linear_logits(dense_input_dict,sparse_input_dict,sparse_feature_columns):
    
    # dense 特征处理
    concat_dense_inputs = Concatenate(axis=1)(list(dense_input_dict.values()))
    dense_logits_output = Dense(1)(concat_dense_inputs)
    
    # sparse 特征处理   
    linear_embedding_layers = build_embedding_layers(sparse_feature_columns, sparse_input_dict, is_linear=True)
    # 将一维的 embedding 进行拼接
    sparse_1d_embed = []
    for fc in sparse_feature_columns:
        feat_input = sparse_input_dict[fc.name]
        embed = Flatten()(linear_embedding_layers[fc.name](feat_input)) # B(batch)*1
        sparse_1d_embed.append(embed)
    sparse_logits_output = Add()(sparse_1d_embed)
    
    # dense + sparse
    linear_logits = Add()([dense_logits_output, sparse_logits_output])
    return linear_logits
linear_logits = get_linear_logits(dense_input_dict,sparse_input_dict,linear_sparse_feature_columns)
linear_logits
<KerasTensor: shape=(None, 1) dtype=float32 (created by layer 'add_1')>

4.4 DNN Logits

# 构建维度为k的embedding层,这里使用字典的形式返回,方便后面搭建模型
# embedding层用户构建FM交叉部分和DNN的输入部分
embedding_layers = build_embedding_layers(dnn_feature_columns, sparse_input_dict, is_linear=False)
embedding_layers["C1"],len(embedding_layers)
(<tensorflow.python.keras.layers.embeddings.Embedding at 0x2c0f37d9e50>, 26)

4.4.1 Bi-Interaction Pooling layer

这里是 NFM 网络的核心结构,这里需要注意的是,和 DeepFM 不同之处在于,$\odot$是元素积,并不是内积(点乘)

$\odot$表示两个向量的元素积操作,即两个向量对应维度相乘得到的元素积向量(可不是点乘呀),其中第$k$维的操作:

同样,可以将原式化简为

4.4.2 内积(inner product/scalar product)

矢量的降维运算,结果为一个常数,矩阵的内积是每行每列的内积的矩阵。

x = numpy.array([1, 2])
y = numpy.array([10, 20])
print("Array inner:")
print(numpy.inner(x, y))
''' Output:
Array inner:
50
'''

x = numpy.mat([[1, 2], [3, 4]])
y = numpy.mat([10, 20])
print("Matrix inner:")
print(numpy.inner(x, y))
''' Output:
Matrix inner:
[[ 50]
 [110]]
'''

4.4.3外积

和内积降维不同,外积是升维运算,m 维矢量和 n 维矢量的外积是 m*n(矩阵)

矩阵的并集运算,a1∗a2维矢量和b1∗b2维矩阵的外积是(a1∗a2)∗(b1∗b2)为矩阵。

x = numpy.array([1, 3])
y = numpy.array([10, 20])
print("Array outer:")
print(numpy.outer(x, y))
''' Output:
Array outer:
[[10 20]
 [30 60]]
'''

x = numpy.mat([[1, 2], [3, 4]])
y = numpy.mat([10, 20])
print("Matrix outer:")
print(numpy.outer(x, y))
''' Output:
Matrix outer:
[[10 20]
 [20 40]
 [30 60]
 [40 80]]
'''

Tips: matrix outer 是各自 vector outer 的并集

4.4.4 元素积(element-wise product/point-wise product/Hadamard product)

x = numpy.array([1, 3])
y = numpy.array([10, 20])
print("Array element-wise product:")
print(x * y)
''' Output:
Array element-wise product:
[10 60]
'''
class BiInteractionPooling(Layer):
    def __init__(self):
        super(BiInteractionPooling, self).__init__()
    
    def call(self, inputs):
        concated_embeds_value = inputs
        
        square_of_sum = tf.square(tf.reduce_sum(concated_embeds_value, axis=1, keepdims=False))
        sum_of_square = tf.reduce_sum(concated_embeds_value*concated_embeds_value, axis=1, keepdims=False)
        # 和 deepFM 不同的地方 元素积并不是内积,内积输出的B x 1,元素积输出的B x k
        cross_term = 0.5*(square_of_sum-sum_of_square) # size = B x k
        
        return cross_term
    
    def compute_output_shape(self, input_shape):
        return (None, input_shape[2])
    
def get_bi_interaction_pooling_output(sparse_input_dict, sparse_feature_columns, dnn_embedding_layers):
    
    sparse_kd_embed = []
    for fc in sparse_feature_columns:
        feat_input = sparse_input_dict[fc.name]
        _embed = dnn_embedding_layers[fc.name](feat_input)
        sparse_kd_embed.append(_embed)
    
    concat_sparse_kd_embed = Concatenate(axis=1)(sparse_kd_embed)
    
    pooling_out = BiInteractionPooling()(concat_sparse_kd_embed)
    
    return pooling_out
    
pooling_output = get_bi_interaction_pooling_output(sparse_input_dict, dnn_sparse_feature_columns, embedding_layers)
# 归一化
pooling_output = BatchNormalization()(pooling_output)
def get_dnn_logits(pooling_output):
    # 带有 Dropout 的 DNN
    dnn_out = Dropout(0.5)(Dense(1024, activation="relu")(pooling_output))
    dnn_out = Dropout(0.3)(Dense(512, activation="relu")(dnn_out))
    dnn_out = Dropout(0.1)(Dense(256, activation="relu")(dnn_out))
    
    dnn_logits = Dense(1)(dnn_out)
    
    return dnn_logits
dnn_logits = get_dnn_logits(pooling_output)
dnn_logits
<KerasTensor: shape=(None, 1) dtype=float32 (created by layer 'dense_13')>

4.5 Output Logits

# Linear + Dnn logits
output_logits = Add()([linear_logits, dnn_logits])
output_layers = Activation("sigmoid")(output_logits)
model = Model(input_layers, output_layers)
model.summary()
Model: "model"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
C1 (InputLayer)                 [(None, 1)]          0                                            
__________________________________________________________________________________________________
C2 (InputLayer)                 [(None, 1)]          0                                            
__________________________________________________________________________________________________
C3 (InputLayer)                 [(None, 1)]          0                                            
__________________________________________________________________________________________________
C4 (InputLayer)                 [(None, 1)]          0                                            
__________________________________________________________________________________________________
C5 (InputLayer)                 [(None, 1)]          0                                            
__________________________________________________________________________________________________
C6 (InputLayer)                 [(None, 1)]          0                                            
__________________________________________________________________________________________________
C7 (InputLayer)                 [(None, 1)]          0                                            
__________________________________________________________________________________________________
C8 (InputLayer)                 [(None, 1)]          0                                            
__________________________________________________________________________________________________
C9 (InputLayer)                 [(None, 1)]          0                                            
__________________________________________________________________________________________________
C10 (InputLayer)                [(None, 1)]          0                                            
__________________________________________________________________________________________________
C11 (InputLayer)                [(None, 1)]          0                                            
__________________________________________________________________________________________________
C12 (InputLayer)                [(None, 1)]          0                                            
__________________________________________________________________________________________________
C13 (InputLayer)                [(None, 1)]          0                                            
__________________________________________________________________________________________________
C14 (InputLayer)                [(None, 1)]          0                                            
__________________________________________________________________________________________________
C15 (InputLayer)                [(None, 1)]          0                                            
__________________________________________________________________________________________________
C16 (InputLayer)                [(None, 1)]          0                                            
__________________________________________________________________________________________________
C17 (InputLayer)                [(None, 1)]          0                                            
__________________________________________________________________________________________________
C18 (InputLayer)                [(None, 1)]          0                                            
__________________________________________________________________________________________________
C19 (InputLayer)                [(None, 1)]          0                                            
__________________________________________________________________________________________________
C20 (InputLayer)                [(None, 1)]          0                                            
__________________________________________________________________________________________________
C21 (InputLayer)                [(None, 1)]          0                                            
__________________________________________________________________________________________________
C22 (InputLayer)                [(None, 1)]          0                                            
__________________________________________________________________________________________________
C23 (InputLayer)                [(None, 1)]          0                                            
__________________________________________________________________________________________________
C24 (InputLayer)                [(None, 1)]          0                                            
__________________________________________________________________________________________________
C25 (InputLayer)                [(None, 1)]          0                                            
__________________________________________________________________________________________________
C26 (InputLayer)                [(None, 1)]          0                                            
__________________________________________________________________________________________________
kd_emb_C1 (Embedding)           (None, 1, 4)         108         C1[0][0]                         
__________________________________________________________________________________________________
kd_emb_C2 (Embedding)           (None, 1, 4)         368         C2[0][0]                         
__________________________________________________________________________________________________
kd_emb_C3 (Embedding)           (None, 1, 4)         688         C3[0][0]                         
__________________________________________________________________________________________________
kd_emb_C4 (Embedding)           (None, 1, 4)         628         C4[0][0]                         
__________________________________________________________________________________________________
kd_emb_C5 (Embedding)           (None, 1, 4)         48          C5[0][0]                         
__________________________________________________________________________________________________
kd_emb_C6 (Embedding)           (None, 1, 4)         28          C6[0][0]                         
__________________________________________________________________________________________________
kd_emb_C7 (Embedding)           (None, 1, 4)         732         C7[0][0]                         
__________________________________________________________________________________________________
kd_emb_C8 (Embedding)           (None, 1, 4)         76          C8[0][0]                         
__________________________________________________________________________________________________
kd_emb_C9 (Embedding)           (None, 1, 4)         8           C9[0][0]                         
__________________________________________________________________________________________________
kd_emb_C10 (Embedding)          (None, 1, 4)         568         C10[0][0]                        
__________________________________________________________________________________________________
kd_emb_C11 (Embedding)          (None, 1, 4)         692         C11[0][0]                        
__________________________________________________________________________________________________
kd_emb_C12 (Embedding)          (None, 1, 4)         680         C12[0][0]                        
__________________________________________________________________________________________________
kd_emb_C13 (Embedding)          (None, 1, 4)         664         C13[0][0]                        
__________________________________________________________________________________________________
kd_emb_C14 (Embedding)          (None, 1, 4)         56          C14[0][0]                        
__________________________________________________________________________________________________
kd_emb_C15 (Embedding)          (None, 1, 4)         680         C15[0][0]                        
__________________________________________________________________________________________________
kd_emb_C16 (Embedding)          (None, 1, 4)         672         C16[0][0]                        
__________________________________________________________________________________________________
kd_emb_C17 (Embedding)          (None, 1, 4)         36          C17[0][0]                        
__________________________________________________________________________________________________
kd_emb_C18 (Embedding)          (None, 1, 4)         508         C18[0][0]                        
__________________________________________________________________________________________________
kd_emb_C19 (Embedding)          (None, 1, 4)         176         C19[0][0]                        
__________________________________________________________________________________________________
kd_emb_C20 (Embedding)          (None, 1, 4)         16          C20[0][0]                        
__________________________________________________________________________________________________
kd_emb_C21 (Embedding)          (None, 1, 4)         676         C21[0][0]                        
__________________________________________________________________________________________________
kd_emb_C22 (Embedding)          (None, 1, 4)         24          C22[0][0]                        
__________________________________________________________________________________________________
kd_emb_C23 (Embedding)          (None, 1, 4)         40          C23[0][0]                        
__________________________________________________________________________________________________
kd_emb_C24 (Embedding)          (None, 1, 4)         500         C24[0][0]                        
__________________________________________________________________________________________________
kd_emb_C25 (Embedding)          (None, 1, 4)         80          C25[0][0]                        
__________________________________________________________________________________________________
kd_emb_C26 (Embedding)          (None, 1, 4)         360         C26[0][0]                        
__________________________________________________________________________________________________
concatenate_3 (Concatenate)     (None, 26, 4)        0           kd_emb_C1[2][0]                  
                                                                 kd_emb_C2[2][0]                  
                                                                 kd_emb_C3[2][0]                  
                                                                 kd_emb_C4[2][0]                  
                                                                 kd_emb_C5[2][0]                  
                                                                 kd_emb_C6[2][0]                  
                                                                 kd_emb_C7[2][0]                  
                                                                 kd_emb_C8[2][0]                  
                                                                 kd_emb_C9[2][0]                  
                                                                 kd_emb_C10[2][0]                 
                                                                 kd_emb_C11[2][0]                 
                                                                 kd_emb_C12[2][0]                 
                                                                 kd_emb_C13[2][0]                 
                                                                 kd_emb_C14[2][0]                 
                                                                 kd_emb_C15[2][0]                 
                                                                 kd_emb_C16[2][0]                 
                                                                 kd_emb_C17[2][0]                 
                                                                 kd_emb_C18[2][0]                 
                                                                 kd_emb_C19[2][0]                 
                                                                 kd_emb_C20[2][0]                 
                                                                 kd_emb_C21[2][0]                 
                                                                 kd_emb_C22[2][0]                 
                                                                 kd_emb_C23[2][0]                 
                                                                 kd_emb_C24[2][0]                 
                                                                 kd_emb_C25[2][0]                 
                                                                 kd_emb_C26[2][0]                 
__________________________________________________________________________________________________
bi_interaction_pooling_2 (BiInt (None, 4)            0           concatenate_3[0][0]              
__________________________________________________________________________________________________
batch_normalization (BatchNorma (None, 4)            16          bi_interaction_pooling_2[0][0]   
__________________________________________________________________________________________________
dense_10 (Dense)                (None, 1024)         5120        batch_normalization[0][0]        
__________________________________________________________________________________________________
dropout_9 (Dropout)             (None, 1024)         0           dense_10[0][0]                   
__________________________________________________________________________________________________
dense_11 (Dense)                (None, 512)          524800      dropout_9[0][0]                  
__________________________________________________________________________________________________
I1 (InputLayer)                 [(None, 1)]          0                                            
__________________________________________________________________________________________________
I2 (InputLayer)                 [(None, 1)]          0                                            
__________________________________________________________________________________________________
I3 (InputLayer)                 [(None, 1)]          0                                            
__________________________________________________________________________________________________
I4 (InputLayer)                 [(None, 1)]          0                                            
__________________________________________________________________________________________________
I5 (InputLayer)                 [(None, 1)]          0                                            
__________________________________________________________________________________________________
I6 (InputLayer)                 [(None, 1)]          0                                            
__________________________________________________________________________________________________
I7 (InputLayer)                 [(None, 1)]          0                                            
__________________________________________________________________________________________________
I8 (InputLayer)                 [(None, 1)]          0                                            
__________________________________________________________________________________________________
I9 (InputLayer)                 [(None, 1)]          0                                            
__________________________________________________________________________________________________
I10 (InputLayer)                [(None, 1)]          0                                            
__________________________________________________________________________________________________
I11 (InputLayer)                [(None, 1)]          0                                            
__________________________________________________________________________________________________
I12 (InputLayer)                [(None, 1)]          0                                            
__________________________________________________________________________________________________
I13 (InputLayer)                [(None, 1)]          0                                            
__________________________________________________________________________________________________
1d_emb_C1 (Embedding)           (None, 1, 1)         27          C1[0][0]                         
__________________________________________________________________________________________________
1d_emb_C2 (Embedding)           (None, 1, 1)         92          C2[0][0]                         
__________________________________________________________________________________________________
1d_emb_C3 (Embedding)           (None, 1, 1)         172         C3[0][0]                         
__________________________________________________________________________________________________
1d_emb_C4 (Embedding)           (None, 1, 1)         157         C4[0][0]                         
__________________________________________________________________________________________________
1d_emb_C5 (Embedding)           (None, 1, 1)         12          C5[0][0]                         
__________________________________________________________________________________________________
1d_emb_C6 (Embedding)           (None, 1, 1)         7           C6[0][0]                         
__________________________________________________________________________________________________
1d_emb_C7 (Embedding)           (None, 1, 1)         183         C7[0][0]                         
__________________________________________________________________________________________________
1d_emb_C8 (Embedding)           (None, 1, 1)         19          C8[0][0]                         
__________________________________________________________________________________________________
1d_emb_C9 (Embedding)           (None, 1, 1)         2           C9[0][0]                         
__________________________________________________________________________________________________
1d_emb_C10 (Embedding)          (None, 1, 1)         142         C10[0][0]                        
__________________________________________________________________________________________________
1d_emb_C11 (Embedding)          (None, 1, 1)         173         C11[0][0]                        
__________________________________________________________________________________________________
1d_emb_C12 (Embedding)          (None, 1, 1)         170         C12[0][0]                        
__________________________________________________________________________________________________
1d_emb_C13 (Embedding)          (None, 1, 1)         166         C13[0][0]                        
__________________________________________________________________________________________________
1d_emb_C14 (Embedding)          (None, 1, 1)         14          C14[0][0]                        
__________________________________________________________________________________________________
1d_emb_C15 (Embedding)          (None, 1, 1)         170         C15[0][0]                        
__________________________________________________________________________________________________
1d_emb_C16 (Embedding)          (None, 1, 1)         168         C16[0][0]                        
__________________________________________________________________________________________________
1d_emb_C17 (Embedding)          (None, 1, 1)         9           C17[0][0]                        
__________________________________________________________________________________________________
1d_emb_C18 (Embedding)          (None, 1, 1)         127         C18[0][0]                        
__________________________________________________________________________________________________
1d_emb_C19 (Embedding)          (None, 1, 1)         44          C19[0][0]                        
__________________________________________________________________________________________________
1d_emb_C20 (Embedding)          (None, 1, 1)         4           C20[0][0]                        
__________________________________________________________________________________________________
1d_emb_C21 (Embedding)          (None, 1, 1)         169         C21[0][0]                        
__________________________________________________________________________________________________
1d_emb_C22 (Embedding)          (None, 1, 1)         6           C22[0][0]                        
__________________________________________________________________________________________________
1d_emb_C23 (Embedding)          (None, 1, 1)         10          C23[0][0]                        
__________________________________________________________________________________________________
1d_emb_C24 (Embedding)          (None, 1, 1)         125         C24[0][0]                        
__________________________________________________________________________________________________
1d_emb_C25 (Embedding)          (None, 1, 1)         20          C25[0][0]                        
__________________________________________________________________________________________________
1d_emb_C26 (Embedding)          (None, 1, 1)         90          C26[0][0]                        
__________________________________________________________________________________________________
dropout_10 (Dropout)            (None, 512)          0           dense_11[0][0]                   
__________________________________________________________________________________________________
concatenate (Concatenate)       (None, 13)           0           I1[0][0]                         
                                                                 I2[0][0]                         
                                                                 I3[0][0]                         
                                                                 I4[0][0]                         
                                                                 I5[0][0]                         
                                                                 I6[0][0]                         
                                                                 I7[0][0]                         
                                                                 I8[0][0]                         
                                                                 I9[0][0]                         
                                                                 I10[0][0]                        
                                                                 I11[0][0]                        
                                                                 I12[0][0]                        
                                                                 I13[0][0]                        
__________________________________________________________________________________________________
flatten (Flatten)               (None, 1)            0           1d_emb_C1[0][0]                  
__________________________________________________________________________________________________
flatten_1 (Flatten)             (None, 1)            0           1d_emb_C2[0][0]                  
__________________________________________________________________________________________________
flatten_2 (Flatten)             (None, 1)            0           1d_emb_C3[0][0]                  
__________________________________________________________________________________________________
flatten_3 (Flatten)             (None, 1)            0           1d_emb_C4[0][0]                  
__________________________________________________________________________________________________
flatten_4 (Flatten)             (None, 1)            0           1d_emb_C5[0][0]                  
__________________________________________________________________________________________________
flatten_5 (Flatten)             (None, 1)            0           1d_emb_C6[0][0]                  
__________________________________________________________________________________________________
flatten_6 (Flatten)             (None, 1)            0           1d_emb_C7[0][0]                  
__________________________________________________________________________________________________
flatten_7 (Flatten)             (None, 1)            0           1d_emb_C8[0][0]                  
__________________________________________________________________________________________________
flatten_8 (Flatten)             (None, 1)            0           1d_emb_C9[0][0]                  
__________________________________________________________________________________________________
flatten_9 (Flatten)             (None, 1)            0           1d_emb_C10[0][0]                 
__________________________________________________________________________________________________
flatten_10 (Flatten)            (None, 1)            0           1d_emb_C11[0][0]                 
__________________________________________________________________________________________________
flatten_11 (Flatten)            (None, 1)            0           1d_emb_C12[0][0]                 
__________________________________________________________________________________________________
flatten_12 (Flatten)            (None, 1)            0           1d_emb_C13[0][0]                 
__________________________________________________________________________________________________
flatten_13 (Flatten)            (None, 1)            0           1d_emb_C14[0][0]                 
__________________________________________________________________________________________________
flatten_14 (Flatten)            (None, 1)            0           1d_emb_C15[0][0]                 
__________________________________________________________________________________________________
flatten_15 (Flatten)            (None, 1)            0           1d_emb_C16[0][0]                 
__________________________________________________________________________________________________
flatten_16 (Flatten)            (None, 1)            0           1d_emb_C17[0][0]                 
__________________________________________________________________________________________________
flatten_17 (Flatten)            (None, 1)            0           1d_emb_C18[0][0]                 
__________________________________________________________________________________________________
flatten_18 (Flatten)            (None, 1)            0           1d_emb_C19[0][0]                 
__________________________________________________________________________________________________
flatten_19 (Flatten)            (None, 1)            0           1d_emb_C20[0][0]                 
__________________________________________________________________________________________________
flatten_20 (Flatten)            (None, 1)            0           1d_emb_C21[0][0]                 
__________________________________________________________________________________________________
flatten_21 (Flatten)            (None, 1)            0           1d_emb_C22[0][0]                 
__________________________________________________________________________________________________
flatten_22 (Flatten)            (None, 1)            0           1d_emb_C23[0][0]                 
__________________________________________________________________________________________________
flatten_23 (Flatten)            (None, 1)            0           1d_emb_C24[0][0]                 
__________________________________________________________________________________________________
flatten_24 (Flatten)            (None, 1)            0           1d_emb_C25[0][0]                 
__________________________________________________________________________________________________
flatten_25 (Flatten)            (None, 1)            0           1d_emb_C26[0][0]                 
__________________________________________________________________________________________________
dense_12 (Dense)                (None, 256)          131328      dropout_10[0][0]                 
__________________________________________________________________________________________________
dense (Dense)                   (None, 1)            14          concatenate[0][0]                
__________________________________________________________________________________________________
add (Add)                       (None, 1)            0           flatten[0][0]                    
                                                                 flatten_1[0][0]                  
                                                                 flatten_2[0][0]                  
                                                                 flatten_3[0][0]                  
                                                                 flatten_4[0][0]                  
                                                                 flatten_5[0][0]                  
                                                                 flatten_6[0][0]                  
                                                                 flatten_7[0][0]                  
                                                                 flatten_8[0][0]                  
                                                                 flatten_9[0][0]                  
                                                                 flatten_10[0][0]                 
                                                                 flatten_11[0][0]                 
                                                                 flatten_12[0][0]                 
                                                                 flatten_13[0][0]                 
                                                                 flatten_14[0][0]                 
                                                                 flatten_15[0][0]                 
                                                                 flatten_16[0][0]                 
                                                                 flatten_17[0][0]                 
                                                                 flatten_18[0][0]                 
                                                                 flatten_19[0][0]                 
                                                                 flatten_20[0][0]                 
                                                                 flatten_21[0][0]                 
                                                                 flatten_22[0][0]                 
                                                                 flatten_23[0][0]                 
                                                                 flatten_24[0][0]                 
                                                                 flatten_25[0][0]                 
__________________________________________________________________________________________________
dropout_11 (Dropout)            (None, 256)          0           dense_12[0][0]                   
__________________________________________________________________________________________________
add_1 (Add)                     (None, 1)            0           dense[0][0]                      
                                                                 add[0][0]                        
__________________________________________________________________________________________________
dense_13 (Dense)                (None, 1)            257         dropout_11[0][0]                 
__________________________________________________________________________________________________
add_2 (Add)                     (None, 1)            0           add_1[0][0]                      
                                                                 dense_13[0][0]                   
__________________________________________________________________________________________________
activation (Activation)         (None, 1)            0           add_2[0][0]                      
==================================================================================================
Total params: 672,925
Trainable params: 672,917
Non-trainable params: 8
__________________________________________________________________________________________________
model.compile(optimizer="adam", 
            loss="binary_crossentropy", 
            metrics=["binary_crossentropy", tf.keras.metrics.AUC(name='auc')])

# 将输入数据转化成字典的形式输入
train_model_input = {name: data[name] for name in dense_features + sparse_features}
# 模型训练
model.fit(train_model_input, train_data['label'].values,
        batch_size=64, epochs=10, validation_split=0.3)
Epoch 1/10
3/3 [==============================] - 6s 378ms/step - loss: 0.3549 - binary_crossentropy: 0.3549 - auc: 0.8915 - val_loss: 0.6528 - val_binary_crossentropy: 0.6528 - val_auc: 0.6463
Epoch 2/10
3/3 [==============================] - 0s 50ms/step - loss: 0.2811 - binary_crossentropy: 0.2811 - auc: 0.9425 - val_loss: 0.6455 - val_binary_crossentropy: 0.6455 - val_auc: 0.6489
Epoch 3/10
3/3 [==============================] - 0s 52ms/step - loss: 0.2301 - binary_crossentropy: 0.2301 - auc: 0.9890 - val_loss: 0.6382 - val_binary_crossentropy: 0.6382 - val_auc: 0.6521
Epoch 4/10
3/3 [==============================] - 0s 52ms/step - loss: 0.1785 - binary_crossentropy: 0.1785 - auc: 0.9961 - val_loss: 0.6287 - val_binary_crossentropy: 0.6287 - val_auc: 0.6630
Epoch 5/10
3/3 [==============================] - 0s 52ms/step - loss: 0.1622 - binary_crossentropy: 0.1622 - auc: 0.9981 - val_loss: 0.6194 - val_binary_crossentropy: 0.6194 - val_auc: 0.6720
Epoch 6/10
3/3 [==============================] - 0s 52ms/step - loss: 0.1363 - binary_crossentropy: 0.1363 - auc: 0.9977 - val_loss: 0.6095 - val_binary_crossentropy: 0.6095 - val_auc: 0.6874
Epoch 7/10
3/3 [==============================] - 0s 51ms/step - loss: 0.0882 - binary_crossentropy: 0.0882 - auc: 0.9990 - val_loss: 0.6001 - val_binary_crossentropy: 0.6001 - val_auc: 0.7015
Epoch 8/10
3/3 [==============================] - 0s 52ms/step - loss: 0.0372 - binary_crossentropy: 0.0372 - auc: 1.0000 - val_loss: 0.5963 - val_binary_crossentropy: 0.5963 - val_auc: 0.7163
Epoch 9/10
3/3 [==============================] - 0s 51ms/step - loss: 0.0245 - binary_crossentropy: 0.0245 - auc: 1.0000 - val_loss: 0.6320 - val_binary_crossentropy: 0.6320 - val_auc: 0.7227
Epoch 10/10
3/3 [==============================] - 0s 52ms/step - loss: 0.0126 - binary_crossentropy: 0.0126 - auc: 1.0000 - val_loss: 0.7111 - val_binary_crossentropy: 0.7111 - val_auc: 0.7240

5. 思考题

  1. NFM中的特征交叉与FM中的特征交叉有何异同,分别从原理和代码实现上进行对比分析
  • FM 只考虑了二阶相关特征,NFM 则是使用了B-Interaction Pooling Layer,并且 NFM 借助了 DNN 进行更高阶的学习
  • FM 的计算方式为内积,NFM 则是元素积

代码实现上:

FM:

square_of_sum = tf.square(tf.reduce_sum(concated_embeds_value, axis=1, keepdims=True))# B x 1 x k
sum_of_square = tf.reduce_sum(concated_embeds_value * concated_embeds_value, axis=1, keepdims=True)# B x 1 x k
cross_term = square_of_sum - sum_of_square # B x 1 x k
cross_term = 0.5*tf.reduce_sum(cross_term, axis=2, keepdims=False) # B x 1

NFM:

square_of_sum = tf.square(tf.reduce_sum(concated_embeds_value, axis=1, keepdims=False)) # B x k
sum_of_square = tf.reduce_sum(concated_embeds_value * concated_embeds_value, axis=1, keepdims=False) # B x k
cross_term = 0.5 * (square_of_sum - sum_of_square) # B x k

参考资料


文章作者: Terence Cai
版权声明: 本博客所有文章除特別声明外,均采用 CC BY 4.0 许可协议。转载请注明来源 Terence Cai !
  目录