Home > Backend Development > Python Tutorial > How to Extract Decision Rules from scikit-learn Decision Trees?

How to Extract Decision Rules from scikit-learn Decision Trees?

Mary-Kate Olsen
Release: 2024-10-27 09:14:03
Original
1031 people have browsed it

How to Extract Decision Rules from scikit-learn Decision Trees?

Decision Rule Extraction from scikit-learn Decision Trees

Problem Statement:

Can the decision rules underlying a trained decision tree model be extracted as a textual list?

Solution:

Using the tree_to_code function, it is possible to generate a valid Python function that represents the decision rules of a scikit-learn decision tree:

<code class="python">from sklearn.tree import _tree

def tree_to_code(tree, feature_names):
    tree_ = tree.tree_
    feature_name = [
        feature_names[i] if i != _tree.TREE_UNDEFINED else "undefined!"
        for i in tree_.feature
    ]
    print("def tree({}):".format(", ".join(feature_names)))

    def recurse(node, depth):
        indent = "  " * depth
        if tree_.feature[node] != _tree.TREE_UNDEFINED:
            name = feature_name[node]
            threshold = tree_.threshold[node]
            print("{}if {} <= {}:".format(indent, name, threshold))
            recurse(tree_.children_left[node], depth + 1)
            print("{}else:  # if {} > {}".format(indent, name, threshold))
            recurse(tree_.children_right[node], depth + 1)
        else:
            print("{}return {}".format(indent, tree_.value[node]))

    recurse(0, 1)</code>
Copy after login

Example:

For a decision tree that attempts to return its input (a number between 0 and 10), the tree_to_code function would print the following Python function:

<code class="python">def tree(f0):
  if f0 <= 6.0:
    if f0 <= 1.5:
      return [[ 0.]]
    else:  # if f0 > 1.5
      if f0 <= 4.5:
        if f0 <= 3.5:
          return [[ 3.]]
        else:  # if f0 > 3.5
          return [[ 4.]]
      else:  # if f0 > 4.5
        return [[ 5.]]
  else:  # if f0 > 6.0
    if f0 <= 8.5:
      if f0 <= 7.5:
        return [[ 7.]]
      else:  # if f0 > 7.5
        return [[ 8.]]
    else:  # if f0 > 8.5
      return [[ 9.]]</code>
Copy after login

Caveats:

Avoid the following common issues:

  • Do not rely on tree_.threshold == -2 to identify leaf nodes; check tree.feature or tree.children_* instead.
  • Specify feature names correctly, avoiding integers that may correspond to undefined features.
  • Single if statements suffice in the recursive function.

The above is the detailed content of How to Extract Decision Rules from scikit-learn Decision Trees?. For more information, please follow other related articles on the PHP Chinese website!

source:php.cn
Statement of this Website
The content of this article is voluntarily contributed by netizens, and the copyright belongs to the original author. This site does not assume corresponding legal responsibility. If you find any content suspected of plagiarism or infringement, please contact admin@php.cn
Latest Articles by Author
Popular Tutorials
More>
Latest Downloads
More>
Web Effects
Website Source Code
Website Materials
Front End Template