RSS
热门关键字:  数据挖掘  数据仓库  人工智能  搜索引擎  数据挖掘导论
当前位置 :| 首页>人工智能>机器学习>

Decision Trees算法及其代码实现

来源: 作者:互联网作品 时间:2007-04-04 点击:

The ML classes discussed in this section implement Classification And Regression Tree algorithms, which is described in [Brieman84].

数据挖掘实验室

The class CvDTree represents a single decision tree that may be used alone, or as a base class in tree ensembles (see Boosting and Random Trees).

数据挖掘研究院

Decision tree is a binary tree (i.e. tree where each non-leaf node has exactly 2 child nodes). It can be used either for classification, when each tree leaf is marked with some class label (multiple leafs may have the same label), or for regression, when each tree leaf is also assigned a constant (so the approximation function is piecewise constant). 数据挖掘研究院

Predicting with Decision Trees 数据挖掘研究院

To reach a leaf node, and thus to obtain a response for the input feature vector, the prediction procedure starts with the root node. From each non-leaf node the procedure goes to the left (i.e. selects the left child node as the next observed node), or to the right based on the value of a certain variable, which index is stored in the observed node. The variable can be either ordered or categorical. In the first case, the variable value is compared with the certain threshold (which is also stored in the node); if the value is less than the threshold, the procedure goes to the left, otherwise, to the right (for example, if the weight is less than 1 kilo, the procedure goes to the left, else to the right). And in the second case the discrete variable value is tested, whether it belongs to a certain subset of values (also stored in the node) from a limited set of values the variable could take; if yes, the procedure goes to the left, else - to the right (for example, if the color is green or red, go to the left, else to the right). That is, in each node, a pair of entities (<variable_index>, <decision_rule (threshold/subset)>) is used. This pair is called split (split on the variable #<variable_index>). Once a leaf node is reached, the value assigned to this node is used as the output of prediction procedure.

Sometimes, certain features of the input vector are missed (for example, in the darkness it is difficult to determine the object color), and the prediction procedure may get stuck in the certain node (in the mentioned example if the node is split by color). To avoid such situations, decision trees use so-called surrogate splits. That is, in addition to the best "primary" split, every tree node may also be split on one or more other variables with nearly the same results.

Training Decision Trees

The tree is built recursively, starting from the root node. The whole training data (feature vectors and the responses) are used to split the root node. In each node the optimum decision rule (i.e. the best "primary" split) is found based on some criteria (in ML gini "purity" criteria is used for classification, and sum of squared errors is used for regression). Then, if necessary, the surrogate splits are found that resemble at the most the results of the primary split on the training data; all data are divided using the primary and the surrogate splits (just like it is done in the prediction procedure) between the left and the right child node. Then the procedure recursively splits both left and right nodes etc. At each node the recursive procedure may stop (i.e. stop splitting the node further) in one of the following cases:

数据挖掘研究院

  • depth of the tree branch being constructed has reached the specified maximum value.
  • number of training samples in the node is less than the specified threshold, i.e. it is not statistically representative set to split the node further.
  • all the samples in the node belong to the same class (or, in case of regression, the variation is too small).
  • the best split found does not give any noticeable improvement comparing to just a random choice.

When the tree is built, it may be pruned using cross-validation procedure, if need. That is, some branches of the tree that may lead to the model overfitting are cut off. Normally, this procedure is only applied to standalone decision trees, while tree ensembles usually build small enough trees and use their own protection schemes against overfitting.

Variable importance

Besides the obvious use of decision trees - prediction, the tree can be also used for various data analysis. One of the key properties of the constructed decision tree algorithms is that it is possible to compute importance (relative decisive power) of each variable. For example, in a spam filter that uses a set of words occurred in the message as a feature vector, the variable importance rating can be used to determine the most "spam-indicating" words and thus help to keep the dictionary size reasonable. 数据挖掘研究院

Importance of each variable is computed over all the splits on this variable in the tree, primary and surrogate ones. Thus, to compute variable importance correctly, the surrogate splits must be enabled in the training parameters, even if there is no missing data. 数据挖掘研究院

[Brieman84] Breiman, L., Friedman, J. Olshen, R. and Stone, C. (1984), "Classification and Regression Trees", Wadsworth. 数据挖掘实验室

CvDTreeSplit

Decision tree node split 数据挖掘研究院

struct CvDTreeSplit
{
    int var_idx;
    int inversed;
    float quality;
    CvDTreeSplit* next;
    union
    {
        int subset[2];
        struct
        {
            float c;
            int split_point;
        }
        ord;
    };
};
 

数据挖掘研究院

var_idx
Index of the variable used in the split
inversed
When it equals to 1, the inverse split rule is used (i.e. left and right branches are exchanged in the expressions below)
quality
The split quality, a positive number. It is used to choose the best primary split, then to choose and sort the surrogate splits. After the tree is constructed, it is also used to compute variable importance.
next
Pointer to the next split in the node split list.
subset
Bit array indicating the value subset in case of split on a categorical variable. The rule is: if var_value in subset then next_node<-left else next_node<-right
c
The threshold value in case of split on an ordered variable. The rule is: if var_value < c then next_node<-left else next_node<-right
split_point
Used internally by the training algorithm.

CvDTreeNode

Decision tree node

struct CvDTreeNode
{
    int class_idx;
    int Tn;
    double value;

    CvDTreeNode* parent;
    CvDTreeNode* left;
    CvDTreeNode* right;

    CvDTreeSplit* split;

    int sample_count;
    int depth;
    ...
};
 数据挖掘研究院 
value
The value assigned to the tree node. It is either a class label, or the estimated function value.
class_idx
The assigned to the node normalized class index (to 0..class_count-1 range), it is used internally in classification trees and tree ensembles.
Tn
The tree index in a ordered sequence of trees. The indices are used during and after the pruning procedure. The root node has the maximum value Tn of the whole tree, child nodes have Tn less than or equal to the parent's Tn, and the nodes with Tn≤CvDTree::pruned_tree_idx are not taken into consideration at the prediction stage (the corresponding branches are considered as cut-off), even if they have not been physically deleted from the tree at the pruning stage.
parent, left, right
Pointers to the parent node, left and right child nodes.
split
Pointer to the first (primary) split.
sample_count
The number of samples that fall into the node at the training stage. It is used to resolve the difficult cases - when the variable for the primary split is missing, and all the variables for other surrogate splits are missing too, the sample is directed to the left if left->sample_count>right->sample_count and to the right otherwise.
depth
The node depth, the root node depth is 0, the child nodes depth is the parent's depth + 1.

Other numerous fields of CvDTreeNode are used internally at the training stage.

数据挖掘实验室

CvDTreeParams

Decision tree training parameters

数据挖掘研究院

struct CvDTreeParams
{
    int max_categories;
    int max_depth;
    int min_sample_count;
    int cv_folds;
    bool use_surrogates;
    bool use_1se_rule;
    bool truncate_pruned_tree;
    float regression_accuracy;
    const float* priors;

    CvDTreeParams() : max_categories(10), max_depth(INT_MAX), min_sample_count(10),
        cv_folds(10), use_surrogates(true), use_1se_rule(true),
        truncate_pruned_tree(true), regression_accuracy(0.01f), priors(0)
    {}

    CvDTreeParams( int _max_depth, int _min_sample_count,
                   float _regression_accuracy, bool _use_surrogates,
                   int _max_categories, int _cv_folds,
                   bool _use_1se_rule, bool _truncate_pruned_tree,
                   const float* _priors );
};
 

数据挖掘研究院

max_depth
This parameter specifies the maximum possible depth of the tree. That is the training algorithms attempts to split a node while its depth is less than max_depth. The actual depth may be smaller if the other termination criteria are met (see the outline of the training procedure in the beginning of the section), and/or if the tree is pruned.
min_sample_count
A node is not split if the number of samples directed to the node is less than the parameter value.
regression_accuracy
Another stop criteria - only for regression trees. As soon as the estimated node value differs from the node training samples responses by less than the parameter value, the node is not split further.
use_surrogates
If true, surrogate splits are built. Surrogate splits are needed to handle missing measurements and for variable importance estimation.
max_categories
If a discrete variable, on which the training procedure tries to make a split, takes more than max_categories values, the precise best subset estimation may take a very long time (as the algorithm is exponential). Instead, many decision trees engines (including ML) try to find sub-optimal split in this case by clustering all the samples into max_categories clusters (i.e. some categories are merged together).
   Note that this technique is used only in N(>2)-class classification problems. In case of regression and 2-class classification the optimal split can be found efficiently without employing clustering, thus the parameter is not used in these cases. 
 

数据挖掘研究院

cv_folds
If this parameter is >1, the tree is pruned using cv_folds-fold cross validation.
use_1se_rule
If true, the tree is truncated a bit more by the pruning procedure. That leads to compact, and more resistant to the training data noise, but a bit less accurate decision tree.
truncate_pruned_tree
If true, the cut off nodes (with Tn≤CvDTree::pruned_tree_idx) are physically removed from the tree. Otherwise they are kept, and by decreasing CvDTree::pruned_tree_idx (e.g. setting it to -1) it is still possible to get the results from the original unpruned (or pruned less aggressively) tree.
priors
The array of a priori class probabilities, sorted by the class label value. The parameter can be used to tune the decision tree preferences toward a certain class. For example, if users want to detect some rare anomaly occurrence, the training base will likely contain much more normal cases than anomalies, so a very good classification performance will be achieved just by considering every case as normal. To avoid this, the priors can be specified, where the anomaly probability is artificially increased (up to 0.5 or even greater), so the weight of the misclassified anomalies becomes much bigger, and the tree is adjusted properly.
A note about memory management: the field priors is a pointer to the array of floats. The array should be allocated by user, and released just after the CvDTreeParams structure is passed to CvDTreeTrainData or CvDTree constructors/methods (as the methods make a copy of the array).

The structure contains all the decision tree training parameters. There is a default constructor that initializes all the parameters with the default values tuned for standalone classification tree. Any of the parameters can be overridden then, or the structure may be fully initialized using the advanced variant of the constructor.

数据挖掘研究院

CvDTreeTrainData

Decision tree training data and shared data for tree ensembles 数据挖掘研究院

struct CvDTreeTrainData
{
    CvDTreeTrainData();
    CvDTreeTrainData( const CvMat* _train_data, int _tflag,
                      const CvMat* _responses, const CvMat* _var_idx=0,
                      const CvMat* _sample_idx=0, const CvMat* _var_type=0,
                      const CvMat* _missing_mask=0,
                      const CvDTreeParams& _params=CvDTreeParams(),
                      bool _shared=false, bool _add_labels=false );
    virtual ~CvDTreeTrainData();

    virtual void set_data( const CvMat* _train_data, int _tflag,
                          const CvMat* _responses, const CvMat* _var_idx=0,
                          const CvMat* _sample_idx=0, const CvMat* _var_type=0,
                          const CvMat* _missing_mask=0,
                          const CvDTreeParams& _params=CvDTreeParams(),
                          bool _shared=false, bool _add_labels=false,
                          bool _update_data=false );

    virtual void get_vectors( const CvMat* _subsample_idx,
         float* values, uchar* missing, float* responses, bool get_class_idx=false );

    virtual CvDTreeNode* subsample_data( const CvMat* _subsample_idx );

    virtual void write_params( CvFileStorage* fs );
    virtual void read_params( CvFileStorage* fs, CvFileNode* node );

    // release all the data
    virtual void clear();

    int get_num_classes() const;
    int get_var_type(int vi) const;
    int get_work_var_count() const;

    virtual int* get_class_labels( CvDTreeNode* n );
    virtual float* get_ord_responses( CvDTreeNode* n );
    virtual int* get_labels( CvDTreeNode* n );
    virtual int* get_cat_var_data( CvDTreeNode* n, int vi );
    virtual CvPair32s32f* get_ord_var_data( CvDTreeNode* n, int vi );
    virtual int get_child_buf_idx( CvDTreeNode* n );

    ////////////////////////////////////

    virtual bool set_params( const CvDTreeParams& params );
    virtual CvDTreeNode* new_node( CvDTreeNode* parent, int count,
                                   int storage_idx, int offset );

    virtual CvDTreeSplit* new_split_ord( int vi, float cmp_val,
                int split_point, int inversed, float quality );
    virtual CvDTreeSplit* new_split_cat( int vi, float quality );
    virtual void free_node_data( CvDTreeNode* node );
    virtual void free_train_data();
    virtual void free_node( CvDTreeNode* node );

    int sample_count, var_all, var_count, max_c_count;
    int ord_var_count, cat_var_count;
    bool have_labels, have_priors;
    bool is_classifier;

    int buf_count, buf_size;
    bool shared;

    CvMat* cat_count;
    CvMat* cat_ofs;
    CvMat* cat_map;

    CvMat* counts;
    CvMat* buf;
    CvMat* direction;
    CvMat* split_buf;

    CvMat* var_idx;
    CvMat* var_type; // i-th element =
                     //   k<0  - ordered
                     //   k>=0 - categorical, see k-th element of cat_* arrays
    CvMat* priors;

    CvDTreeParams params;

    CvMemStorage* tree_storage;
    CvMemStorage* temp_storage;

    CvDTreeNode* data_root;

    CvSet* node_heap;
    CvSet* split_heap;
    CvSet* cv_heap;
    CvSet* nv_heap;

    CvRNG rng;
};
  

This structure is mostly used internally for storing both standalone trees and tree ensembles efficiently. Basically, it contains 3 types of information: 数据挖掘研究院

  1. The training parameters, CvDTreeParams instance.
  2. The training data, preprocessed in order to find the best splits more efficiently. For tree ensembles this preprocessed data is reused by all the trees. Additionally, the training data characteristics that are shared by all trees in the ensemble are stored here: variable types, the number of classes, class label compression map etc.
  3. Buffers, memory storages for tree nodes, splits and other elements of the trees constructed.

There are 2 ways of using this structure. In simple cases (e.g. standalone tree, or ready-to-use "black box" tree ensemble from ML, like Random Trees or Boosting) there is no need to care or even to know about the structure - just construct the needed statistical model, train it and use it. The CvDTreeTrainData structure will be constructed and used internally. However, for custom tree algorithms, or another sophisticated cases, the structure may be constructed and used explicitly. The scheme is the following:

数据挖掘研究院

  1. The structure is initialized using the default constructor, followed by set_data (or it is built using the full form of constructor). The parameter _shared must be set to true.
  2. One or more trees are trained using this data, see the special form of the method CvDTree::train.
  3. Finally, the structure can be released only after all the trees using it are released.

CvDTree

Decision tree

数据挖掘研究院

class CvDTree : public CvStatModel
{
public:
    CvDTree();
    virtual ~CvDTree();

    virtual bool train( const CvMat* _train_data, int _tflag,
                        const CvMat* _responses, const CvMat* _var_idx=0,
                        const CvMat* _sample_idx=0, const CvMat* _var_type=0,
                        const CvMat* _missing_mask=0,
                        CvDTreeParams params=CvDTreeParams() );

    virtual bool train( CvDTreeTrainData* _train_data, const CvMat* _subsample_idx );

    virtual CvDTreeNode* predict( const CvMat* _sample, const CvMat* _missing_data_mask=0,
                                  bool raw_mode=false ) const;
    virtual const CvMat* get_var_importance();
    virtual void clear();

    virtual void read( CvFileStorage* fs, CvFileNode* node );
    virtual void write( CvFileStorage* fs, const char* name );

    // special read & write methods for trees in the tree ensembles
    virtual void read( CvFileStorage* fs, CvFileNode* node,
                       CvDTreeTrainData* data );
    virtual void write( CvFileStorage* fs );

    const CvDTreeNode* get_root() const;
    int get_pruned_tree_idx() const;
    CvDTreeTrainData* get_data();

protected:

    virtual bool do_train( const CvMat* _subsample_idx );

    virtual void try_split_node( CvDTreeNode* n );
    virtual void split_node_data( CvDTreeNode* n );
    virtual CvDTreeSplit* find_best_split( CvDTreeNode* n );
    virtual CvDTreeSplit* find_split_ord_class( CvDTreeNode* n, int vi );
    virtual CvDTreeSplit* find_split_cat_class( CvDTreeNode* n, int vi );
    virtual CvDTreeSplit* find_split_ord_reg( CvDTreeNode* n, int vi );
    virtual CvDTreeSplit* find_split_cat_reg( CvDTreeNode* n, int vi );
    virtual CvDTreeSplit* find_surrogate_split_ord( CvDTreeNode* n, int vi );
    virtual CvDTreeSplit* find_surrogate_split_cat( CvDTreeNode* n, int vi );
    virtual double calc_node_dir( CvDTreeNode* node );
    virtual void complete_node_dir( CvDTreeNode* node );
    virtual void cluster_categories( const int* vectors, int vector_count,
        int var_count, int* sums, int k, int* cluster_labels );

    virtual void calc_node_value( CvDTreeNode* node );

    virtual void prune_cv();
    virtual double update_tree_rnc( int T, int fold );
    virtual int cut_tree( int T, int fold, double min_alpha );
    virtual void free_prune_data(bool cut_tree);
    virtual void free_tree();

    virtual void write_node( CvFileStorage* fs, CvDTreeNode* node );
    virtual void write_split( CvFileStorage* fs, CvDTreeSplit* split );
    virtual CvDTreeNode* read_node( CvFileStorage* fs, CvFileNode* node, CvDTreeNode* parent );
    virtual CvDTreeSplit* read_split( CvFileStorage* fs, CvFileNode* node );
    virtual void write_tree_nodes( CvFileStorage* fs );
    virtual void read_tree_nodes( CvFileStorage* fs, CvFileNode* node );

    CvDTreeNode* root;

    int pruned_tree_idx;
    CvMat* var_importance;

    CvDTreeTrainData* data;
};
 

数据挖掘实验室

CvDTree::train

Trains decision tree

数据挖掘研究院

bool CvDTree::train( const CvMat* _train_data, int _tflag,
                     const CvMat* _responses, const CvMat* _var_idx=0,
                     const CvMat* _sample_idx=0, const CvMat* _var_type=0,
                     const CvMat* _missing_mask=0,
                     CvDTreeParams params=CvDTreeParams() );

bool CvDTree::train( CvDTreeTrainData* _train_data, const CvMat* _subsample_idx );
 数据挖掘研究院 

There are 2 train methods in CvDTree. 数据挖掘研究院

The first method follows the generic CvStatModel::train conventions, it is the most complete form of it. Both data layouts (_tflag=CV_ROW_SAMPLE and _tflag=CV_COL_SAMPLE) are supported, as well as sample and variable subsets, missing measurements, arbitrary combinations of input and output variable types etc. The last parameter contains all the necessary training parameters, see CvDTreeParams description. 数据挖掘研究院

The second method train is mostly used for building tree ensembles. It takes the pre-constructed CvDTreeTrainData instance and the optional subset of training set. The indices in _subsample_idx are counted relatively to the _sample_idx, passed to CvDTreeTrainData constructor. For example, if _sample_idx=[1, 5, 7, 100], then _subsample_idx=[0,3] means that the samples [1, 100] of the original training set are used. 数据挖掘研究院

CvDTree::predict

Returns the leaf node of decision tree corresponding to the input vector

CvDTreeNode* CvDTree::predict( const CvMat* _sample, const CvMat* _missing_data_mask=0,
                               bool raw_mode=false ) const;
  

The method takes the feature vector and the optional missing measurement mask on input, traverses the decision tree and returns the reached leaf node on output. The prediction result, either the class label or the estimated function value, may be retrieved as value field of the CvDTreeNode structure, for example: dtree->predict(sample,mask)->value 数据挖掘研究院

The last parameter is normally set to false that implies a regular input. If it is true, the method assumes that all the values of the discrete input variables have been already normalized to 0..<num_of_categoriesi>-1 ranges. (as the decision tree uses such normalized representation internally). It is useful for faster prediction with tree ensembles. For ordered input variables the flag is not used. Example. Building Tree for Classifying Mushrooms

数据挖掘研究院

See mushroom.cpp sample that demonstrates how to build and use the decision tree.

数据挖掘实验室

最新评论共有 0 位网友发表了评论
发表评论
评论内容:不能超过250字,需审核,请自觉遵守互联网相关政策法规。
匿名?