?? controller.java
字號(hào):
long start, stop; Executor pe=ExecutorSinglet.getExecutor(); for (int iter= 0; iter < iterNo && !m_learningTree.boosterIsFinished(); iter++) { if (Monitor.logLevel > 1) { m_monitor.logIteration( iter, m_learningTree.getCombinedPredictor(), m_learningTree.getLastBasePredictor()); } start= System.currentTimeMillis(); Vector candidates= m_learningTree.getCandidates(); stop= System.currentTimeMillis(); if (Monitor.logLevel > 3) { Monitor.log("Learning iteration " + iter + " candidates are:"); Monitor.log(candidates.toString()); Monitor.log("It took " + (stop - start) / 1000.0 + " seconds to generate candidates Vector for iteration " + iter); } // This piece should be replaced by a more general tool to // measure the goodness of a split. // Create a synchronization barrier that counts the number // of processed splitter builders CountDown candidateCount=new CountDown(candidates.size()); // an array to record losses in double[] losses=new double[candidates.size()]; int i=0; for(Iterator ci=candidates.iterator();ci.hasNext();) { CandidateSplit candidate=(CandidateSplit)ci.next(); SplitEvaluatorWorker sew=new SplitEvaluatorWorker(candidate,losses,i,candidateCount); try { pe.execute(sew); } catch (InterruptedException ie) { System.err.println("exception ocurred while handing off the candidate job to the pool: "+ie.getMessage()); ie.printStackTrace(); } i++; } try { candidateCount.acquire(); } catch(InterruptedException ie) { if(candidateCount.currentCount()!=0) { System.err.println("interrupted exception occurred, but the candidateCount is "+candidateCount.currentCount()); } }; // run through the losses results to determine the best split int best=0; if (losses.length==0) { System.err.println("ERROR: There are no candidate weak hypotheses to add to the tree."); System.err.println("This is likely a bug in JBoost; please report to JBoost developers."); System.exit(2); } double bestLoss=losses[best]; double tmpLoss; for(int li=1;li<losses.length;li++) { if((tmpLoss=losses[li]) < bestLoss) { bestLoss=tmpLoss; best=li; } } if (Monitor.logLevel > 3) Monitor.log("Best candidate is: " + (CandidateSplit) candidates.get(best) + "\n"); // add the candidate with lowest loss start= System.currentTimeMillis(); m_learningTree.addCandidate((CandidateSplit) candidates.get(best)); stop= System.currentTimeMillis(); if (Monitor.logLevel > 3) { Monitor.log( "It took " + (stop - start) / 1000.0 + " seconds to add candidate for iteration " + iter); } System.out.println("Finished learning iteration " + iter); if(m_booster instanceof BrownBoost){ iterNo++; } } System.out.println(); if (Monitor.logLevel > 3) m_monitor.logIteration( iterNo, m_learningTree.getCombinedPredictor(), m_learningTree.getLastBasePredictor()); } /** * @param cp a predictor * @throws IncompAttException */ private void test(Predictor cp) throws NotSupportedException, InstrumentException { int size= m_config.getTestSet().getExampleNo(); int i= 0; Example ex= null; Monitor.log("Testing rule."); LabelDescription labelDescription= m_exampleDescription.getLabelDescription(); try { for (i= 0; i < size; i++) { ex= m_config.getTestSet().getExample(i); Prediction prediction= cp.predict(ex.getInstance()); Label label= ex.getLabel(); if (!prediction.getBestClass().equals(label)) { Monitor.log("Test Example " + i + " -----------------------------"); Monitor.log(ex); Monitor.log("------------------------------------------"); Monitor.log(prediction); Monitor.log( "Explanation: " + ((AlternatingTree) cp).explain(ex.getInstance())); } } } catch (IncompAttException e) { // TODO add something here? } } /** * build the array of splitterBuilders * @throws IncompAttException */ private void buildSplitterBuilderArray() throws IncompAttException { Vector sbf= SplitterBuilderFamily.factory(m_config); m_splitterBuilderVector= new Vector(); for (int i= 0; i < sbf.size(); i++) { m_splitterBuilderVector.addAll(((SplitterBuilderFamily) sbf.get(i)).build( m_exampleDescription, m_config, m_booster)); } if (Monitor.logLevel > 3) { Monitor.log("The initial array of splitter Builders is:"); for (int i= 0; i < m_splitterBuilderVector.size(); i++) { Monitor.log("builder " + i + m_splitterBuilderVector.get(i)); } } } /** * Determine the private weight for this example * If the margin weight of the example is less than the threshold, * then we accept the example with probability margin/threshold and * set the private weight to 1/margin * Otherwise we accept the example and set the private weight to 1/threshold * @param data * @param threshold * @return */ private double calculateExampleWeight(Example data, double threshold) { double weight= -1; Label label= data.getLabel(); Instance instance= data.getInstance(); double[] weights= m_serializedTree.predict(instance).getMargins(label); double margin=0; // for now, just handle the binary prediction case if (weights.length == 1) { margin= m_booster.calculateWeight(weights[0]); } // else handle multi-label if (margin < threshold) { double prob= margin/threshold; if (Math.random() <= prob) { weight= 1/margin; } } else { weight= 1/threshold; } return weight; } /** * Add example to booster and training set. * Update SplitterBuilder vector for this example * @param counter * @param example * @param exampleWeight */ private void addTrainingExample(int counter, Example example, double exampleWeight) { if (Monitor.logLevel > 0) { m_config.getTrainSet().addExample(counter, example); } m_booster.addExample(counter, example.getLabel(), exampleWeight); for (int i= 0; i < m_splitterBuilderVector.size(); i++) { if (Monitor.logLevel > 5) { Monitor.log("the class of splitterBuilder " + i + " is " + m_splitterBuilderVector.get(i).getClass()); } ((SplitterBuilder) m_splitterBuilderVector.get(i)).addExample(counter, example); } } /** * Read the training file and initialize the booster and the splitterBuilders * with its content * @throws BadLabelException * @throws IncompAttException * @throws ParseException */ private void readTrainData() throws IncompAttException, ParseException { long start, stop; m_config.setTrainSet(new ExampleSet(m_exampleDescription)); Example example= null; int counter= 0; start= System.currentTimeMillis(); boolean sampling= m_config.getBool(ControllerConfiguration.SAMPLE_TRAINING_DATA, false); double threshold= m_config.getDouble(ControllerConfiguration.SAMPLE_THRESHOLD_WEIGHT, 0.500); while ((example= m_trainStream.getExample()) != null) { double exampleWeight= example.getWeight(); boolean accepted= true; // if we are sampling, then calculate the weight of this example // if the exampleWeight is zero, then don't accept this example if (sampling) { exampleWeight= calculateExampleWeight(example, threshold); if (exampleWeight < 0) { accepted= false; } } // The default behavior is to accept each example. An // example is only refused if we are sampling and its // weight is set to zero. if (accepted) { addTrainingExample(counter, example, exampleWeight); counter++; } if ((counter % 100) == 0) { System.out.print("Read " + counter + " training examples\n"); } } stop= System.currentTimeMillis(); System.out.println("Read " + counter + " training examples in " + (stop - start) / 1000.0 + " seconds."); m_config.getTrainSet().finalizeData(); m_booster.finalizeData(); for (int i= 0; i < m_splitterBuilderVector.size(); i++) { ((SplitterBuilder) m_splitterBuilderVector.get(i)).finalizeData(); } m_trainSetIndices= new int[counter]; for (int i= 0; i < counter; i++) m_trainSetIndices[i]= i; } /** * initialize the tokenizer * @throws Exception */ private void startTokenizer() throws Exception { DataStream ds= null; ds= new jboost_DataStream(m_config.getSpecFileName(), m_config.getTrainFileName()); m_trainStream= new ExampleStream(ds); ds= new jboost_DataStream(m_config.getSpecFileName(), m_config.getTestFileName()); m_testStream= new ExampleStream(ds); } /** read the test data file */ private void readTestData() throws BadLabelException, IncompAttException, ParseException { long start, stop; m_config.setTestSet(new ExampleSet(m_exampleDescription)); Example example= null; int counter= 0; start= System.currentTimeMillis(); while ((example= m_testStream.getExample()) != null) { m_config.getTestSet().addExample(counter, example); counter++; if ((counter % 100) == 0) { System.out.print("read " + counter + " test examples\n"); } } stop= System.currentTimeMillis(); System.out.println("read " + counter + " test examples in " + (stop - start) / 1000.0 + " seconds."); m_config.getTestSet().finalizeData(); } private void reportResults() { try { PrintWriter resultOutputStream= new PrintWriter(new BufferedWriter(new FileWriter(m_config.getResultOutputFileName()))); resultOutputStream.println(m_learningTree); resultOutputStream.close(); } catch (Exception e) { System.err.println("Exception occured while attempting to write result"); e.printStackTrace(); } } private void generateCode( WritablePredictor predictor, String language, String codeOutputFileName, String procedureName) { try { String code= null; if (language.equals("C")) code= predictor.toC(procedureName, m_exampleDescription); else if (language.equals("MatLab")) code= predictor.toMatlab(procedureName, m_exampleDescription); else if (language.equals("java")) code= predictor.toJava( procedureName, m_config.getString("javaOutputMethod", "predict"), (m_config.getBool("javaStandAlone", false) ? null : m_config.getSpecFileName()), m_exampleDescription); else throw new RuntimeException( "Controller.generateCode: Unrecognized language:" + language); PrintWriter codeOutputStream= new PrintWriter(new BufferedWriter(new FileWriter(codeOutputFileName))); codeOutputStream.println(code); codeOutputStream.close(); } catch (Exception e) { System.err.println( "Exception occured while attempting to write " + language + " code"); System.err.println("Message:" + e); e.printStackTrace(); } }}
?? 快捷鍵說(shuō)明
復(fù)制代碼
Ctrl + C
搜索代碼
Ctrl + F
全屏模式
F11
切換主題
Ctrl + Shift + D
顯示快捷鍵
?
增大字號(hào)
Ctrl + =
減小字號(hào)
Ctrl + -