?? gibbssampler.cpp
字號:
#include "GibbsSampler.h"#include "VarConfig.h"#include <algorithm>#include "Prob.h"#define CHECKMEM(x) \{ \ cout << "Pre-" << x << endl; \ int* foo = new int[1000]; \ cout << "Post-" << x << endl; \}#define SQR(x) ((x) * (x))#if 0// Unused -- we use the VarConfig class instead.int GibbsSampler::configIndex(VarSet& allVars, list<int>& testIndices){ int rangeProduct = 1; int ret = 0; list<int>::iterator i; for (i = testIndices.begin(); i != testIndices.end(); i++) { // Test vars must be discrete. ret += rangeProduct * (int)allVars[*i]; rangeProduct *= model.getRange(*i); } return ret;}#endif// Old way of burning in -- run a constant number of iterations.void GibbsSampler::burnInChain(VarSet &chain, const VarSet& evidence, int burnInIters) const{ for (int iter = 0; iter < burnInIters; iter++) { for (int v = 0; v < evidence.getNumVars(); v++) { if (!evidence.isTested(v)) { chain[v] = model.MBsample(v, chain); } } }}unsigned int GibbsSampler::sampleFromDist(const vector<double>& dist) const{ double p = (double)rand()/RAND_MAX; for (unsigned int v = 0; v < dist.size(); v++) { p -= dist[v]; if (p < 0.0) { return v; } } // DEBUG cout << "Error: returning last value in sampleFromDist()\n"; return (dist.size() - 1);}double GibbsSampler::testConvergence(vector<vector<double> > summaries, vector<vector<double> > sqSummaries, int n) const{ double maxR = 0.0; int numChains = summaries.size(); int numSummaries = sqSummaries[0].size(); vector<double> avgSummaryVariance(numSummaries); // Iterate through all chains for (int c = 0; c < numChains; c++) { // Compute within-chain variance for each summary statistic for (int s = 0; s < numSummaries; s++) {#if 0 // HACK: We require that every single state receive at least a // fractional count in some chain before we converge. if (summaries[c][s] == 1.0) { // DEBUG cout << "c = " << c << "; s = " << s << endl; return 1000.0; }#endif double chainVar = (sqSummaries[c][s] - SQR(summaries[c][s])/n)/(n-1); avgSummaryVariance[s] += chainVar/numChains; } } for (int s = 0; s < numSummaries; s++) { // Compute between-chain variance for each summary statistic double squareSum = 0.0; double sum = 0.0; for (int c = 0; c < numChains; c++) { squareSum += SQR(summaries[c][s]/n); sum += summaries[c][s]/n; } double betweenChainVariance = (squareSum - SQR(sum)/numChains)/(numChains-1); // Compute convergence criteria R double R = ((n-1.0)/n*avgSummaryVariance[s] + betweenChainVariance)/avgSummaryVariance[s]; // Report largest convergence statistic if (R > maxR) { maxR = R;#if 0 // DEBUG cout << "Counts:"; for (int c = 0; c < numChains; c++) { cout << " " << summaries[c][s]; cout << " (" << sqSummaries[c][s] << ")"; } cout << "\n"; cout << "Within: " << avgSummaryVariance[s]; cout << "; Between: " << betweenChainVariance; cout << "; R: " << sqrt(R) << endl; // END DEBUG#endif } } // DEBUG //cout << "sqrt(R) = " << sqrt(maxR) << endl; return sqrt(maxR);}double GibbsSampler::predictIters(vector<vector<double> > summaries, vector<vector<double> > sqSummaries, int n) const{ double maxV = 0.0; int numChains = summaries.size(); int numSummaries = sqSummaries[0].size(); vector<double> avgSummaryVariance(numSummaries); // Iterate through all summary statistics for (int s = 0; s < numSummaries; s++) { // Consider the chains as independent estimates of each summary // statistic, and compute their standard deviation. double squareSum = 0.0; double sum = 0.0; for (int c = 0; c < numChains; c++) { squareSum += SQR(summaries[c][s]/n); sum += summaries[c][s]/n; } // See page 740 of DeGroot and Schervish, 3rd ed. double S = sqrt((squareSum - SQR(sum)/numChains)/numChains); double sigma_hat = sqrt((double)n) * S; // Compute number of expected iterations (with 95% certainty) // to get the estimate correct within 5%. // (See page 707, eqn 11.1.5 of DeGroot and Schervish, 3rd ed.) double epsilon = 0.05 * sum/numChains; double v = SQR(1.96 * sigma_hat/epsilon); // Report largest number of iterations to run if (v > maxV) { maxV = v;#if 0 // DEBUG cout << "epsilon = " << epsilon << endl; cout << "sigma_hat = " << sigma_hat << endl; cout << "n = " << n << endl; cout << "sum = " << sum << endl; cout << "squareSum = " << squareSum << endl;#endif#if 0 double mean = sum/numChains; double maxRatio = 1.0; for (int c = 0; c < numChains; c++) { double currStat = summaries[c][s]/n; if (currStat/mean > maxRatio) { maxRatio = currStat/mean; } if (mean/currStat > maxRatio) { maxRatio = mean/currStat; } } cout << "Max ratio: " << maxRatio << endl;#endif } } return maxV;}void GibbsSampler::runMarginalInference(const VarSet& evidence){ // We use this vector to convert var/value pairs into summary // statistic indices. vector<vector<int> > index(model.getNumVars()); int numSummaries = 0; for (int v = 0; v < model.getNumVars(); v++) { for (int val = 0; val < model.getRange(v); val++) { index[v].push_back(numSummaries++); } } // Keep track of marginal counts for all test variables vector<vector<double> > counts(numChains, vector<double>(numSummaries)); vector<vector<double> > sqCounts(numChains, vector<double>(numSummaries)); for (int c = 0; c < numChains; c++) { for (int s = 0; s < numSummaries; s++) { counts[c][s] = 0.0; sqCounts[c][s] = 0.0; //counts[c][s] = 1.0; //sqCounts[c][s] = 1.0; } } // Initialize and burn-in all chains vector<VarSet> chains(numChains); for (int c = 0; c < numChains; c++) { chains[c] = evidence; model.wholeSample(chains[c]); // Use a fixed number of burn-in iters, if appropriate if (fixedIters) { burnInChain(chains[c], evidence, burnInIters); } } // Sample, sample, sample until convergence double burnin_iter = 0; double sampling_iter = 0; double predicted_iters = minIters; bool burnin_done = fixedIters; while (1) { if (burnin_done) { sampling_iter++; } else { burnin_iter++; } // Sample all variables and increase counts for (int c = 0; c < numChains; c++) { for (int v = 0; v < model.getNumVars(); v++) { // Don't resample evidence variables if (evidence.isTested(v)) { continue; } // Sample vector<double> dist = model.MBdist(v, chains[c]); chains[c][v] = sampleFromDist(dist);#define RAOBLACKWELL#ifdef RAOBLACKWELL // Update (Rao-Blackwellized) counts for (int val = 0; val < model.getRange(v); val++) { // Update counts using the distribution int s = index[v][val]; counts[c][s] += dist[val]; sqCounts[c][s] += dist[val]*dist[val]; }#else // Update counts int s = index[v][(int)chains[c][v]]; // DEBUG //cout << "s = " << s << endl; counts[c][s]++; sqCounts[c][s]++;#endif } } // After completing some minimum number of iterations, // check for convergence of our burn-in period if (!burnin_done && burnin_iter >= burnInIters && ((int)burnin_iter % 100 == 0) && (testConvergence(counts, sqCounts, (int)burnin_iter) < convergenceRatio)) { // Stop burn-in burnin_done = true; // Throw away counts for burn-in period for (int c = 0; c < numChains; c++) { for (int s = 0; s < numSummaries; s++) { counts[c][s] = 0.0; sqCounts[c][s] = 0.0; //counts[c][s] = 1.0; //sqCounts[c][s] = 1.0; } } } // Test for convergence of the sampling // Go until our standard error among the different chains is // less than 5% of the predicted value. if (burnin_done && sampling_iter >= predicted_iters) { // Stop, if we're only running a fixed number of iterations if (fixedIters) { break; } predicted_iters = predictIters(counts, sqCounts, (int)sampling_iter); // DEBUG cout << "Predicted iters = " << predicted_iters << endl; if (predicted_iters <= sampling_iter) { break; } } } // Save distributions for (int v = 0; v < model.getNumVars(); v++) { if (evidence.isTested(v)) { continue; } Distribution m(model.getRange(v)); for (int val = 0; val < model.getRange(v); val++) { m[val] = 0; for (int c = 0; c < numChains; c++) { m[val] += counts[c][index[v][val]]; } } m.normalize(); marginals[v] = m;#ifdef DEBUG if (!evidence.isTested(v)) { cout << v << ": " << m << endl; }#endif } // DEBUG if (!fixedIters) { cout << burnin_iter << "; " << sampling_iter << endl; }}void GibbsSampler::runJointInference(const list<int>& queryVars, const VarSet& evidence){ VarSchema schema = model.getSchema(); VarConfig query(evidence, queryVars, schema); int numSummaries = query.getMaxIndex() + 1; // Keep track of counts for all test configurations vector<vector<double> > counts(numChains, vector<double>(numSummaries)); vector<vector<double> > sqCounts(numChains, vector<double>(numSummaries)); for (int c = 0; c < numChains; c++) { for (int s = 0; s < numSummaries; s++) { counts[c][s] = 1.0/numSummaries; sqCounts[c][s] = 1.0/numSummaries;
?? 快捷鍵說明
復(fù)制代碼
Ctrl + C
搜索代碼
Ctrl + F
全屏模式
F11
切換主題
Ctrl + Shift + D
顯示快捷鍵
?
增大字號
Ctrl + =
減小字號
Ctrl + -