?? decisiontree.cpp
字號:
#include "DecisionTree.h"using namespace DT;DecisionTree::~DecisionTree() { if (headVertex == NULL) { delete headVertex; } else { delete currLeaf; }}double DecisionTree::getProb(double state, double* allStates) const{ const Leaf* l = getLeaf(allStates);#ifdef DEBUG if (l == NULL) { cout << "ERROR: could not find leaf for states!\n"; // HACK cout << "us = " << var << " (" << allStates[var] << ")\n"; return 0; } // DEBUG if (l->getProb(state) == 0.0) { cout << "ERROR: leaf returned less than zero in DT::getProb!\n"; }#endif return l->getProb(state);}double DecisionTree::getLogProb(double state, double* allStates) const{ const Leaf* l = getLeaf(allStates);#ifdef DEBUG if (l == NULL) { cout << "ERROR: could not find leaf for states!\n"; // HACK cout << "us = " << var << " (" << allStates[var] << ")\n"; return 0; } // DEBUG if (l->getLogProb(state) == 0.0) { cout << "ERROR: leaf returned zero in DT::getLogProb!\n"; }#endif return l->getLogProb(state);}double DecisionTree::sample(double* allStates) const{ const Leaf* l = getLeaf(allStates);#ifdef DEBUG if (l == NULL) { cout << "ERROR: could not find leaf for states!\n"; // HACK cout << "us = " << var << " (" << allStates[var] << ")\n"; return 0; }#endif return l->sample();}list<double> DecisionTree::getSplits(int var) const{ list<double> ret; if (headVertex != NULL) { headVertex->getSplits(var, ret); } return ret;}const Leaf* DecisionTree::getLeaf(double* allStates) const{ if (headVertex == NULL) { if (currLeaf == NULL) { cout << "ERROR: no leaf found in getLeaf()!\n"; return NULL; } else { return currLeaf; } }#if 0 return headVertex->getLeaf(allStates);#else const Vertex* currVertex = headVertex; while (currVertex) { list<Branch*>::const_iterator c = currVertex->children.begin(); //for (c = currVertex->children.begin(); c != currVertex->children.end(); c++) if (!(*c)->inRange(allStates[currVertex->split])) { c++; } if ((*c)->childLeaf) { return (*c)->childLeaf; } else { currVertex = (*c)->childVertex; } } cout << "ERROR: no leaf found!\n"; return NULL;#endif}void DecisionTree::beginVertex(int splitVar){ // DEBUG //cout << "Starting vertex!\n"; currVertex = new Vertex(splitVar, currBranch); if (headVertex == NULL) { // If there's no root vertex, set it as such headVertex = currVertex; } else { // Otherwise, add it as the child of a branch currBranch->setVertex(currVertex); }}void DecisionTree::endVertex(){ // DEBUG //cout << "Ending vertex!\n"; if (currVertex != headVertex) { currVertex = currBranch->getParent(); }}void DecisionTree::beginBranch() { // DEBUG //cout << "Adding branch!\n"; currBranch = new Branch(currVertex); currVertex->addChild(currBranch);}void DecisionTree::endBranch() { // DEBUG //cout << "Ending branch!\n"; // Move up in the tree currBranch = currVertex->getParent();}void DecisionTree::endValues(list<Range> values) { currBranch->setValues(values);}void DecisionTree::beginMultinomial() { currLeaf = currMultinomial = new Multinomial(maxVal); leafList.push_back(currLeaf); if (currBranch != NULL) { currBranch->setLeaf(currLeaf); }}void DecisionTree::endProbs(double* probs) { currMultinomial->setProbs(probs);}void DecisionTree::endBinGaussian() { currLeaf = new BinGaussian(currMean, currSD, currProbMissing); leafList.push_back(currLeaf); if (currBranch != NULL) { currBranch->setLeaf(currLeaf); }}namespace DT {Vertex::~Vertex(){ list<Branch*>::iterator c; for (c = children.begin(); c != children.end(); c++) { delete (*c); }}void Vertex::getSplits(int var, list<double>& splits){ // Search for splits in each branch for (list<Branch*>::iterator i = children.begin(); i != children.end(); i++) { (*i)->getSplits(var, splits); }}Leaf* Vertex::getLeaf(double* allStates){ Leaf* ret = NULL; for (list<Branch*>::iterator i = children.begin(); i != children.end() && ret == NULL; i++) { ret = (*i)->getLeaf(allStates); } if (ret == NULL) { cout << "ERROR: no leaf found, split = " << split << " (" << allStates[split] << ")" << endl; } return ret;}Branch::~Branch(){ // Delete children in destructor (but not parent!) delete childVertex; delete childLeaf; }bool Branch::inRange(double val){ for (list<Range>::iterator i = values.begin(); i != values.end(); i++) { if (i->inRange(val)) { return true; } } return false;}void Branch::getSplits(int var, list<double>& splits){ // Add values from this branch, if we're splitting on the right var if (parent->getSplitVar() == var) { Range r = *(values.begin()); if (!r.missing && r.minV > -HUGE) { splits.push_back(r.minV); } } if (childVertex != NULL) { childVertex->getSplits(var, splits); }}Leaf* Branch::getLeaf(double* allVals) { if (!inRange(allVals[parent->getSplitVar()])) { return NULL; } else if (childLeaf != NULL) { return childLeaf; } else if (childVertex != NULL) { return childVertex->getLeaf(allVals); } else { cout << "ERROR: missing leaf or branch in DecisionTree!\n"; return NULL; }}};
?? 快捷鍵說明
復制代碼
Ctrl + C
搜索代碼
Ctrl + F
全屏模式
F11
切換主題
Ctrl + Shift + D
顯示快捷鍵
?
增大字號
Ctrl + =
減小字號
Ctrl + -