?? 神經網絡bp.txt
字號:
#if !defined(AFX_BPNET_H__7ACF7725_EE66_11D6_AAF0_00E04F29491B__INCLUDED_)
#define AFX_BPNET_H__7ACF7725_EE66_11D6_AAF0_00E04F29491B__INCLUDED_
#if _MSC_VER > 1000
#pragma once
#endif // _MSC_VER > 1000
// BpNet.h : header file
//
/////////////////////////////////////////////////////////////////////////////
// CBpNet window
#include<matlib.h>
class CBpNet : public CObject
{
// Construction
public:
CBpNet();
// Attributes
public:
// Operations
public:
// Overrides
// ClassWizard generated virtual function overrides
//{{AFX_VIRTUAL(CBpNet)
//}}AFX_VIRTUAL
// Implementation
public:
void Serialize( CArchive& ar );
void display(Mm data);
Mm scope(Mm mData);
long lEpochs;
double dblMse;
double dblError;
double randab(double a,double b);
void stop();
void learn();
bool SaveBpNet(CString &strNetName);
void LoadBpNet(CString &strNetName);
Mm simulate(Mm mData);
void Create(Mm mInputData,Mm mTarget,int iInput,int iHidden,int iOutput);
virtual ~CBpNet();
// Generated message map functions
protected:
//{{AFX_MSG(CBpNet)
// NOTE - the ClassWizard will add and remove member functions here.
//}}AFX_MSG
DECLARE_SERIAL(CBpNet)
public:
bool m_isOK;
void LoadPattern(Mm mIn,Mm mOut);
int iHidden;//隱層神經元個數
int iInput;//輸入個數
int iOutput;//輸出個數
protected:
Mm mInput;//單個樣本輸入數據
Mm mSampleInput;//全體樣本輸入數據
Mm mSampleTarget;//全體目標數據
Mm mHidden;//計算得到的隱層數據
Mm mOutput;//計算輸出
Mm mWeighti;//輸入-隱層權重
Mm mWeighto;//隱層-輸出權重
Mm mChangei;//輸入-隱層權重變化
Mm mChangeo;//隱層-輸出權重變化
public:
Mm mInputNormFactor;//正規化因子,iInputx2
Mm mTargetNormFactor;//輸出正規化因子,iOutputx2
protected:
Mm mThresholdi;//闕值
Mm mThresholdo;
Mm mOutputDeltas;//誤差
Mm mHiddenDeltas;
protected:
bool m_IsStop;
double dblMomentumFactor;
double dblLearnRate1;
double dblLearnRate2;
void backward(int iSample);
void forward(int iSample);
void normalize();//將輸入輸出樣本數據正規化處理
private:
double dblErr;
};
/////////////////////////////////////////////////////////////////////////////
//{{AFX_INSERT_LOCATION}}
// Microsoft Visual C++ will insert additional declarations immediately before
the previous line.
#endif // !defined(AFX_BPNET_H__7ACF7725_EE66_11D6_AAF0_00E04F29491B__INCLUDED
_)
// BpNet.cpp : implementation file
////////////////////////////////////////////////////////////////////
/////////////////人工神經網絡BP算法/////////////////////////////////
//1、動態改變學習速率
//2、加入動量項
//3、運用了Matcom4.5的矩陣運算庫(可免費下載,頭文件matlib.h),
// 方便矩陣運算,當然,也可自己寫矩陣類
//4、可暫停運算
//5、可將網絡以文件的形式保存、恢復
///////////////作者:同濟大學材料學院 張純禹//////////////////////
///////////////email:chunyu_79@hotmail.com//////////////////////////
///////////////QQ:53806186//////////////////////////////////////////
///////////////歡迎不斷改進!歡迎討論其他實用的算法!/////////////////
#include "BpNet.h"
#ifdef _DEBUG
#define new DEBUG_NEW
#undef THIS_FILE
static char THIS_FILE[] = __FILE__;
#endif
/////////////////////////////////////////////////////////////////////////////
// CBpNet
IMPLEMENT_SERIAL( CBpNet, CObject, 1 )
CBpNet::CBpNet()
{initM(MATCOM_VERSION);//啟用矩陣運算庫
}
CBpNet::~CBpNet()
{exitM();
delete this;
}
/////////////////////////////////////////////////////////////////////////////
// CBpNet message handlers
//創建新網絡
void CBpNet::Create(Mm mInputData, Mm mTarget, int iInput, int iHidden, int iO
utput)
{ int i,j;
mSampleInput=zeros(mInput.rows(),mInput.cols());
mSampleTarget=zeros(mTarget.rows(),mTarget.cols());
mSampleInput=mInputData;
mSampleTarget=mTarget;
this->iInput=iInput;
this->iHidden=iHidden;
this->iOutput=iOutput;
//創建計算用的單個樣本矩陣
mInput=zeros(1,this->iInput);
mHidden=zeros(1,this->iHidden);
mOutput=zeros(1,this->iOutput);
//創建權重矩陣,并賦初值
mWeighti=zeros(this->iInput,this->iHidden);
mWeighto=zeros(this->iHidden,this->iOutput);
//賦初值
for(i=1;i<=this->iInput;i++)
for(j=1;j<=this->iHidden;j++)
mWeighti.r(i,j)=randab(-1.0,1.0);
for(i=1;i<=this->iHidden;i++)
for(j=1;j<=this->iOutput;j++)
mWeighto.r(i,j)=randab(-1.0,1.0);
//創建闕值矩陣,并賦值
mThresholdi=zeros(1,this->iHidden);
for(i=1;i<=this->iHidden;i++)
mThresholdi.r(i)=randab(-1.0,1.0);
mThresholdo=zeros(1,this->iOutput);
for(i=1;i<=this->iOutput;i++)
mThresholdo.r(i)=randab(-1.0,1.0);
//創建權重變化矩陣
mChangei=zeros(this->iInput,this->iHidden);
mChangeo=zeros(this->iHidden,this->iOutput);
mInputNormFactor=zeros(iInput,2);
mTargetNormFactor=zeros(iOutput,2);
//誤差矩陣
mOutputDeltas=zeros(iOutput);
mHiddenDeltas=zeros(iHidden);
//學習速率賦值
dblLearnRate1=0.5;
dblLearnRate2=0.5;
dblMomentumFactor=0.95;
m_isOK=false;
m_IsStop=false;
dblMse=1.0e-6;//誤差限
dblError=1.0;
lEpochs=0;
}
//根據已有的網絡進行預測
Mm CBpNet::simulate(Mm mData)
{int i,j;
Mm mResult;
Mm data=zeros(mData.rows(),mData.cols());
data=mData;
if(mData.cols()!=iInput)
{::MessageBox(NULL,"輸入數據變量個數錯誤!","輸入數據變量個數錯誤!",MB_OK);
return mResult;
}
mResult=zeros(data.rows(),iOutput);
//正規化數據
for(i=1;i<=data.rows();i++)
for(j=1;j<=data.cols();j++)
data.r(i,j)=(data.r(i,j)-mInputNormFactor.r(j,1))/(mInputNormFactor.r
(j,2)-mInputNormFactor.r(j,1));
//計算
int iSample;
Mm mInputdata,mHiddendata,mOutputdata;
mInputdata=zeros(1,iInput);
mHiddendata=zeros(1,iHidden);
mOutputdata=zeros(1,iOutput);
double sum=0.0;
for(iSample=1;iSample<=data.rows();iSample++){
//輸入層數據
for(i=1;i<=iInput;i++)
mInputdata.r(i)=data.r(iSample,i);
//隱層數據
for(j=1;j<=iHidden;j++){
sum=0.0;
for(i=1;i<=iInput;i++)
sum+=mInputdata.r(i)*mWeighti.r(i,j);
sum-=mThresholdi.r(j);
mHiddendata.r(j)=1.0/(1.0+exp(-sum));
}
//輸出數據
for(j=1;j<=iOutput;j++){
sum=0.0;
for(i=1;i<=iHidden;i++)
sum+=mHiddendata.r(i)*mWeighto.r(i,j);
sum-=mThresholdo.r(j);
mOutputdata.r(j)=1.0/(1.0+exp(-sum));
}
//轉換
for(j=1;j<=iOutput;j++)
mResult.r(iSample,j)=mOutputdata.r(j)*(mTargetNormFactor.r(j,2)-mTarge
tNormFactor.r(j,1))+mTargetNormFactor.r(j,1);
}
return (mResult);
}
void CBpNet::LoadBpNet(CString &strNetName)
{CFile file;
if(file.Open(strNetName,CFile::modeRead)==0)
{MessageBox(NULL,"無法打開文件!","錯誤",MB_OK);
return;
}
else{
CArchive myar(&file,CArchive::load);
Serialize(myar);
myar.Close();
}
file.Close();
}
bool CBpNet::SaveBpNet(CString &strNetName)
{CFile file;
if(strNetName.GetLength()==0)
return(false);
if(file.Open(strNetName,CFile::modeCreate|CFile::modeWrite)==0)
{MessageBox(NULL,"無法創建文件!","錯誤",MB_OK);
return(false);
}
else{
CArchive myar(&file,CArchive::store);
Serialize(myar);
myar.Close();
}
file.Close();
return(true);
}
//網絡學習
void CBpNet::learn()
{ int iSample=1;
double dblTotal;
MSG msg;
if(m_IsStop)
m_IsStop=false;
//數據正規化處理
normalize();
while(dblError>dblMse&&!m_IsStop){
dblTotal=0.0;
for(iSample=1;iSample<=mSampleInput.rows();iSample++){
forward(iSample);
backward(iSample);
dblTotal+=dblErr;//總誤差
}
if(dblTotal/dblError>1.04){//動態改變學習速率
dblLearnRate1*=0.7;
dblLearnRate2*=0.7;
}
else{
dblLearnRate1*=1.05;
dblLearnRate2*=1.05;
}
lEpochs++;
dblError=dblTotal;
::PeekMessage(&msg,NULL,0,0,PM_REMOVE);
::DispatchMessage(&msg);
msg.message=-1;
::DispatchMessage(&msg);//這樣可以消除屏閃和假死機
}
if(dblError<=dblMse)
?? 快捷鍵說明
復制代碼
Ctrl + C
搜索代碼
Ctrl + F
全屏模式
F11
切換主題
Ctrl + Shift + D
顯示快捷鍵
?
增大字號
Ctrl + =
減小字號
Ctrl + -