?? markovnet.cpp
字號:
#include "MarkovNet.h"// DEBUG//#define DEBUG//#define DEBUG2//#define DEBUG3MarkovNet::MarkovNet(const BayesNet& bn) : nodes(new Node[bn.getNumVars() * 2]), numNodes(0){ VarSchema schema = bn.getSchema(); // Create one node per original variable int numVars = schema.getNumVars(); for (int i = 0; i < numVars; i++) { nodes[i].index = i; nodes[i].fixedValue = -1; nodes[i].marginal = Distribution(schema.getRange(i)); nodes[i].phi = Distribution(schema.getRange(i)); numNodes++; }#ifdef DEBUG cout << "Added one node per original variable.\n";#endif // Link nodes with edges and potential functions; add extra nodes // as necessary. for (int i = 0; i < numVars; i++) {#ifdef DEBUG cout << "Linking node " << i << endl;#endif DecisionTree& dtree = *bn.decisionTrees[i]; const list<int>& parents = bn.parents[i]; if (parents.size() == 0) {#ifdef DEBUG cout << " No parents; done.\n";#endif // Construct a marginal distribution from the decision tree nodes[i].phi = Distribution(schema.getRange(i)); for (int val = 0; val < schema.getRange(i); val++) { nodes[i].phi[val] = dtree.getProb(val, NULL); } } else if (parents.size() == 1) {#ifdef DEBUG cout << " One parent; done.\n";#endif Potential psi(dtree, schema); addEdge(&nodes[i], &nodes[parents.front()], psi); } else { // Multiple parents case#ifdef DEBUG cout << " Multiple parents.\n";#endif // Get total dimension of inputs to decision tree int inputDim = 1; list<int>::const_iterator p; for (p = parents.begin(); p != parents.end(); p++ ) { inputDim *= schema.getRange(*p); } // Add mediator node to represent all values of all parents Node* mediator = &nodes[numNodes++]; mediator->index = numNodes-1; mediator->fixedValue = -1; mediator->marginal = Distribution(inputDim); mediator->phi = Distribution(inputDim); // Add edges from mediator node to actual parents, // to enforce value consistency. int cumulativeDim = 1; for (p = parents.begin(); p != parents.end(); p++) { int parentDim = schema.getRange(*p); Potential psi(inputDim, parentDim); for (int r = 0; r < inputDim; r++) { for (int c = 0; c < parentDim; c++) { psi.set(r, c, 0.0); } int parentVal = (r/cumulativeDim) % parentDim; psi.set(r, parentVal, 1.0); } cumulativeDim *= parentDim; addEdge(mediator, &nodes[*p], psi); // DEBUG //cout << "psi(" << *p << ") = " << psi << endl; } // Finally, add an edge from the child to the mediator // using its distribution over parents from the Bayes net. Potential psi(dtree, schema); addEdge(&nodes[i], mediator, psi); } }}void MarkovNet::addEdge(Node* n1, Node* n2, Potential& psi){ Edge e; edges.push_back(e); edges.back().n1 = n1; edges.back().n2 = n2; edges.back().psi = psi;#ifdef DEBUG3 cout << "Psi = " << psi << endl;#endif n1->edges.push_back(&edges.back()); n2->edges.push_back(&edges.back());}#if 0double MarkovNet::getLikelihood(VarSet query, VarSet evidence, double threshold, double damping){ // DEBUG cout << "Resetting all nodes\n"; // Set evidence resetAllNodes(); // DEBUG cout << "Resetting all messages\n"; resetAllMessages(); // DEBUG cout << "Setting all evidence\n"; for (int i = 0; i < evidence.getNumVars(); i++) { if (evidence.isTested(i)) { fixNodeValue(i, (int)evidence[i]); } } // DEBUG cout << "Running belief propagation\n"; // Run belief propogation to convergence runBP(threshold, damping); // Compute product of marginal distributions double prob = 1.0; for (int i = 0; i < query.getNumVars(); i++) { if (query.isTested(i) && !evidence.isTested(i)) { prob *= getMarginal(i).get((int)query[i]); } } return prob;}#endifvoid MarkovNet::runBP(double convergenceThreshold, double dampingFactor){ list<Edge*>::iterator neighbor; // For all fixed nodes, we need only send our messages once. for (int n = 0; n < numNodes; n++) { // Only consider fixed nodes at this point if (nodes[n].fixedValue < 0) { continue; } // Send messages to each neighbor for (neighbor = nodes[n].edges.begin(); neighbor != nodes[n].edges.end(); neighbor++) { Edge* currEdge = *neighbor; // Don't send messages to other fixed nodes if (currEdge->otherNode(&nodes[n])->fixedValue >= 0) { continue; } // Send the messages currEdge->sendMsg(&nodes[n], nodes[n].marginal); } } // For nodes with only fixed neighbors, compute marginals exactly once for (int n = 0; n < numNodes; n++) { // Already fixed if (nodes[n].fixedValue >= 0) { continue; } // Assume fixed until proven otherwise nodes[n].fixedValue = 0; for (neighbor = nodes[n].edges.begin(); neighbor != nodes[n].edges.end(); neighbor++) { // Neighbor not fixed: we guessed wrong if ((*neighbor)->otherNode(&nodes[n])->fixedValue < 0) { nodes[n].fixedValue = -1; break; } } // Compute marginal from all incoming messages if (nodes[n].fixedValue >= 0) { nodes[n].marginal = nodes[n].phi; for (neighbor = nodes[n].edges.begin(); neighbor != nodes[n].edges.end(); neighbor++) { nodes[n].marginal *= (*neighbor)->msgTo(&nodes[n]);#ifdef DEBUG2 cout << "Pre-incoming: " << (*neighbor)->msgTo(&nodes[n]) << endl;#endif }#ifdef DEBUG2 cout << "New marginal: " << nodes[n].marginal << endl;#endif nodes[n].marginal.normalize(); } } int maxIters = 1000; double delta = 0.0; int i; for (i = 0; i < maxIters; i++) { delta = BPiter(dampingFactor); if (delta < convergenceThreshold) { break; }#ifdef DEBUG cout << "Iteration " << i << ": delta = " << delta << endl;#endif }#ifdef DEBUG if (i == maxIters) { cout << "Did not converge after " << maxIters << " iterations.\n"; } else { cout << "Successfully converged in " << i << " iterations.\n"; } cout << "Final delta: " << delta << endl;#endif return;}double MarkovNet::BPiter(double dampingFactor){ double delta = 1.0; list<Edge*>::iterator neighbor; list<Edge*>::reverse_iterator rneighbor; int index; for (int i = 0; i < numNodes; i++) { // Skip fixed nodes; their messages have already been sent if (nodes[i].fixedValue >= 0) { continue; } int numNeighbors = nodes[i].edges.size(); vector<Distribution> forwardDistribs(numNeighbors+1); vector<Distribution> reverseDistribs(numNeighbors); // Compute products of first n messages and phi (the prior // distribution at this node), for all n forwardDistribs[0] = nodes[i].phi; index = 1; for (neighbor = nodes[i].edges.begin(); neighbor != nodes[i].edges.end(); neighbor++) { // Multiply previous distrib by the incoming message forwardDistribs[index] = forwardDistribs[index-1]; forwardDistribs[index] *= (*neighbor)->msgTo(&nodes[i]);#ifdef DEBUG2 cout << "Incoming message: " << (*neighbor)->msgTo(&nodes[i]) << endl;#endif index++; } // Compute products of last n messages, for all n // (Except for product of all messages, which is redundant.) reverseDistribs[0] = Distribution(nodes[i].phi.dim()); index = 1; for (rneighbor = nodes[i].edges.rbegin(); index < numNeighbors; rneighbor++) { // Multiply by the incoming message on each edge reverseDistribs[index] = reverseDistribs[index-1]; reverseDistribs[index] *= (*rneighbor)->msgTo(&nodes[i]); index++; } // The marginal is just the product of all messages and phi, // which we've computed above. Distribution marginal = forwardDistribs[numNeighbors]; marginal.normalize();#ifdef DEBUG2 cout << "Marginal: " << marginal << endl;#endif // Compute messages to send to neighbors index = 0; for (neighbor = nodes[i].edges.begin(); neighbor != nodes[i].edges.end(); neighbor++, index++) { Edge* currEdge = *neighbor; // Don't send messages to fixed nodes if (currEdge->otherNode(&nodes[i])->fixedValue >= 0) { continue; } // Compute message Distribution message = forwardDistribs[index]; if (index != numNeighbors - 1) { message *= reverseDistribs[numNeighbors - index - 1]; } // Distribute message currEdge->sendMsg(&nodes[i], message);#if 0 // HACK DEBUG -- Comparing messages! Distribution outgoing = nodes[i].phi; if (nodes[i].fixedValue < 0) { for (list<Edge*>::iterator n2 = nodes[i].edges.begin(); n2 != nodes[i].edges.end(); n2++) { if (*n2 != currEdge) { outgoing *= ((*n2)->n1 == &nodes[i]) ? (*n2)->message2to1 : (*n2)->message1to2; } } } else { outgoing = nodes[i].marginal; } // DEBUG -- Report inconsistencies between old and new methods if (nodes[i].fixedValue < 0) { bool printStuff = false; for (int a = 0; a < message.dim(); a++) { if (outgoing[a]/message[a] > 1.1 || message[a]/outgoing[a] > 1.1) { printStuff = true; } } if (printStuff) { for (int a = 0; a < numNeighbors; a++) { cout << "Forward " << a << ": " << forwardDistribs[a] << endl; } for (int a = 0; a < numNeighbors; a++) { cout << "Reverse " << a << ": " << reverseDistribs[a] << endl; } cout << "Old: " << outgoing << endl; cout << "New: " << message << endl; cout << "Marginal: " << marginal << endl; cout << "Index: " << index << endl; cout << "Num neighbors: " << numNeighbors << endl; } } // HACK message = outgoing;#endif } // Compare to previous marginal at this node // and store the new one. for (int x_i = 0; x_i < marginal.dim(); x_i++) { double delta_x_i = marginal[x_i]/nodes[i].marginal[x_i]; if (delta_x_i < 1.0) { delta_x_i = 1.0/delta_x_i; } if (delta < delta_x_i) { delta = delta_x_i; } } nodes[i].marginal = marginal; } return delta;}
?? 快捷鍵說明
復制代碼
Ctrl + C
搜索代碼
Ctrl + F
全屏模式
F11
切換主題
Ctrl + Shift + D
顯示快捷鍵
?
增大字號
Ctrl + =
減小字號
Ctrl + -