?? em.java
字號:
/**
*
* AgentAcademy - an open source Data Mining framework for
* training intelligent agents
*
* Copyright (C) 2001-2003 AA Consortium.
*
* This library is open source software; you can redistribute it
* and/or modify it under the terms of the GNU Lesser General
* Public License as published by the Free Software Foundation;
* either version 2.0 of the License, or (at your option) any later
* version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this library; if not, write to the Free
* Software Foundation, Inc., 59 Temple Place, Suite 330, Boston,
* MA 02111-1307 USA
*
*/
package org.agentacademy.modules.dataminer.clusterers;
/**
* <p>Title: The Data Miner prototype</p>
* <p>Description: A prototype for the DataMiner (DM), the Agent Academy (AA) module responsible for performing data mining on the contents of the Agent Use Repository (AUR). The extracted knowledge is to be sent back to the AUR in the form of a PMML document.</p>
* <p>Copyright: Copyright (c) 2002</p>
* <p>Company: CERTH</p>
* @author asymeon
* @version 0.3
*/
import java.util.Enumeration;
import java.util.Random;
import java.util.Vector;
import org.agentacademy.modules.dataminer.core.Instance;
import org.agentacademy.modules.dataminer.core.Instances;
import org.agentacademy.modules.dataminer.core.Option;
import org.agentacademy.modules.dataminer.core.OptionHandler;
import org.agentacademy.modules.dataminer.core.Utils;
import weka.estimators.DiscreteEstimator;
import weka.estimators.Estimator;
import org.apache.log4j.Logger;
/**
* Simple EM (estimation maximisation) class. <p>
*
* EM assigns a probability distribution to each instance which
* indicates the probability of it belonging to each of the clusters.
* EM can decide how many clusters to create by cross validation, or you
* may specify apriori how many clusters to generate. <p>
*
* Valid options are:<p>
*
* -V <br>
* Verbose. <p>
*
* -N <number of clusters> <br>
* Specify the number of clusters to generate. If omitted,
* EM will use cross validation to select the number of clusters
* automatically. <p>
*
* -I <max iterations> <br>
* Terminate after this many iterations if EM has not converged. <p>
*
* -S <seed> <br>
* Specify random number seed. <p>
*
* -M <num> <br>
* Set the minimum allowable standard deviation for normal density calculation.
* <p>
*
*/
public class EM
extends DistributionClusterer
implements OptionHandler
{
public static Logger log = Logger.getLogger(EM.class);
/** hold the discrete estimators for each cluster */
private Estimator m_model[][];
/** hold the normal estimators for each cluster */
private double m_modelNormal[][][];
/** default minimum standard deviation */
private double m_minStdDev = 1e-6;
/** hold the weights of each instance for each cluster */
private double m_weights[][];
/** the prior probabilities for clusters */
private double m_priors[];
/** the loglikelihood of the data */
private double m_loglikely;
/** training instances */
private Instances m_theInstances = null;
/** number of clusters selected by the user or cross validation */
private int m_num_clusters;
/** the initial number of clusters requested by the user--- -1 if
xval is to be used to find the number of clusters */
private int m_initialNumClusters;
/** number of attributes */
private int m_num_attribs;
/** number of training instances */
private int m_num_instances;
/** maximum iterations to perform */
private int m_max_iterations;
/** random numbers and seed */
private Random m_rr;
private int m_rseed;
/** Constant for normal distribution. */
private static double m_normConst = Math.sqrt(2*Math.PI);
/** Verbose? */
private boolean m_verbose;
/**
* Returns a string describing this clusterer
* @return a description of the evaluator suitable for
* displaying in the explorer/experimenter gui
*/
public String globalInfo() {
return "Cluster data using expectation maximization";
}
/**
* Returns an enumeration describing the available options.. <p>
*
* Valid options are:<p>
*
* -V <br>
* Verbose. <p>
*
* -N <number of clusters> <br>
* Specify the number of clusters to generate. If omitted,
* EM will use cross validation to select the number of clusters
* automatically. <p>
*
* -I <max iterations> <br>
* Terminate after this many iterations if EM has not converged. <p>
*
* -S <seed> <br>
* Specify random number seed. <p>
*
* -M <num> <br>
* Set the minimum allowable standard deviation for normal density
* calculation. <p>
*
* @return an enumeration of all the available options.
*
**/
public Enumeration listOptions () {
Vector newVector = new Vector(6);
newVector.addElement(new Option("\tnumber of clusters. If omitted or"
+ "\n\t-1 specified, then cross "
+ "validation is used to\n\tselect the "
+ "number of clusters.", "N", 1
, "-N <num>"));
newVector.addElement(new Option("\tmax iterations.\n(default 100)", "I"
, 1, "-I <num>"));
newVector.addElement(new Option("\trandom number seed.\n(default 1)"
, "S", 1, "-S <num>"));
newVector.addElement(new Option("\tverbose.", "V", 0, "-V"));
newVector.addElement(new Option("\tminimum allowable standard deviation "
+"for normal density computation "
+"\n\t(default 1e-6)"
,"M",1,"-M <num>"));
return newVector.elements();
}
/**
* Parses a given list of options.
* @param options the list of options as an array of strings
* @exception Exception if an option is not supported
*
**/
public void setOptions (String[] options)
throws Exception
{
resetOptions();
setDebug(Utils.getFlag('V', options));
String optionString = Utils.getOption('I', options);
if (optionString.length() != 0) {
setMaxIterations(Integer.parseInt(optionString));
}
optionString = Utils.getOption('N', options);
if (optionString.length() != 0) {
setNumClusters(Integer.parseInt(optionString));
}
optionString = Utils.getOption('S', options);
if (optionString.length() != 0) {
setSeed(Integer.parseInt(optionString));
}
optionString = Utils.getOption('M', options);
if (optionString.length() != 0) {
setMinStdDev((new Double(optionString)).doubleValue());
}
}
/**
* Returns the tip text for this property
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String minStdDevTipText() {
return "set minimum allowable standard deviation";
}
/**
* Set the minimum value for standard deviation when calculating
* normal density. Reducing this value can help prevent arithmetic
* overflow resulting from multiplying large densities (arising from small
* standard deviations) when there are many singleton or near singleton
* values.
* @param m minimum value for standard deviation
*/
public void setMinStdDev(double m) {
m_minStdDev = m;
}
/**
* Get the minimum allowable standard deviation.
* @return the minumum allowable standard deviation
*/
public double getMinStdDev() {
return m_minStdDev;
}
/**
* Returns the tip text for this property
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String seedTipText() {
return "random number seed";
}
/**
* Set the random number seed
*
* @param s the seed
*/
public void setSeed (int s) {
m_rseed = s;
}
/**
* Get the random number seed
*
* @return the seed
*/
public int getSeed () {
return m_rseed;
}
/**
* Returns the tip text for this property
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String numClustersTipText() {
return "set number of clusters. -1 to select number of clusters "
+"automatically by cross validation.";
}
/**
* Set the number of clusters (-1 to select by CV).
*
* @param n the number of clusters
* @exception Exception if n is 0
*/
public void setNumClusters (int n)
throws Exception {
if (n == 0) {
throw new Exception("Number of clusters must be > 0. (or -1 to "
+ "select by cross validation).");
}
if (n < 0) {
m_num_clusters = -1;
m_initialNumClusters = -1;
}
else {
m_num_clusters = n;
m_initialNumClusters = n;
}
}
/**
* Get the number of clusters
*
* @return the number of clusters.
*/
public int getNumClusters () {
return m_initialNumClusters;
}
/**
* Returns the tip text for this property
* @return tip text for this property suitable for
* displaying in the explorer/experimenter gui
*/
public String maxIterationsTipText() {
return "maximum number of iterations";
}
/**
* Set the maximum number of iterations to perform
*
* @param i the number of iterations
* @exception Exception if i is less than 1
*/
public void setMaxIterations (int i)
throws Exception
{
if (i < 1) {
throw new Exception("Maximum number of iterations must be > 0!");
}
m_max_iterations = i;
}
/**
* Get the maximum number of iterations
*
* @return the number of iterations
*/
public int getMaxIterations () {
return m_max_iterations;
}
/**
* Set debug mode - verbose output
*
* @param v true for verbose output
*/
public void setDebug (boolean v) {
m_verbose = v;
}
/**
* Get debug mode
*
* @return true if debug mode is set
*/
public boolean getDebug () {
return m_verbose;
}
/**
* Gets the current settings of EM.
*
* @return an array of strings suitable for passing to setOptions()
*/
public String[] getOptions () {
String[] options = new String[9];
int current = 0;
if (m_verbose) {
options[current++] = "-V";
}
options[current++] = "-I";
options[current++] = "" + m_max_iterations;
options[current++] = "-N";
options[current++] = "" + getNumClusters();
options[current++] = "-S";
options[current++] = "" + m_rseed;
options[current++] = "-M";
options[current++] = ""+getMinStdDev();
while (current < options.length) {
options[current++] = "";
}
return options;
}
/**
* Initialised estimators and storage.
*
* @param inst the instances
* @param num_cl the number of clusters
**/
private void EM_Init (Instances inst, int num_cl)
throws Exception
{
m_weights = new double[inst.numInstances()][num_cl];
int z;
m_model = new Estimator[num_cl][m_num_attribs];
m_modelNormal = new double[num_cl][m_num_attribs][3];
m_priors = new double[num_cl];
for (int i = 0; i < inst.numInstances(); i++) {
for (int j = 0; j < num_cl; j++) {
m_weights[i][j] = m_rr.nextDouble();
}
Utils.normalize(m_weights[i]);
}
// initial priors
estimate_priors(inst, num_cl);
}
/**
* calculate prior probabilites for the clusters
*
* @param inst the instances
* @param num_cl the number of clusters
* @exception Exception if priors can't be calculated
**/
private void estimate_priors (Instances inst, int num_cl)
throws Exception
{
for (int i = 0; i < num_cl; i++) {
m_priors[i] = 0.0;
}
for (int i = 0; i < inst.numInstances(); i++) {
for (int j = 0; j < num_cl; j++) {
m_priors[j] += m_weights[i][j];
}
}
Utils.normalize(m_priors);
}
/**
* Density function of normal distribution.
* @param x input value
* @param mean mean of distribution
* @param stdDev standard deviation of distribution
*/
private double normalDens (double x, double mean, double stdDev) {
double diff = x - mean;
return (1/(m_normConst*stdDev))*Math.exp(-(diff*diff/(2*stdDev*stdDev)));
}
/**
* New probability estimators for an iteration
*
* @param num_cl the numbe of clusters
*/
private void new_estimators (int num_cl) {
for (int i = 0; i < num_cl; i++) {
for (int j = 0; j < m_num_attribs; j++) {
if (m_theInstances.attribute(j).isNominal()) {
m_model[i][j] = new DiscreteEstimator(m_theInstances.
attribute(j).numValues()
, true);
}
else {
m_modelNormal[i][j][0] = m_modelNormal[i][j][1] =
m_modelNormal[i][j][2] = 0.0;
}
}
}
}
/**
* The M step of the EM algorithm.
* @param inst the training instances
* @param num_cl the number of clusters
*/
private void M (Instances inst, int num_cl)
throws Exception
{
int i, j, l;
new_estimators(num_cl);
for (i = 0; i < num_cl; i++) {
for (j = 0; j < m_num_attribs; j++) {
for (l = 0; l < inst.numInstances(); l++) {
if (!inst.instance(l).isMissing(j)) {
?? 快捷鍵說明
復(fù)制代碼
Ctrl + C
搜索代碼
Ctrl + F
全屏模式
F11
切換主題
Ctrl + Shift + D
顯示快捷鍵
?
增大字號
Ctrl + =
減小字號
Ctrl + -