?? node.h
字號:
DSSet *sstate, *sstep; double cost; bool exuptodate;};class RectifiedGaussian : public Variable, public BiParNode{public: RectifiedGaussian(Net *net, Label label, Node *m, Node *v);#ifndef BUILDING_SWIG_INTERFACE RectifiedGaussian(Net *net, NetLoader *loader);#endif double Cost(); bool GetReal(DSSet &val, DFlags req); /* Returns the actual posterior parameters. */ bool GetMyval(DSSet &val); void GradReal(DSSet &val, const Node *ptr); string GetType() const; void Save(NetSaver *saver); void GetState(DV *state, size_t t); void SetState(DV *state, size_t t);protected: virtual void MyUpdate(); void MyPartialUpdate(IntV *indices); void UpdateExpectations(); /* Parameters of the rectified Gaussian posterior approximation. For debug purposes. */ DSSet myval; /* Expectations (stored to gain speed). Note that the posterior mean- or variance parameter is not the same as the mean or variance because the posterior is approximated with a rectified Gaussian. */ DSSet expectations; double cost;};class RectifiedGaussianV : public Variable, public BiParNode{public: RectifiedGaussianV(Net *net, Label label, Node *m, Node *v);#ifndef BUILDING_SWIG_INTERFACE RectifiedGaussianV(Net *net, NetLoader *loader);#endif double Cost(); bool GetRealV(DVH &val, DFlags req); bool GetMyvalV(DVH &val); void GradReal(DSSet &val, const Node *ptr); void GradRealV(DVSet &val, const Node *ptr); string GetType() const; void Save(NetSaver *saver);protected: void MyUpdate(); void UpdateExpectations(); DVSet myval; DVSet expectations; double cost;};class GaussRect : public Variable, public BiParNode{public: GaussRect(Net *net, Label label, Node *m, Node *v);#ifndef BUILDING_SWIG_INTERFACE GaussRect(Net *net, NetLoader *loader);#endif double Cost(); bool GetReal(DSSet &val, DFlags req); bool GetRectReal(DSSet &val, DFlags req); void GradReal(DSSet &val, const Node *ptr); string GetType() const; void Save(NetSaver *saver);protected: void MyUpdate(); void UpdateMoments(); void UpdateExpectations(); void ChildGradients(DSSet &norm, DSSet &rect); DSSet posval; DSSet negval; double posweight; double negweight; vector<double> posmoments; vector<double> negmoments; DSSet expts; DSSet rectexpts; double cost;};class GaussRectV : public Variable, public BiParNode{public: friend class GaussRectVState; GaussRectV(Net *net, Label label, Node *m, Node *v);#ifndef BUILDING_SWIG_INTERFACE GaussRectV(Net *net, NetLoader *loader);#endif double Cost(); bool GetRealV(DVH &val, DFlags req); bool GetRectRealV(DVH &val, DFlags req); void GradReal(DSSet &val, const Node *ptr); void GradRealV(DVSet &val, const Node *ptr); string GetType() const; void Save(NetSaver *saver); void GetState(DV *state, size_t t); void SetState(DV *state, size_t t);protected: void MyUpdate(); void MyPartialUpdate(IntV *indices); void UpdateMoments(); void UpdateExpectations(); void ChildGradients(DVSet &norm, DVSet &rect); DVSet posval; DVSet negval; DV posweights; DV negweights; vector<DV> posmoments; vector<DV> negmoments; DVSet expts; DVSet rectexpts; double cost;};// Making the internals of GaussRectV public does not// seem temptating but writing unittests without them// is impossible. Hence, GaussRectVState (a friend of GaussRectV)// provides access to the internals of GaussRectV without// cluttering the interface of GaussRectV.class GaussRectVState{public: GaussRectVState(GaussRectV *n); DVSet &GetPosVal(); DVSet &GetNegVal(); DV &GetPosWeights(); DV &GetNegWeights(); DV &GetPosMoment(int i); DV &GetNegMoment(int i);private: GaussRectV *node;};class MoG : public Variable, public NParNode{public: MoG(Net *net, Label label, Node *d);#ifndef BUILDING_SWIG_INTERFACE MoG(Net *net, NetLoader *loader);#endif double Cost(); bool GetReal(DSSet &val, DFlags req); void GradReal(DSSet &val, const Node *ptr); void GradDiscrete(DD &val, const Node *ptr); string GetType() const; void Save(NetSaver *saver); void AddComponent(Node *m, Node *v); size_t NumComponents();protected: void MyUpdate(); vector<DSSet*> myval; vector<Node*> means; vector<Node*> vars;private: bool IsMeanParent(const Node *ptr); bool IsVarParent(const Node *ptr); int WhichMeanParent(const Node *ptr); int WhichVarParent(const Node *ptr); int WhichParent(const Node *ptr, const vector<Node*> &parents); void ComputeExpectations(); DSSet expts; size_t numComponents; double cost;};class MoGV : public Variable, public NParNode{public: MoGV(Net *net, Label label, Node *d);#ifndef BUILDING_SWIG_INTERFACE MoGV(Net *net, NetLoader *loader);#endif double Cost(); bool GetRealV(DVH &val, DFlags req); void GetMyvalV(DVH &val, int k); void GradReal(DSSet &val, const Node *ptr); void GradRealV(DVSet &val, const Node *ptr); void GradDiscreteV(VDD &val, const Node *ptr); string GetType() const; void Save(NetSaver *saver); /* Parents MUST be added with this method. */ void AddComponent(Node *m, Node *v); size_t NumComponents();protected: void MyUpdate(); bool MyClamp(const DV &m); /* Posterior parameters (weights are got from Categorical) */ vector<DVSet*> myval; vector<Node*> means; vector<Node*> vars;private: bool IsMeanParent(const Node *ptr); bool IsVarParent(const Node *ptr); int WhichMeanParent(const Node *ptr); int WhichVarParent(const Node *ptr); int WhichParent(const Node *ptr, const vector<Node*> &parents); /* Updates expts. */ void ComputeExpectations(); /* Expectations calculated from the posterior. */ DVSet expts; /* Number of mixture components. */ size_t numComponents; double cost;};class Dirichlet : public Variable, public NParNode{public: Dirichlet(Net *net, Label label, ConstantV *n);#ifndef BUILDING_SWIG_INTERFACE Dirichlet(Net *net, NetLoader *loader);#endif double Cost(); /* Returns expectations of different components. <log c_i> is in ex field, naturally. */ bool GetRealV(DVH &val, DFlags req); string GetType() const; void Save(NetSaver *saver);protected: void MyUpdate();private: /* Updates expts. */ void ComputeExpectations(); /* Posterior parameters. */ DV myval; /* Expectations calculated from the posterior. */ DVSet expts; /* Number of components. */ size_t numComponents; double cost;};class DiscreteDirichlet : public Variable, public NParNode{public: DiscreteDirichlet(Net *net, Label label, Dirichlet *n);#ifndef BUILDING_SWIG_INTERFACE DiscreteDirichlet(Net *net, NetLoader *loader);#endif double Cost(); bool GetDiscrete(DD *&val); void GradRealV(DVSet &val, const Node *ptr); string GetType() const; void Save(NetSaver *saver);protected: void MyUpdate(); bool MyClamp(const DD &v); DD myval; double cost;};/* A discrete variable with dirichlet prior for its prior weights */class DiscreteDirichletV : public Variable, public NParNode{public: DiscreteDirichletV(Net *net, Label label, Dirichlet *n);#ifndef BUILDING_SWIG_INTERFACE DiscreteDirichletV(Net *net, NetLoader *loader);#endif double Cost(); bool GetDiscreteV(VDDH &val); void GradRealV(DVSet &val, const Node *ptr); string GetType() const; void Save(NetSaver *saver);protected: void MyUpdate(); bool MyClamp(const VDD &v); VDD myval; double cost;};class Rectification : public Function, public UniParNode{public: Rectification(Net *net, Label label, Node *n);#ifndef BUILDING_SWIG_INTERFACE Rectification(Net *net, NetLoader *loader);#endif bool GetReal(DSSet &val, DFlags req); void GradReal(DSSet &val, const Node *ptr); void Save(NetSaver *saver); string GetType() const;};class RectificationV : public Function, public UniParNode{public: RectificationV(Net *net, Label label, Node *n);#ifndef BUILDING_SWIG_INTERFACE RectificationV(Net *net, NetLoader *loader);#endif bool GetRealV(DVH &val, DFlags req); void GradRealV(DVSet &val, const Node *ptr); void Save(NetSaver *saver); string GetType() const;};class ProdV : public Function, public BiParNode{public: ProdV(Net *ptr, Label label, Node *n1, Node *n2) : Function(ptr, label, n1, n2), BiParNode(n1, n2) {}#ifndef BUILDING_SWIG_INTERFACE ProdV(Net *ptr, NetLoader *loader);#endif void GradReal(DSSet &val, const Node *ptr); bool GetRealV(DVH &val, DFlags req); void GradRealV(DVSet &val, const Node *ptr); void Save(NetSaver *saver); string GetType() const { return "ProdV"; }private: DVSet myval;};class Sum2V : public Function, public BiParNode{public: Sum2V(Net *ptr, Label label, Node *n1, Node *n2) : Function(ptr, label, n1, n2), BiParNode(n1, n2) { persist = 4 | 8; // Sum2V needs at least one child and cuts off if // there is only one parent }#ifndef BUILDING_SWIG_INTERFACE Sum2V(Net *ptr, NetLoader *loader);#endif void GradReal(DSSet &val, const Node *ptr); bool GetRealV(DVH &val, DFlags req); void GradRealV(DVSet &val, const Node *ptr); void Save(NetSaver *saver); string GetType() const { return "Sum2V"; }private: DVSet myval;};class SumNV : public Function, public NParNode{public: SumNV(Net *net, Label label) : Function(net, label) { persist = 4 | 8; // SumN needs at least one child and cuts off if // there is only one parent keepupdated = false; }#ifndef BUILDING_SWIG_INTERFACE SumNV(Net *net, NetLoader *loader);#endif bool AddParent(Node *n); bool GetRealV(DVH &val, DFlags req); void GradReal(DSSet &val, const Node *ptr); void GradRealV(DVSet &val, const Node *ptr); void Save(NetSaver *saver); string GetType() const { return "SumNV"; } void Outdate(const Node *ptr); void SetKeepUpdated(const bool _keepupdated);private: void UpdateFromScratch(DFlags req); DVSet myval; vector<DVSet> parentval; bool keepupdated;};class DelayV : public Function, public BiParNode{public: DelayV(Net *ptr, Label label, Node *n1, Node *n2) : Function(ptr, label, n1, n2), BiParNode(n1, n2) { lendelay = 1; }#ifndef BUILDING_SWIG_INTERFACE DelayV(Net *ptr, NetLoader *loader);#endif void GradReal(DSSet &val, const Node *ptr); bool GetRealV(DVH &val, DFlags req); void GradRealV(DVSet &val, const Node *ptr); void Save(NetSaver *saver); string GetType() const { return "DelayV"; } int GetDelayLength(); void SetDelayLength(int len);private: DVSet myval; int lendelay;};class GaussianV : public Variable, public BiParNode{public: GaussianV(Net *_net, Label label, Node *m, Node *v);#ifndef BUILDING_SWIG_INTERFACE GaussianV(Net *_net, NetLoader *loader);#endif ~GaussianV() { if (sstate) delete sstate; if (sstep) delete sstep; } double Cost(); void GradReal(DSSet &val, const Node *ptr); bool GetRealV(DVH &val, DFlags req); void GradRealV(DVSet &val, const Node *ptr); void Save(NetSaver *saver); string GetType() const { return "GaussianV"; } void GetState(DV *state, size_t t); void SetState(DV *state, size_t t);protected: bool MyClamp(double m); bool MyClamp(double m, double v); bool MyClamp(const DV &m);
?? 快捷鍵說明
復(fù)制代碼
Ctrl + C
搜索代碼
Ctrl + F
全屏模式
F11
切換主題
Ctrl + Shift + D
顯示快捷鍵
?
增大字號
Ctrl + =
減小字號
Ctrl + -