?? node.h
字號:
bool MyClamp(const DV &m, const DV &v); void MyUpdate(); void MyPartialUpdate(IntV *indices); bool MySaveState(); bool MySaveStep(); bool MySaveRepeatedState(double alpha); void MyRepeatStep(double alpha); bool MyClearStateAndStep(); DVSet myval; double cost; bool exuptodate;private: DVSet *sstate, *sstep;};class SparseGaussV : public GaussianV{public: SparseGaussV(Net *_net, Label label, Node *m, Node *v) : GaussianV(_net, label, m, v) {}#ifndef BUILDING_SWIG_INTERFACE SparseGaussV(Net *_net, NetLoader *loader);#endif double Cost(); void Update(); void GradReal(DSSet &val, const Node *ptr); void GradRealV(DVSet &val, const Node *ptr); void Save(NetSaver *saver); string GetType() const { return "SparseGaussV"; } void SparseClampDV(const DV &mean, const IntV &mis); IntV& GetMissing() { return missing; } void SetMissing(IntV &mis);private: IntV missing;};class DelayGaussV : public Variable, public NParNode{public: DelayGaussV(Net *_net, Label label, Node *m, Node *v, Node *a, Node *m0, Node *v0);#ifndef BUILDING_SWIG_INTERFACE DelayGaussV(Net *_net, NetLoader *loader);#endif double Cost(); void GradReal(DSSet &val, const Node *ptr); bool GetRealV(DVH &val, DFlags req); void GradRealV(DVSet &val, const Node *ptr); void Save(NetSaver *saver); string GetType() const { return "DelayGaussV"; }protected: bool MyClamp(double m) { fill(myval.mean.begin(), myval.mean.end(), m); fill(myval.var.begin(), myval.var.end(), 0); exuptodate = false; return true; } bool MyClamp(const DV &m) { if (m.size() == myval.mean.size()) copy(m.begin(), m.end(), myval.mean.begin()); else { ostringstream msg; msg << "DelayGaussV::MyClamp: wrong vector size " << m.size() << " != " << myval.mean.size(); throw TypeException(msg.str()); } fill(myval.var.begin(), myval.var.end(), 0); return true; } bool MyClamp(const DV &m, const DV &v) { if (m.size() == myval.mean.size() && v.size() == myval.var.size()) { copy(m.begin(), m.end(), myval.mean.begin()); copy(v.begin(), v.end(), myval.var.begin()); } else { ostringstream msg; msg << "DelayGaussV::MyClamp: wrong vector size " << m.size() << " != " << myval.mean.size(); throw TypeException(msg.str()); } return true; } void MyUpdate(); bool MySaveState(); bool MySaveStep(); bool MySaveRepeatedState(double alpha); void MyRepeatStep(double alpha); bool MyClearStateAndStep();private: DVSet myval; DVSet *sstate, *sstep; double cost; bool exuptodate;};class GaussNonlin : public Variable, public BiParNode// nonlinearity: myval2 = exp(-myval1*myval1){public: GaussNonlin(Net *_net, Label label, Node *m, Node *v) : Variable(_net, label, m, v), BiParNode(m, v) { sstate = 0; sstep = 0; cost = 0; CheckParent(0, REAL_MV); CheckParent(1, REAL_ME); MyUpdate(); }#ifndef BUILDING_SWIG_INTERFACE GaussNonlin(Net *_net, NetLoader *loader);#endif ~GaussNonlin() { if (sstate) delete sstate; if (sstep) delete sstep; } double Cost(); bool GetReal(DSSet &val, DFlags req); void GradReal(DSSet &val, const Node *ptr); void Save(NetSaver *saver); string GetType() const { return "GaussNonlin"; }protected: bool MyClamp(double m) { myval1.mean = m; myval1.var = 0; meanuptodate = false; varuptodate = false; return true; } void MyUpdate(); void UpdateMean(); void UpdateVar(); bool MySaveState(); bool MySaveStep(); bool MySaveRepeatedState(double alpha); void MyRepeatStep(double alpha); bool MyClearStateAndStep();private: DSSet myval1, myval2; // 1 before nonlinearity, 2 after DSSet *sstate, *sstep; double cost; bool meanuptodate, varuptodate; // refer to myval2};class GaussNonlinV : public Variable, public BiParNode// nonlinearity: myval2 = exp(-myval1*myval1){public: GaussNonlinV(Net *_net, Label label, Node *m, Node *v);#ifndef BUILDING_SWIG_INTERFACE GaussNonlinV(Net *_net, NetLoader *loader);#endif ~GaussNonlinV() { if (sstate) delete sstate; if (sstep) delete sstep; } double Cost(); void GradReal(DSSet &val, const Node *ptr); bool GetRealV(DVH &val, DFlags req); void GradRealV(DVSet &val, const Node *ptr); void Save(NetSaver *saver); string GetType() const { return "GaussNonlinV"; }protected: bool MyClamp(double m) { fill(myval1.mean.begin(), myval1.mean.end(), m); fill(myval1.var.begin(), myval1.var.end(), 0); meanuptodate = false; varuptodate = false; return true; } bool MyClamp(const DV &m) { if (m.size() == myval1.mean.size()) copy(m.begin(), m.end(), myval1.mean.begin()); else { ostringstream msg; msg << "GaussianV::MyClamp: wrong vector size " << m.size() << " != " << myval1.mean.size(); throw TypeException(msg.str()); } fill(myval1.var.begin(), myval1.var.end(), 0); meanuptodate = false; varuptodate = false; return true; } void MyUpdate(); void UpdateMean(); void UpdateVar(); bool MySaveState(); bool MySaveStep(); bool MySaveRepeatedState(double alpha); void MyRepeatStep(double alpha); bool MyClearStateAndStep();private: DVSet myval1, myval2; // 1 before nonlinearity, 2 after DVSet *sstate, *sstep; double cost; bool meanuptodate, varuptodate; // refer to myval2};class Discrete : public Variable, public NParNode{public: Discrete(Net *_net, Label label, Node *n=0) : Variable(_net, label, n), NParNode(n) { cost = 0; exsum = 0; if (n) { CheckParent(0, REAL_ME); exuptodate = false; MyUpdate(); } }#ifndef BUILDING_SWIG_INTERFACE Discrete(Net *_net, NetLoader *loader);#endif bool AddParent(Node *n) { Node::AddParent(n); CheckParent(NumParents()-1, REAL_ME); MyUpdate(); return true; } double Cost(); void GradReal(DSSet &val, const Node *ptr);#ifdef BUILDING_SWIG_INTERFACE BOOLASOBJ GetDiscrete(DD *&val);#else bool GetDiscrete(DD *&val);#endif void Save(NetSaver *saver); string GetType() const { return "Discrete"; }protected: bool MyClamp(double m) { return false; } bool MyClamp(const DD &m) { myval = m; return true; } bool MyClamp(int n) { if (n >= (int)NumParents()) { throw TypeException("Too large value for clamping a Discrete"); } myval.Resize(NumParents()); for (size_t j=NumParents(); j>0; j--) { myval[j-1] = 0; } myval[n] = 1; return true; } void MyUpdate(); void UpdateExpSum();private: DD myval; double cost, exsum; bool exuptodate;};class DiscreteV : public Variable, public NParNode{public: DiscreteV(Net *_net, Label label, Node *n=0);#ifndef BUILDING_SWIG_INTERFACE DiscreteV(Net *_net, NetLoader *loader);#endif bool AddParent(Node *n) { DVH tmp; if (! n->GetRealV(tmp, DFlags(true, false, true))) { ostringstream msg; msg << "Wrong type of parents in " << GetType() << " Node " << label << std::endl; msg << " Parent " << n->GetLabel() << ":" << n->GetType(); throw StructureException(msg.str()); } Node::AddParent(n); MyUpdate(); return true; } double Cost(); void GradReal(DSSet &val, const Node *ptr); void GradRealV(DVSet &val, const Node *ptr); bool GetDiscreteV(VDDH &val); void Save(NetSaver *saver); string GetType() const { return "DiscreteV"; }protected: bool MyClamp(double m) { return false; } bool MyClamp(const VDD &m) { myval = m; return true; } void MyUpdate(); void UpdateExpSum();private: VDD myval; double cost; DV exsum; bool exuptodate;};class Memory : public Variable, public UniParNode{public: Memory(Net *_net, Label label, Node *n) : Variable(_net, label, n), UniParNode(n) { if (n->TimeType()) { ostringstream msg; msg << GetIdent() << ": parent must be independent of time"; throw StructureException(msg.str()); } timetype = 2; oldcost = 0; cost = 0; }#ifndef BUILDING_SWIG_INTERFACE Memory(Net * net, NetLoader *loader);#endif void NotifyTimeType(int tt, int verbose=0) { if (GetParent(0)->TimeType()) { ostringstream msg; msg << GetIdent() << ": parent must be independent of time"; throw StructureException(msg.str()); } } double Cost(); void MyUpdate(); bool GetReal(DSSet &val, DFlags req) {return ParReal(0, val, req);} void GradReal(DSSet &grad, const Node *ptr); void Save(NetSaver *saver); string GetType() const { return "Memory"; } void Outdate(const Node *ptr) { costuptodate = false; OutdateChild(); } DSSet oldval; double oldcost; double cost;};class OnLineDelay : public Node{public: virtual void Save(NetSaver *saver); virtual void StepTime() = 0; virtual void ResetTime() = 0;protected: OnLineDelay(Net *ptr, Label label, Node *n1 = 0, Node *n2 = 0);#ifndef BUILDING_SWIG_INTERFACE OnLineDelay(Net *ptr, NetLoader *loader);#endif};class OLDelayS : public OnLineDelay, public BiParNode{public: OLDelayS(Net *ptr, Label label, Node *n1 = 0, Node *n2 = 0) : OnLineDelay(ptr, label, n1, n2), BiParNode(n1, n2) { CheckParent(0, REAL_M); CheckParent(1, REAL_M); DSSet tmp; ParReal(0, tmp, DFlags(true)); oldmean = tmp.mean; exuptodate = false; }#ifndef BUILDING_SWIG_INTERFACE OLDelayS(Net *ptr, NetLoader *loader);#endif virtual void Save(NetSaver *saver); virtual void StepTime(); virtual void ResetTime(); virtual bool GetReal(DSSet &val, DFlags req); string GetType() const { return "OLDelayS"; }private: double oldmean; double oldexp; bool exuptodate;};class OLDelayD : public OnLineDelay, public BiParNode{public: OLDelayD(Net *ptr, Label label, Node *n1 = 0, Node *n2 = 0) : OnLineDelay(ptr, label, n1, n2), BiParNode(n1, n2) { CheckParent(0, DISCRETE); CheckParent(1, DISCRETE); DD *tmp; ParDiscrete(0, tmp); oldval = *tmp; }#ifndef BUILDING_SWIG_INTERFACE OLDelayD(Net *ptr, NetLoader *loader);#endif virtual void Save(NetSaver *saver); virtual void StepTime(); virtual void ResetTime();#ifdef BUILDING_SWIG_INTERFACE virtual BOOLASOBJ GetDiscrete(DD *&val);#else virtual bool GetDiscrete(DD *&val);#endif string GetType() const { return "OLDelayD"; }private: DD oldval;};class Proxy : public Node, public UniParNode{public: Proxy(Net *ptr, Label label, Label rlabel);#ifndef BUILDING_SWIG_INTERFACE Proxy(Net *ptr, NetLoader *loader);#endif void Save(NetSaver *saver); string GetType() const { return "Proxy"; } bool GetReal(DSSet &val, DFlags req); bool GetRealV(DVH &val, DFlags req);#ifdef BUILDING_SWIG_INTERFACE BOOLASOBJ GetDiscrete(DD *&val);#else bool GetDiscrete(DD *&val);#endif bool GetDiscreteV(VDDH &val); void GradReal(DSSet &val, const Node *ptr) { ChildGradReal(val); } void GradRealV(DVSet &val, const Node *ptr) { ChildGradRealV(val); } void GradDiscrete(DD &val, const Node *ptr) { ChildGradDiscrete(val); } void GradDiscreteV(VDD &val, const Node *ptr) { ChildGradDiscreteV(val); } bool CheckRef();private: string reflabel; bool req_discrete, req_discretev; DFlags real_flags, realv_flags;};class Evidence : public Variable, public Decayer, public UniParNode{public: Evidence(Net *ptr, Label label, Node *p) : Variable(ptr, label, p), Decayer(ptr), UniParNode(p) { alpha = 1e-10; decay = 0; myval = 0; cost = 0; }#ifndef BUILDING_SWIG_INTERFACE Evidence(Net *ptr, NetLoader *loader);#endif void Save(NetSaver *saver); string GetType() const { return "Evidence"; } void GradReal(DSSet &val, const Node *ptr); double Cost(); void SetDecayTime(double iters) { decay = alpha / iters; } virtual bool DoDecay(string hook);private: void MyUpdate() {} bool MyClamp(double mean, double var); double cost; double myval; double alpha; double decay;};class EvidenceV : public Variable, public Decayer, public UniParNode{public: EvidenceV(Net *ptr, Label label, Node *p);#ifndef BUILDING_SWIG_INTERFACE EvidenceV(Net *ptr, NetLoader *loader);#endif void Save(NetSaver *saver); string GetType() const { return "EvidenceV"; } void GradRealV(DVSet &val, const Node *ptr); double Cost(); void SetDecayTime(const DV &iters); virtual bool DoDecay(string hook);private: void MyUpdate() {} bool MyClamp(double mean, double var); bool MyClamp(const DV &mean, const DV &var); double cost; DV myval; DV alpha; DV decay;};#endif // NODE_H
?? 快捷鍵說明
復(fù)制代碼
Ctrl + C
搜索代碼
Ctrl + F
全屏模式
F11
切換主題
Ctrl + Shift + D
顯示快捷鍵
?
增大字號
Ctrl + =
減小字號
Ctrl + -