Home  >  Article  >  Technology peripherals  >  Why do tree-based models still outperform deep learning on tabular data?

Why do tree-based models still outperform deep learning on tabular data?

WBOY
WBOYforward
2023-04-08 14:41:031501browse

Deep learning has made huge progress in areas such as images, language and even audio. However, deep learning performs mediocrely when it comes to processing tabular data. Since tabular data has characteristics such as uneven characteristics, small sample size, and large extreme values, it is difficult to find corresponding invariants.

Tree-based models are not differentiable and cannot be trained jointly with deep learning modules, so creating table-specific deep learning architectures is a very active research area. Many studies have claimed to be able to beat or rival tree-based models, but their studies have been met with much skepticism.

The fact that learning from tabular data lacks established benchmarks gives researchers a lot of freedom when evaluating their methods. Furthermore, most tabular datasets available online are small compared to benchmarks in other machine learning subdomains, making evaluation more difficult.

To alleviate these concerns, researchers from the French National Institute of Information and Automation, the Sorbonne University and other institutions have proposed a tabular data benchmark that can evaluate the latest deep learning models. And show that tree-based models are still SOTA on medium-sized tabular datasets.

For this conclusion, the article gives conclusive evidence that on tabular data, it is easier to achieve good predictions using tree-based methods than deep learning (even modern architectures) , researchers have discovered the reasons.

Why do tree-based models still outperform deep learning on tabular data?

Paper address: https://hal.archives-ouvertes.fr/hal-03723551/document It is worth mentioning that one of the authors of the paper is Gaël Varoquaux, who is one of the leaders of the Scikit-learn project. The project has now become one of the most popular machine learning libraries on GitHub. The article "Scikit-learn: Machine learning in Python" by Gaël Varoquaux has 58,949 citations. ​

Why do tree-based models still outperform deep learning on tabular data?

The contribution of this article can be summarized as:

This study creates a new benchmark (selected 45 open datasets) and share these datasets through OpenML, which makes them easy to use.

This study compares deep learning models and tree-based models under various settings on tabular data and considers the cost of selecting hyperparameters. The study also shares raw results from random searches, which will allow researchers to cheaply test new algorithms for a fixed hyperparameter optimization budget.

On tabular data, tree-based models still outperform deep learning methods

The new benchmark refers to 45 tabular data sets, and the selected benchmarks are as follows:

  • Heterogeneous columns, columns should correspond to features of different nature, thereby excluding image or signal data sets.
  • The dimensionality is low and the d/n ratio of the data set is less than 1/10.
  • Invalid data sets, delete data sets with little available information.
  • I.I.D. (Independently and Identically Distributed) data, removing stream-like data sets or time series.
  • Real world data, remove artificial datasets but keep some simulated datasets.
  • The data set cannot be too small, delete data sets with too few features (
  • Delete data sets that are too simple.
  • Delete data sets for games such as poker and chess because these data sets are deterministic in nature.

Among the tree-based models, the researchers chose three SOTA models: Scikit Learn’s RandomForest, GradientBoostingTrees (GBTs), and XGBoost. The study conducted the following benchmarks on deep models: MLP, Resnet, FT Transformer, SAINT. Figures 1 and 2 show the benchmark results for different types of data sets

Why do tree-based models still outperform deep learning on tabular data?Why do tree-based models still outperform deep learning on tabular data?

Empirical investigation: why tree-based models still outperform deep learning on tabular data

Inductive Bias . Tree-based models beat neural networks across a variety of hyperparameter choices. In fact, the best methods for processing tabular data have two properties in common: they are ensemble methods, bagging (random forests) or boosting (XGBoost, GBT), and the weak learners used in these methods are decision trees.

Finding 1: Neural Network (NN) tends to over-smooth solutions

As shown in Figure 3 It is shown that for smaller scales, smoothing the objective function on the training set significantly reduces the accuracy of tree-based models, but has little impact on NN. These results indicate that the objective function in the dataset is not smooth and that NN has difficulty adapting to these irregular functions compared to tree-based models. This is consistent with the findings of Rahaman et al., who found that NNs are biased toward low-frequency functions. Decision tree-based models learn piece-wise constant functions without such biases.

Why do tree-based models still outperform deep learning on tabular data?

Finding 2: Non-informative features can more affect MLP-like NN

Tabular data sets contain many uninformative features, and for each data set, the study will choose to discard a certain proportion of features (usually sorted by random forest) based on the importance of the features. As can be seen from Figure 4, removing more than half of the features has little impact on the classification accuracy of GBT.

Why do tree-based models still outperform deep learning on tabular data?

Figure 5 It can be seen that removing non-informative features (5a) reduces the difference between MLP (Resnet) and other models ( performance gap between FT Transformers and tree-based models), while adding non-informative features widens the gap, indicating that MLP is less robust to non-informative features. In Figure 5a, when the researcher removes a larger proportion of features, useful information features are also removed accordingly. Figure 5b shows that the accuracy drop caused by removing these features can be compensated by removing non-informative features, which is more helpful to MLP compared with other models (at the same time, this study also removes redundant features and does not affect model performance).

Why do tree-based models still outperform deep learning on tabular data?

Discovery 3: Through rotation, the data is non-invariant

Why is MLP more susceptible to uninformative features compared to other models? One answer is that MLPs are rotation invariant: the process of learning an MLP on the training set and evaluating it on the test set is invariant when rotations are applied to training and test set features. In fact, any rotation-invariant learning process has a worst-case sample complexity that grows linearly at least in the number of irrelevant features. Intuitively, in order to remove useless features, the rotation-invariant algorithm must first find the original orientation of the feature and then select the least informative feature.

Figure 6a shows the change in test accuracy when the dataset is randomly rotated, confirming that only Resnets are rotation invariant. Notably, random rotation reverses the order of performance: the result is NNs above tree-based models and Resnets above FT Transformers, indicating that rotation invariance is undesirable. In fact, tabular data often has individual meanings, such as age, weight, etc. As shown in Figure 6b: Removing the least important half of the features in each dataset (before rotation) reduces the performance of all models except Resnets, but compared to using all features without removing features. , the decline is smaller.

Why do tree-based models still outperform deep learning on tabular data?

The above is the detailed content of Why do tree-based models still outperform deep learning on tabular data?. For more information, please follow other related articles on the PHP Chinese website!

Statement:
This article is reproduced at:51cto.com. If there is any infringement, please contact admin@php.cn delete