?? learner.py
字號:
# -*- coding: iso-8859-1 -*-## This file is a part of the Bayes Blocks library## Copyright (C) 2001-2006 Markus Harva, Antti Honkela, Alexander# Ilin, Tapani Raiko, Harri Valpola and Tomas 謘tman.## This program is free software; you can redistribute it and/or modify# it under the terms of the GNU General Public License as published by# the Free Software Foundation; either version 2, or (at your option)# any later version.## This program is distributed in the hope that it will be useful,# but WITHOUT ANY WARRANTY; without even the implied warranty of# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the# GNU General Public License (included in file License.txt in the# program package) for more details.## $Id: Learner.py 5 2006-10-26 09:44:54Z ah $#import Helpersimport timeimport signalsigintencountered = 0def siginthandler(signum, frame): global sigintencountered sigintencountered = 1SIGNORE = -1SSTOP = 0SRAISE = 1class timewrap: def __init__(self): self.lastclock = time.clock() self.wraps = 0 def clock(self): newtime = time.clock() if self.lastclock > newtime: self.wraps += 1 self.lastclock = newtime # math.exp(32*math.log(2)-6*math.log(10)) == 4294.9672959999916 return newtime + self.wraps*4294.9672959999916class Learner: def __getstate__(self): odict = self.__dict__.copy() del odict['costfun'] odict['costfunname'] = self.costfun.im_func.func_name odict['costfunself'] = self.costfun.im_self return odict def __setstate__(self, dict): if dict.has_key('costfunname') and dict.has_key('costfunself'): costfun = eval("dict['costfunself']" + "." + dict['costfunname']) del dict['costfunname'] del dict['costfunself'] self.__dict__.update(dict) self.costfun = costfun else: self.__dict__.update(dict) self.costfun = self.net.Cost def __init__(self, net, prunefunc=None): self.net = net self.printhooke = 1 self.prunefunc = prunefunc self.history = [] self.benchmark = [] self.iter = 0 self.time = timewrap() self.functions = {} self.function_timings = [] self.calls = [] self.costfun = self.net.Cost self.stopnow = False def AddFunction(self, name, func, timing): if not Helpers.IsSequence(timing): timing = (timing, 0) self.function_timings.append((timing, name)) self.functions[name] = func def AddCall(self, name, func, iter): if not Helpers.IsSequence(iter): iter = (iter,) for i in iter: assert isinstance(i, int) self.calls.append((i, name, func)) def SortCalls(self): """Sorts self.calls only using first item in each tuple.""" self.calls.sort(lambda a, b: cmp(a[0], b[0])) def CallFunc(self, func, name=None, verbose=0): if callable(func): val = func() elif len(func) == 1: val = func[0]() elif len(func) == 2: val = func[0](*func[1]) elif len(func) == 3: val = func[0](*func[1], **func[2]) else: raise ValueError, "len(func) is not in (1,2,3)" if (name is not None) and val: self.HistoryAdd(name, val) if verbose: print name+":", val def HistoryAdd(self, record, data = ()): if type(data) == type(()): self.history.append((record, self.iter, self.costfun())+data) else: self.history.append((record, self.iter, self.costfun(), data)) def TryPruning(self, *args, **kws): pruningout = apply(self.prunefunc, (self.net,) + args, kws) self.HistoryAdd("Pruning", pruningout) return pruningout def HookeJeeves(self): tmp = self.time.clock() hookeout = self.net.UpdateAllHookeJeeves( exploresteps=self.exploresteps, returncost = 1) timeused = self.time.clock()-tmp self.benchmark.append(timeused) self.HistoryAdd("HookeJeeves", (timeused, self.time.clock()) + hookeout) if self.printhooke: print self.iter, ":", hookeout self.iter += 1 return hookeout def Iteration(self, debug = 0): tmp = self.time.clock() if debug: self.net.UpdateAllDebug() else: self.net.UpdateAll() timeused = self.time.clock()-tmp self.benchmark.append(timeused) self.HistoryAdd("Iteration", (timeused, self.time.clock())) self.iter += 1 def CheckDoHooke(self, hooke): if hooke: if isinstance(hooke, int): if self.iter%hooke == hooke/2: return 1 elif isinstance(hooke, list): if len(hooke) and hooke[-1] >= self.iter: del hooke[-1] return 1 elif callable(hooke): return hooke(self.iter) else: raise TypeError, "Bad type for hooke parameter" return 0 def LearnNet(self, printcost=0, iters=200, hooke=None, exploresteps=1, printhooke=None, raisekbd=SSTOP, debug=0, verbosecall=1): """Learns the net. Parameters: printcost=0 - If nonzero causes printing of cost every printcost iterations. iters=200 - Number of iteration where to stop (numbering starts where LearnNet() last time stopped). hooke=None - integer, list or function if None (or not given) use self.usehooke if integer != 0, then Hooke-Jeeves is performed when iter%hooke==hooke/2 if list Hooke-Jeeves is performed if (iter in hooke) if function HookeJeeves is performde if funciton(iter) exploresteps=4 - Sent as parameter to UpdateAllHookeJeeves() """ self.stopnow = False if debug: self.costfun = self.net.CostDebug else: self.costfun = self.net.Cost if iters < 0: iters = self.iter - iters self.exploresteps=exploresteps if printhooke is not None: self.printhooke = printhooke else: if printcost == 0: self.printhooke = 0 else: self.printhooke = 1 if raisekbd != SIGNORE: global sigintencountered sigintencountered = 0 oldsiginthandler = signal.signal(signal.SIGINT, siginthandler) if printcost: print self.iter, ":", self.costfun() if type(hooke) == type([]): hooke = hooke[:] hooke.sort() hooke.reverse() #TODO some better find algorithm while len(hooke) and hooke[-1] < iter: del hooke[-1] self.SortCalls() while self.iter < iters: if (raisekbd != SIGNORE) and sigintencountered: print "Learning stopped with ctrl-C" sigintencountered = 0 if raisekbd == SSTOP: return 1 else: raise KeyboardInterrupt if self.CheckDoHooke(hooke): self.HookeJeeves() else: self.Iteration(debug) while len(self.calls) > 0 and self.calls[0][0] == self.iter: self.CallFunc(self.calls[0][2], self.calls[0][1], verbose=verbosecall) del self.calls[0] for t in self.function_timings: if (self.iter%t[0][0] == t[0][1]): self.CallFunc(self.functions[t[1]], t[1]) if printcost and (self.iter%printcost == 0): if debug: debugstring = "(debug) " else: debugstring = "" print "%s%d : %f" % (debugstring, self.iter, self.costfun()) if self.stopnow: break if raisekbd != SIGNORE: signal.signal(signal.SIGINT, oldsiginthandler) if sigintencountered: print "Learning ready when ctrl-C was sent" sigintencountered = 0 if raisekbd: raise KeyboardInterrupt return 0class LearnerOL(Learner): def __init__(self, net, stepperfunc, stepperargs=(), epsilon=1e-2): self.net = net self.history = [] self.benchmark = [] self.timestep = 0 self.iter = 0 self.time = timewrap() self.epsilon = epsilon self.stepperfunc = stepperfunc self.stepperargs = stepperargs self.functions = {} self.function_timings = [] def HistoryAdd(self, record, data = ()): if type(data) == type(()): self.history.append( (record, self.timestep, self.iter, self.net.Cost())+data) else: self.history.append( (record, self.timestep, self.iter, self.net.Cost(), data)) def PrintCost(self): print `self.timestep` + "(" + `self.iter` + "):", self.net.Cost() def Iteration(self, printiters = 0): self.net.UpdateTimeDep() if printiters and (self.iter%printiters == 0): self.PrintCost() self.iter += 1 def StepTime(self, printsteps = 1): self.net.UpdateTimeInd() if printsteps and (self.timestep%printsteps == 0): self.PrintCost() self.iter = 0 self.timestep += 1 self.stepperfunc(*self.stepperargs) def DoTimeStep(self, printiters = 0, printsteps = 1): oldcost = 1e300 cost = self.net.Cost() while (oldcost - cost) > self.epsilon: self.Iteration(printiters = printiters) oldcost = cost cost = self.net.Cost() self.StepTime(printsteps = printsteps) def LearnNet(self, stopstep, printiters = 0, printsteps = 1, raisekbd = SSTOP): if stopstep < 0: stopstep = self.timestep - stopstep if raisekbd not in (SIGNORE,SSTOP,SRAISE): raise ValueError, "Unkown raisekbd value" if raisekbd != SIGNORE: global sigintencountered sigintencountered = 0 oldsiginthandler = signal.signal(signal.SIGINT, siginthandler) while self.timestep < stopstep: if (raisekbd != SIGNORE) and sigintencountered: print "Learning stopped with ctrl-C" sigintencountered = 0 if raisekbd == SSTOP: return 1 elif raisekbd == SRAISE: raise KeyboardInterrupt else: assert(0) self.DoTimeStep(printiters = printiters, printsteps = printsteps) for t in self.function_timings: if (self.timestep%t[0][0] == t[0][1]): self.CallFunc(self.functions[t[1]], t[1]) if raisekbd != SIGNORE: signal.signal(signal.SIGINT, oldsiginthandler) if sigintencountered: print "Learning ready when ctrl-C was sent" sigintencountered = 0 if raisekbd == SRAISE: raise KeyboardInterrupt
?? 快捷鍵說明
復制代碼
Ctrl + C
搜索代碼
Ctrl + F
全屏模式
F11
切換主題
Ctrl + Shift + D
顯示快捷鍵
?
增大字號
Ctrl + =
減小字號
Ctrl + -