首頁 > 科技週邊 > 人工智慧 > TabTransformer轉換器提升多層感知機效能深度解析

TabTransformer轉換器提升多層感知機效能深度解析

WBOY
發布: 2023-04-17 15:25:03
轉載
1555 人瀏覽過

​如今,转换器(Transformers)成为大多数先进的自然语言处理(NLP)和计算机视觉(CV)体系结构中的关键模块。然而,表格式数据领域仍然主要以梯度提升决策树(GBDT)算法为主导。于是,有人试图弥合这一差距。其中,第一篇基于转换器的表格数据建模论文是由Huang等人于2020年发表的论文《TabTransformer:使用上下文嵌入的表格数据建模》。

本文旨在提供该论文内容的基本展示,同时将深入探讨TabTransformer模型的实现细节,并向您展示如何针对我们自己的数据来具体使用TabTransformer。

一、论文概述

上述论文的主要思想是,如果使用转换器将常规的分类嵌入转换为上下文嵌入,那么,常规的多层感知器(MLP)的性能将会得到显著提高。接下来,让我们更为深入地理解这一描述。

1.分类嵌入(Categorical Embeddings)

在深度学习模型中,使用分类特征的经典方法是训练其嵌入性。这意味着,每个类别值都有一个唯一的密集型向量表示,并且可以传递给下一层。例如,由下图您可以看到,每个分类特征都使用一个四维数组表示。然后,这些嵌入与数字特征串联,并用作MLP的输入。

TabTransformer轉換器提升多層感知機效能深度解析

带有分类嵌入的MLP

2.上下文嵌入(Contextual Embeddings)

论文作者认为,分类嵌入缺乏上下文含义,即它们并没有对分类变量之间的任何交互和关系信息进行编码。为了将嵌入内容更加具体化,有人建议使用NLP领域当前所使用的转换器来实现这一目的。

TabTransformer轉換器提升多層感知機效能深度解析

TabTransformer转换器中的上下文嵌入

为了以可视化方式形象地展示上述想法,我们不妨考虑下面这个训练后得到的上下文嵌入图像。其中,突出显示了两个分类特征:关系(黑色)和婚姻状况(蓝色)。这些特征是相关的;所以,“已婚(Married)”、“丈夫(Husband)”和“妻子(Wife)”的值应该在向量空间中彼此接近,即使它们来自不同的变量。

TabTransformer轉換器提升多層感知機效能深度解析

经训练后的TabTransformer转换器嵌入结果示例

通过上图中经过训练的上下文嵌入结果,我们可以看到,“已婚(Married)”的婚姻状况更接近“丈夫(Husband)”和“妻子(Wife)”的关系水平,而“未结婚(non-married)”的分类值则来自右侧的单独数据簇。这种类型的上下文使这样的嵌入更加有用,而使用简单形式的类别嵌入技术是不可能实现这种效果的。

3.TabTransformer架构

为了达到上述目的,论文作者提出了以下架构:

TabTransformer轉換器提升多層感知機效能深度解析

TabTransformer转换器架构示意图

(摘取自Huang等人2020年发表的论文)

我们可以将此体系结构分解为5个步骤:

  • 标准化数字特征并向前传递
  • 嵌入分类特征
  • 嵌入经过N次转换器块处理,以便获得上下文嵌入
  • 把上下文分类嵌入与数字特征进行串联
  • 通过MLP进行串联获得所需的预测

虽然模型架构非常简单,但论文作者表示,添加转换器层可以显著提高计算性能。当然,所有的“魔术”发生在这些转换器块内部;所以,接下来让我们更加详细地研究一下其中的实现过程。

4.转换器

TabTransformer轉換器提升多層感知機效能深度解析

转换器(Transformer)架构示意

(选自Vaswani等人于2017年发表的论文)

您可能以前见过转换器架构,但为了快速介绍起见,请记住该转换器是由编码器和解码器两部分组成(见上图)。对于TabTransformer,我们只关心将输入的嵌入内容上下文化的编码器部分(解码器部分将这些嵌入内容转换为最终输出结果)。但它到底是如何做到的呢?答案是——多头注意力机制。

5.多头注意力机制(Multi-head-attention)

引用我最喜歡的關於注意力機制的文章的描述,是這樣的:

#「自我關注(self attention)背後的關鍵概念是,這種機制允許神經網路學習如何在輸入序列的各個片段之間以最好的路由方案進行資訊調度。」

換句話說,自我關注(self-attention)有助於模型找出在表示某個單字/類別時,輸入的哪些部分更重要,哪些部分相對不重要。為此,我強烈建議您閱讀一下上面引用的這篇文章,以便對自我關注為什麼如此有效有一個更直觀的理解。

TabTransformer轉換器提升多層感知機效能深度解析

多頭注意力機制

(選自Vaswani等人於2017年發表的論文)

#注意力是透過3個學習過的矩陣來計算的-Q、K和V,它們代表查詢(Query)、鍵(Key)和值(Value)。首先,我們將矩陣Q和K相乘得到注意力矩陣。此矩陣被縮放並通過softmax層傳遞。然後,我們將其乘以V矩陣,得出最終值。為了更直觀地理解起見,請考慮下面的示意圖,它顯示了我們如何使用矩陣Q、K和V實現從輸入嵌入轉換到上下文嵌入。

TabTransformer轉換器提升多層感知機效能深度解析

自我關注流程視覺化

透過重複流程h次(使用不同的Q、K 、V矩陣),我們就能夠得到多個脈絡嵌入,它們形成我們最終的多頭注意力。

6.簡短回顧

讓我們總結一下上面所介紹的內容:

  • 簡單的分類嵌入不包含上下文訊息
  • 透過轉換器編碼器傳遞分類嵌入,我們就能夠將嵌入上下文化
  • 轉換器部分能夠將嵌入上下文化,因為它使用了多頭注意力機制
  • 多頭注意力機制在編碼變數時使用矩陣Q、K和V來尋找有用的交互作用和相關性資訊
  • 在TabTransformer中,被上下文化的嵌入與數位輸入相串聯,並透過一個簡單的MLP輸出預測

#雖然TabTransformer背後的想法很簡單,但您可能需要一些時間才能掌握注意力機制。因此,我強烈建議您重新閱讀以上解釋。如果您感到有些迷茫,請認真閱讀本文中所有建議的連結相關內容。我保證,做到這些後,您就不難搞明白注意力機制的原理了。

7.試驗結果展示

TabTransformer轉換器提升多層感知機效能深度解析

#結果資料(選自Huang等人2020年發表的論文)

根據報告的結果,TabTransformer轉換器優於所有其他深度學習表格模型,此外,它接近GBDT的性能水平,這非常令人鼓舞。該模型對缺失資料和雜訊資料也相對穩健,並且在半監督環境下優於其他模型。然而,這些資料集顯然不是詳盡無遺的,正如以後發表的一些相關論文所證實的那樣,仍有很大的改進空間。

二、建立我們自己的範例程式

#現在,讓我們最終來確定如何將模型應用於我們自己的資料。接下來的範例數據取自著名的Tabular Playground Kaggle比賽。為了方便使用TabTransformer轉換器,我建立了一個tabtransformertf套件。它可以使用以下pip命令進行安裝:

pip install tabtransformertf
登入後複製

並允許我們使用該模型,而無需進行大量預處理。

1.資料預處理

第一步是設定適當的資料類型,並將我們的訓練和驗證資料轉換為TF數據集。其中,前面安裝的軟體包中就提供了一個很好的實用程式可以做到這一點。

from tabtransformertf.utils.preprocessing import df_to_dataset, build_categorical_prep

# 设置数据类型
train_data[CATEGORICAL_FEATURES] = train_data[CATEGORICAL_FEATURES].astype(str)
val_data[CATEGORICAL_FEATURES] = val_data[CATEGORICAL_FEATURES].astype(str)

train_data[NUMERIC_FEATURES] = train_data[NUMERIC_FEATURES].astype(float)
val_data[NUMERIC_FEATURES] = val_data[NUMERIC_FEATURES].astype(float)

# 转换成TF数据集
train_dataset = df_to_dataset(train_data[FEATURES + [LABEL]], LABEL, batch_size=1024)
val_dataset = df_to_dataset(val_data[FEATURES + [LABEL]], LABEL, shuffle=False, batch_size=1024)
登入後複製

下一步是為分類資料準備預處理層。該分類資料稍後將傳遞給我們的主模型。

from tabtransformertf.utils.preprocessing import build_categorical_prep

category_prep_layers = build_categorical_prep(train_data, CATEGORICAL_FEATURES)

# 输出结果是一个字典结构,其中键部分是特征名称,值部分是StringLookup层
# category_prep_layers ->
# {'product_code': <keras.layers.preprocessing.string_lookup.StringLookup at 0x7f05d28ee4e0>,
#'attribute_0': <keras.layers.preprocessing.string_lookup.StringLookup at 0x7f05ca4fb908>,
#'attribute_1': <keras.layers.preprocessing.string_lookup.StringLookup at 0x7f05ca4da5f8>}
登入後複製

這就是預處理!現在,我們可以開始建立模型了。

2.建構TabTransformer模型

#初始化模型很容易。其中,有幾個參數需要指定,但最重要的幾個參數是:embeding_dim、depth和heads。所有參數都是在超參數調整後選擇的。

from tabtransformertf.models.tabtransformer import TabTransformer

tabtransformer = TabTransformer(
numerical_features = NUMERIC_FEATURES,# 带有数字特征名称的列表
categorical_features = CATEGORICAL_FEATURES, # 带有分类特征名称的列表
categorical_lookup=category_prep_layers, # 带StringLookup层的Dict
numerical_discretisers=None,# None代表我们只是简单地传递数字特征
embedding_dim=32,# 嵌入维数
out_dim=1,# Dimensionality of output (binary task)
out_activatinotallow='sigmoid',# 输出层激活
depth=4,# 转换器块层的个数
heads=8,# 转换器块中注意力头的个数
attn_dropout=0.1,# 在转换器块中的丢弃率
ff_dropout=0.1,# 在最后MLP中的丢弃率
mlp_hidden_factors=[2, 4],# 我们为每一层划分最终嵌入的因子
use_column_embedding=True,#如果我们想使用列嵌入,设置此项为真
)

# 模型运行中摘要输出:
# 总参数个数: 1,778,884
# 可训练的参数个数: 1,774,064
# 不可训练的参数个数: 4,820
登入後複製

模型初始化後,我們可以像其他Keras模型一樣安裝它。訓練參數也可以調整,所以可以隨意調整學習速度和提前停止。

LEARNING_RATE = 0.0001
WEIGHT_DECAY = 0.0001
NUM_EPOCHS = 1000

optimizer = tfa.optimizers.AdamW(
learning_rate=LEARNING_RATE, weight_decay=WEIGHT_DECAY
)

tabtransformer.compile(
optimizer = optimizer,
loss = tf.keras.losses.BinaryCrossentropy(),
metrics= [tf.keras.metrics.AUC(name="PR AUC", curve='PR')],
)

out_file = './tabTransformerBasic'
checkpoint = ModelCheckpoint(
out_file, mnotallow="val_loss", verbose=1, save_best_notallow=True, mode="min"
)
early = EarlyStopping(mnotallow="val_loss", mode="min", patience=10, restore_best_weights=True)
callback_list = [checkpoint, early]

history = tabtransformer.fit(
train_dataset,
epochs=NUM_EPOCHS,
validation_data=val_dataset,
callbacks=callback_list
)
登入後複製

3.評價

競賽中最關鍵的指標是ROC AUC。因此,讓我們將其與PR AUC指標一起輸出來評估模型的表現。

val_preds = tabtransformer.predict(val_dataset)

print(f"PR AUC: {average_precision_score(val_data['isFraud'], val_preds.ravel())}")
print(f"ROC AUC: {roc_auc_score(val_data['isFraud'], val_preds.ravel())}")

# PR AUC: 0.26
# ROC AUC: 0.58
登入後複製

您也可以自己给测试集评分,然后将结果值提交给Kaggle官方。我现在选择的这个解决方案使我跻身前35%,这并不坏,但也不太好。那么,为什么TabTransfromer在上述方案中表现不佳呢?可能有以下几个原因:

  • 数据集太小,而深度学习模型以需要大量数据著称
  • TabTransformer很容易在表格式数据示例领域出现过拟合
  • 没有足够的分类特征使模型有用

三、结论

本文探讨了TabTransformer背后的主要思想,并展示了如何使用Tabtransformertf包来具体应用此转换器。

归纳起来看,TabTransformer的确是一种有趣的体系结构,它在当时的表现明显优于大多数深度表格模型。它的主要优点是将分类嵌入语境化,从而增强其表达能力。它使用在分类特征上的多头注意力机制来实现这一点,而这是在表格数据领域使用转换器的第一个应用实例。

TabTransformer体系结构的一个明显缺点是,数字特征被简单地传递到最终的MLP层。因此,它们没有语境化,它们的价值也没有在分类嵌入中得到解释。在下一篇文章中,我将探讨如何修复此缺陷并进一步提高性能。

译者介绍

朱先忠,51CTO社区编辑,51CTO专家博客、讲师,潍坊一所高校计算机教师,自由编程界老兵一枚。

原文链接:https://towardsdatascience.com/transformers-for-tabular-data-tabtransformer-deep-dive-5fb2438da820?source=collection_home---------4----------------------------

以上是TabTransformer轉換器提升多層感知機效能深度解析的詳細內容。更多資訊請關注PHP中文網其他相關文章!

相關標籤:
來源:51cto.com
本網站聲明
本文內容由網友自願投稿,版權歸原作者所有。本站不承擔相應的法律責任。如發現涉嫌抄襲或侵權的內容,請聯絡admin@php.cn
熱門教學
更多>
最新下載
更多>
網站特效
網站源碼
網站素材
前端模板