?? node.h
字號:
// -*- C++ -*-//// This file is a part of the Bayes Blocks library//// Copyright (C) 2001-2006 Markus Harva, Antti Honkela, Alexander// Ilin, Tapani Raiko, Harri Valpola and Tomas 謘tman.//// This program is free software; you can redistribute it and/or modify// it under the terms of the GNU General Public License as published by// the Free Software Foundation; either version 2, or (at your option)// any later version.//// This program is distributed in the hope that it will be useful,// but WITHOUT ANY WARRANTY; without even the implied warranty of// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the// GNU General Public License (included in file License.txt in the// program package) for more details.//// $Id: Node.h 7 2006-10-26 10:26:41Z ah $#ifndef NODE_H#define NODE_H#include <string>#include <map>#include "Templates.h"#include "Saver.h"#include "Loader.h"#include "Decay.h"#include "Net.h"class Node;#ifndef BUILDING_SWIG_INTERFACEtypedef bool BOOLASOBJ;#endifenum partype_e { REAL_MV, REAL_ME, REAL_M, REALV_MV, REALV_ME, REALV_M, DISCRETE, DISCRETEV};class NodeBase{public: virtual ~NodeBase() { } virtual int ParIdentity(const Node *ptr) = 0; virtual size_t NumParents() = 0; virtual Node *GetParent(size_t i) = 0; virtual int RemoveParent(const Node *ptr) = 0;protected: virtual void ReallyAddParent(Node *ptr) = 0; virtual bool ParReplacePtr(const Node *oldptr, Node *newptr) = 0;};class Node : public virtual NodeBase{public: friend Net::Net(NetLoader *loader); virtual ~Node() { } void NotifyDeath(Node *ptr, int verbose = 0); virtual void NotifyTimeType(int tt, int verbose = 0); void ReplacePtr(Node *oldptr, Node *newptr); void AddChild(Node *ptr) {children.push_back(ptr);}protected: Node(Net *ptr, Label label);#ifndef BUILDING_SWIG_INTERFACE Node(Net *ptr, NetLoader *loader, bool isproxy = 0);#endif void AddParent(Node *ptr, bool really=true);public: virtual bool GetReal(DSSet &val, DFlags req) { return false; } virtual void GradReal(DSSet &val, const Node *ptr) {} virtual bool GetRealV(DVH &val, DFlags req) { val.vec = 0; return GetReal(val.scalar, req); } virtual void GradRealV(DVSet &val, const Node *ptr) {}#ifdef BUILDING_SWIG_INTERFACE virtual BOOLASOBJ GetDiscrete(DD *&val) { return false; }#else virtual bool GetDiscrete(DD *&val) { return false; }#endif virtual void GradDiscrete(DD &val, const Node *ptr) {} virtual bool GetDiscreteV(VDDH &val) { val.vec = 0; return GetDiscrete(val.scalar); } virtual void GradDiscreteV(VDD &val, const Node *ptr) {} virtual void Outdate(const Node *ptr) { OutdateChild(); } void CheckParent(size_t parnum, partype_e partype); bool ParReal(int i, DSSet &val, const DFlags req) { return GetParent(i)->GetReal(val, req);} bool ParRealV(int i, DVH &val, const DFlags req) { return GetParent(i)->GetRealV(val, req);} bool ParDiscrete(int i, DD *&val) { return GetParent(i)->GetDiscrete(val);} bool ParDiscreteV(int i, VDDH &val) { return GetParent(i)->GetDiscreteV(val);} void ChildGradReal(DSSet &val); void ChildGradRealV(DVSet &val); void ChildGradDiscrete(DD &val); void ChildGradDiscreteV(VDD &val); Label GetLabel() const { return label; } string GetIdent() const { return GetType() + " node " + GetLabel(); } Net *GetNet() const { return net; } virtual string GetType() const = 0; int TimeType() { return timetype; } int GetDying() { return dying; } void Die(int verbose = 0); void OutdateChild(); virtual void Save(NetSaver *saver); size_t NumChildren() { return children.size(); } Node *GetChild(size_t i) {return i < children.size() ? children[i] : 0;} int GetPersist() { return persist; } void SetPersist(int p) { persist = p; }protected: vector<Node *> children; Net *net; Label label; int persist, timetype; bool dying;};class NullParNode : public virtual NodeBase{public: virtual int ParIdentity(const Node *ptr) {return -1;} virtual size_t NumParents() { return 0; } virtual Node *GetParent(size_t i) {return 0;} virtual int RemoveParent(const Node *ptr) {return 0;}protected: virtual void ReallyAddParent(Node *ptr) {return;} virtual bool ParReplacePtr(const Node *oldptr, Node *newptr) {return false;}};class UniParNode : public virtual NodeBase{private: Node *parent;public: UniParNode(Node *p) : parent(p) {} virtual int ParIdentity(const Node *ptr) { return ptr==parent ? 0 : -1;} virtual size_t NumParents() { return parent!=0; } virtual Node *GetParent(size_t i) {return i==0 ? parent : 0;} virtual int RemoveParent(const Node *ptr);protected: virtual void ReallyAddParent(Node *ptr) { parent = ptr; } virtual bool ParReplacePtr(const Node *oldptr, Node *newptr) { return (parent==oldptr) ? (parent=newptr) : false; }};class BiParNode : public virtual NodeBase{private: Node *parents[2];public: BiParNode(Node *p1, Node *p2) { parents[0]=p1; parents[1]=p2; } virtual int ParIdentity(const Node *ptr); virtual size_t NumParents() { return parents[0] == 0 ? 0 : (parents[1] == 0 ? 1 : 2); } virtual Node *GetParent(size_t i) {return i < 2 ? parents[i] : 0;} virtual int RemoveParent(const Node *ptr);protected: virtual void ReallyAddParent(Node *ptr); virtual bool ParReplacePtr(const Node *oldptr, Node *newptr);};class NParNode : public virtual NodeBase{private: vector<Node *> parents; map<const Node *, int> parent_inds;public: NParNode(Node *p1=0, Node *p2=0, Node *p3=0, Node *p4=0, Node *p5=0) { if (p1) parents.push_back(p1); if (p2) parents.push_back(p2); if (p3) parents.push_back(p3); if (p4) parents.push_back(p4); if (p5) parents.push_back(p5); } virtual int ParIdentity(const Node *ptr); virtual size_t NumParents() { return parents.size(); } virtual Node *GetParent(size_t i) {return i < parents.size() ? parents[i] : 0;} virtual int RemoveParent(const Node *ptr);protected: virtual void ReallyAddParent(Node *ptr) { parents.push_back(ptr); } virtual bool ParReplacePtr(const Node *oldptr, Node *newptr);};class Constant : public Node, public NullParNode{public: Constant(Net *net, Label label, double v) : Node(net, label) {cval = v;}#ifndef BUILDING_SWIG_INTERFACE Constant(Net *net, NetLoader *loader);#endif void NotifyTimeType(int tt, int verbose=0) {} bool GetReal(DSSet &val, DFlags req) { if (req.mean) {val.mean = cval; req.mean = false;} if (req.var) {val.var = 0; req.var = false;} if (req.ex) {val.ex = exp(cval); req.ex = false;} return req.AllFalse(); } void GradReal(DSSet &val, const Node *ptr) {} void Save(NetSaver *saver); string GetType() const { return "Constant"; }private: double cval;};class ConstantV : public Node, public NullParNode{public: ConstantV(Net *net, Label label, DV v);#ifndef BUILDING_SWIG_INTERFACE ConstantV(Net *net, NetLoader *loader);#endif void NotifyTimeType(int tt, int verbose=0) {} bool GetRealV(DVH &val, DFlags req) { val.vec = &myval; req.mean = false; req.var = false; req.ex = false; return req.AllFalse(); } void Save(NetSaver *saver); string GetType() const { return "ConstantV"; }private: DVSet myval;};class Function : public Node{public: void Outdate(const Node *ptr) { uptodate = DFlags(false,false,false); OutdateChild(); } virtual void Save(NetSaver *saver);protected: Function(Net *ptr, Label label, Node *n1 = 0, Node *n2 = 0);#ifndef BUILDING_SWIG_INTERFACE Function(Net *ptr, NetLoader *loader);#endif DFlags uptodate;};class Prod : public Function, public BiParNode{public: Prod(Net *ptr, Label label, Node *n1, Node *n2) : Function(ptr, label, n1, n2), BiParNode(n1, n2) {mean = 0.0; var = 0.0;}#ifndef BUILDING_SWIG_INTERFACE Prod(Net *ptr, NetLoader *loader);#endif bool GetReal(DSSet &val, DFlags req); void GradReal(DSSet &val, const Node *ptr); void Save(NetSaver *saver); string GetType() const { return "Prod"; }private: double mean, var;};class Sum2 : public Function, public BiParNode{public: Sum2(Net *ptr, Label label, Node *n1, Node *n2) : Function(ptr, label, n1, n2), BiParNode(n1, n2) { persist = 4 | 8; // Sum2 needs at least one child and cuts off if // there is only one parent }#ifndef BUILDING_SWIG_INTERFACE Sum2(Net *ptr, NetLoader *loader);#endif bool GetReal(DSSet &val, DFlags req); void GradReal(DSSet &val, const Node *ptr); void Save(NetSaver *saver); string GetType() const { return "Sum2"; }private: DSSet myval;};class SumN : public Function, public NParNode{public: SumN(Net *net, Label label) : Function(net, label) { persist = 4 | 8; // SumN needs at least one child and cuts off if // there is only one parent keepupdated = false; }#ifndef BUILDING_SWIG_INTERFACE SumN(Net *net, NetLoader *loader);#endif bool AddParent(Node *n); bool GetReal(DSSet &val, DFlags req); void GradReal(DSSet &val, const Node *ptr); void Save(NetSaver *saver); string GetType() const { return "SumN"; } void Outdate(const Node *ptr); void SetKeepUpdated(const bool _keepupdated);private: DSSet myval; vector<DSSet> parentval; bool keepupdated;};class Relay : public Function, public UniParNode{public: Relay(Net *ptr, Label label, Node *n) : Function(ptr, label, n), UniParNode(n) {}#ifndef BUILDING_SWIG_INTERFACE Relay(Net *ptr, NetLoader *loader);#endif bool GetReal(DSSet &val, DFlags req) {return ParReal(0, val, req);} void GradReal(DSSet &val, const Node *ptr) {ChildGradReal(val);} void Save(NetSaver *saver); string GetType() const { return "Relay"; }};class Variable : public Node{public: virtual double Cost() = 0; virtual void Update() { if (!clamped) { MyUpdate(); OutdateChild(); } } virtual void PartialUpdate(IntV *indices) { if (!clamped) { MyPartialUpdate(indices); OutdateChild(); } } void Clamp(double val) { if (!MyClamp(val)) { ostringstream msg; msg << GetIdent() << ": Double clamp not allowed"; throw TypeException(msg.str()); } clamped = true; costuptodate = false; OutdateChild(); } void Clamp(double mean, double var) { if (!MyClamp(mean, var)) { ostringstream msg; msg << GetIdent() << ": Double double clamp not allowed"; throw TypeException(msg.str()); } clamped = true; costuptodate = false; OutdateChild(); } void Clamp(const DV &val) { if (!MyClamp(val)) { ostringstream msg; msg << GetIdent() << ": DV clamp not allowed"; throw TypeException(msg.str()); } clamped = true; costuptodate = false; OutdateChild(); } void Clamp(const DV &mean, const DV &var) { if (!MyClamp(mean, var)) { ostringstream msg; msg << GetIdent() << ": Double DV clamp not allowed"; throw TypeException(msg.str()); } clamped = true; costuptodate = false; OutdateChild(); } void Clamp(const DD &val) { if (!MyClamp(val)) { ostringstream msg; msg << GetIdent() << ": DD clamp not allowed"; throw TypeException(msg.str()); } clamped = true; costuptodate = false; OutdateChild(); } void Clamp(int val) { if (!MyClamp(val)) { ostringstream msg; msg << GetIdent() << ": Int clamp not allowed"; throw TypeException(msg.str()); } clamped = true; costuptodate = false; OutdateChild(); } void Clamp(const VDD &val) { if (!MyClamp(val)) { ostringstream msg; msg << GetIdent() << ": VDD clamp not allowed"; throw TypeException(msg.str()); } clamped = true; costuptodate = false; OutdateChild(); } void Unclamp() {if (clamped) {clamped = false; MyUpdate(); OutdateChild();}} void SaveState(); void SaveStep(); void RepeatStep(double alpha); void SaveRepeatedState(double alpha); void ClearStateAndStep(); virtual void Outdate(const Node *ptr) {costuptodate = false;} virtual void Save(NetSaver *saver); int GetHookeFlags() { return hookeflags; } void SetHookeFlags(int h) { hookeflags = h; } bool IsClamped() { return clamped; } // These two methods are ment for copying things from one network // to another similar one // The allocation of the DV instance is left to user // so it can be done in the jurisdiction of Python's GC. // The DV is resized, so initially it can be of size zero, for example. virtual void GetState(DV *state, size_t t = 0) { ostringstream msg; msg << "GetState not supported by " << GetType(); throw TypeException(msg.str()); } virtual void SetState(DV *state, size_t t = 0) { ostringstream msg; msg << "SetState not supported by " << GetType(); throw TypeException(msg.str()); }protected: Variable(Net *ptr, Label label, Node *n1 = 0, Node *n2 = 0);#ifndef BUILDING_SWIG_INTERFACE Variable(Net *ptr, NetLoader *loader);#endif virtual bool MyClamp(double val) {return false;} virtual bool MyClamp(double mean, double var) {return false;} virtual bool MyClamp(const DV &val) {return false;} virtual bool MyClamp(const DV &mean, const DV &var) {return false;} virtual bool MyClamp(const DD &val) {return false;} virtual bool MyClamp(int val) {return false;} virtual bool MyClamp(const VDD &val) {return false;} virtual void MyUpdate() = 0; virtual bool MySaveState() {return false;} virtual bool MySaveStep() {return false;} virtual bool MySaveRepeatedState(double alpha) {return false;} virtual void MyRepeatStep(double alpha) {} virtual bool MyClearStateAndStep() {return false; } virtual void MyPartialUpdate(IntV *indices) { ostringstream msg; msg << "Partial updates not supported by " << GetType(); throw StructureException(msg.str()); } bool clamped, costuptodate; int hookeflags;};class Gaussian : public Variable, public BiParNode{public: Gaussian(Net *net, Label label, Node *m, Node *v);#ifndef BUILDING_SWIG_INTERFACE Gaussian(Net *net, NetLoader *loader);#endif ~Gaussian() { if (sstate) delete sstate; if (sstep) delete sstep; } double Cost(); bool GetReal(DSSet &val, DFlags req); void GradReal(DSSet &val, const Node *ptr); void Save(NetSaver *saver); string GetType() const { return "Gaussian"; } void GetState(DV *state, size_t t); void SetState(DV *state, size_t t);protected: virtual bool MyClamp(double m); virtual bool MyClamp(double m, double v); virtual void MyUpdate(); bool MySaveState(); bool MySaveStep(); bool MySaveRepeatedState(double alpha); void MyRepeatStep(double alpha); bool MyClearStateAndStep(); void MyPartialUpdate(IntV *indices); DSSet myval;
?? 快捷鍵說明
復制代碼
Ctrl + C
搜索代碼
Ctrl + F
全屏模式
F11
切換主題
Ctrl + Shift + D
顯示快捷鍵
?
增大字號
Ctrl + =
減小字號
Ctrl + -