?? markovnet.h
字號:
#ifndef MARKOVNET_H#define MARKOVNET_H/* Definition for the MarkovNet class, representing a pairwise Markov * random field. Includes an implementation of belief propagation. * * Daniel Lowd <lowd@cs.washington.edu> */#include "Potential.h"#include "Distribution.h"#include "VarSet.h"#include "BayesNet.h"#include <list>class Edge;class Node{public: int index; list<Edge*> edges; int fixedValue; Distribution marginal; Distribution phi; Node() { /* NOP */ }};class Edge{public: Node* n1; Node* n2; Distribution message1to2; Distribution message2to1; Potential psi; Edge() { /* NOP */ } Distribution& msgTo(Node* n) { return (n == n1) ? message2to1 : message1to2; } Distribution& msgFrom(Node* n) { return (n == n2) ? message2to1 : message1to2; } Node* otherNode(Node* n) { return (n == n1) ? n2 : n1; } void sendMsg(Node* n, Distribution& message) { if (n == n1) { message1to2 = message * psi; message1to2.normalize(); } else { message2to1 = psi * message; message2to1.normalize(); } }};class MarkovNet{ // HACK: maybe this ought to be a vector? Node* nodes; int numNodes; list<Edge> edges;public: // Convert a Bayesian network to a pairwise Markov random field MarkovNet(const BayesNet& bn); ~MarkovNet() { delete [] nodes; } void fixNodeValue(int nodeIndex, int value) { nodes[nodeIndex].fixedValue = value; for (int i = 0; i < nodes[nodeIndex].marginal.dim(); i++) { nodes[nodeIndex].marginal[i] = ((i == value) ? 1 : 0); } } void resetNode(int nodeIndex) { nodes[nodeIndex].fixedValue = -1; } void resetAllNodes() { for (int i = 0; i < numNodes; i++) { resetNode(i); } } void resetAllMessages() { for (int i = 0; i < numNodes; i++) { // Reset all messages in one direction to uniform list<Edge*>::iterator e; for (e = nodes[i].edges.begin(); e != nodes[i].edges.end(); e++) { (*e)->message1to2 = Distribution((*e)->n2->marginal.dim()); (*e)->message2to1 = Distribution((*e)->n2->marginal.dim()); } } } Distribution getMarginal(int nodeIndex) const { return nodes[nodeIndex].marginal; } double getLikelihood(VarSet query, VarSet evidence, double threshold, double damping); void runBP(double convergenceThreshold, double dampingFactor);private: double BPiter(double dampingFactor); void addEdge(Node* n1, Node* n2, Potential& psi);};#endif // ndef MARKOVNET_H
?? 快捷鍵說明
復制代碼
Ctrl + C
搜索代碼
Ctrl + F
全屏模式
F11
切換主題
Ctrl + Shift + D
顯示快捷鍵
?
增大字號
Ctrl + =
減小字號
Ctrl + -