?? kbestparseforest.java
字號:
package mstparser;public class KBestParseForest { public static int rootType; public ParseForestItem[][][][][] chart; private String[] sent,pos; private int start,end; private int K; public KBestParseForest(int start, int end, DependencyInstance inst, int K) { this.K = K; chart = new ParseForestItem[end+1][end+1][2][2][K]; this.start = start; this.end = end; this.sent = inst.sentence; this.pos = inst.pos; } public boolean add(int s, int type, int dir, double score, FeatureVector fv) { boolean added = false; if(chart[s][s][dir][0][0] == null) { for(int i = 0; i < K; i++) chart[s][s][dir][0][i] = new ParseForestItem(s,type,dir,Double.NEGATIVE_INFINITY,null); } if(chart[s][s][dir][0][K-1].prob > score) return false; for(int i = 0; i < K; i++) { if(chart[s][s][dir][0][i].prob < score) { ParseForestItem tmp = chart[s][s][dir][0][i]; chart[s][s][dir][0][i] = new ParseForestItem(s,type,dir,score,fv); for(int j = i+1; j < K && tmp.prob != Double.NEGATIVE_INFINITY; j++) { ParseForestItem tmp1 = chart[s][s][dir][0][j]; chart[s][s][dir][0][j] = tmp; tmp = tmp1; } added = true; break; } } return added; } public boolean add(int s, int r, int t, int type, int dir, int comp, double score, FeatureVector fv, ParseForestItem p1, ParseForestItem p2) { boolean added = false; if(chart[s][t][dir][comp][0] == null) { for(int i = 0; i < K; i++) chart[s][t][dir][comp][i] = new ParseForestItem(s,r,t,type,dir,comp,Double.NEGATIVE_INFINITY,null,null,null); } if(chart[s][t][dir][comp][K-1].prob > score) return false; for(int i = 0; i < K; i++) { if(chart[s][t][dir][comp][i].prob < score) { ParseForestItem tmp = chart[s][t][dir][comp][i]; chart[s][t][dir][comp][i] = new ParseForestItem(s,r,t,type,dir,comp,score,fv,p1,p2); for(int j = i+1; j < K && tmp.prob != Double.NEGATIVE_INFINITY; j++) { ParseForestItem tmp1 = chart[s][t][dir][comp][j]; chart[s][t][dir][comp][j] = tmp; tmp = tmp1; } added = true; break; } } return added; } public double getProb(int s, int t, int dir, int comp) { return getProb(s,t,dir,comp,0); } public double getProb(int s, int t, int dir, int comp, int i) { if(chart[s][t][dir][comp][i] != null) return chart[s][t][dir][comp][i].prob; return Double.NEGATIVE_INFINITY; } public double[] getProbs(int s, int t, int dir, int comp) { double[] result = new double[K]; for(int i = 0; i < K; i++) result[i] = chart[s][t][dir][comp][i] != null ? chart[s][t][dir][comp][i].prob : Double.NEGATIVE_INFINITY; return result; } public ParseForestItem getItem(int s, int t, int dir, int comp) { return getItem(s,t,dir,comp,0); } public ParseForestItem getItem(int s, int t, int dir, int comp, int k) { if(chart[s][t][dir][comp][k] != null) return chart[s][t][dir][comp][k]; return null; } public ParseForestItem[] getItems(int s, int t, int dir, int comp) { if(chart[s][t][dir][comp][0] != null) return chart[s][t][dir][comp]; return null; } public Object[] getBestParse() { Object[] d = new Object[2]; d[0] = getFeatureVector(chart[0][end][0][0][0]); d[1] = getDepString(chart[0][end][0][0][0]); return d; } public Object[][] getBestParses() { Object[][] d = new Object[K][2]; for(int k = 0; k < K; k++) { if(chart[0][end][0][0][k].prob != Double.NEGATIVE_INFINITY) { d[k][0] = getFeatureVector(chart[0][end][0][0][k]); d[k][1] = getDepString(chart[0][end][0][0][k]); } else { d[k][0] = null; d[k][1] = null; } } return d; } public FeatureVector getFeatureVector(ParseForestItem pfi) { if(pfi.left == null) return pfi.fv; return cat(pfi.fv,cat(getFeatureVector(pfi.left),getFeatureVector(pfi.right))); } public String getDepString(ParseForestItem pfi) { if(pfi.left == null) return ""; if(pfi.comp == 0) { return (getDepString(pfi.left) + " " + getDepString(pfi.right)).trim(); } else if(pfi.dir == 0) { return ((getDepString(pfi.left)+" "+getDepString(pfi.right)).trim()+" " +pfi.s+"|"+pfi.t+":"+pfi.type).trim(); } else { return (pfi.t+"|"+pfi.s+":"+pfi.type +" " +(getDepString(pfi.left)+" "+getDepString(pfi.right)).trim()).trim(); } } public FeatureVector cat(FeatureVector fv1, FeatureVector fv2) { return FeatureVector.cat(fv1,fv2); } // returns pairs of indeces and -1,-1 if < K pairs public int[][] getKBestPairs(ParseForestItem[] items1, ParseForestItem[] items2) { // in this case K = items1.length boolean[][] beenPushed = new boolean[K][K]; int[][] result = new int[K][2]; for(int i = 0; i < K; i++) { result[i][0] = -1; result[i][1] = -1; } if(items1 == null || items2 == null || items1[0] == null || items2[0] == null) return result; BinaryHeap heap = new BinaryHeap(K+1); int n = 0; ValueIndexPair vip = new ValueIndexPair(items1[0].prob+items2[0].prob,0,0); heap.add(vip); beenPushed[0][0] = true; while(n < K) { vip = heap.removeMax(); if(vip.val == Double.NEGATIVE_INFINITY) break; result[n][0] = vip.i1; result[n][1] = vip.i2; n++; if(n >= K) break; if(!beenPushed[vip.i1+1][vip.i2]) { heap.add(new ValueIndexPair(items1[vip.i1+1].prob+items2[vip.i2].prob,vip.i1+1,vip.i2)); beenPushed[vip.i1+1][vip.i2] = true; } if(!beenPushed[vip.i1][vip.i2+1]) { heap.add(new ValueIndexPair(items1[vip.i1].prob+items2[vip.i2+1].prob,vip.i1,vip.i2+1)); beenPushed[vip.i1][vip.i2+1] = true; } } return result; }} class ValueIndexPair { public double val; public int i1, i2; public ValueIndexPair(double val, int i1, int i2) { this.val = val; this.i1 = i1; this.i2 = i2; } public int compareTo(ValueIndexPair other) { if(val < other.val) return -1; if(val > other.val) return 1; return 0; } }// Max Heap// We know that never more than K elements on Heapclass BinaryHeap { private int DEFAULT_CAPACITY; private int currentSize; private ValueIndexPair[] theArray; public BinaryHeap(int def_cap) { DEFAULT_CAPACITY = def_cap; theArray = new ValueIndexPair[DEFAULT_CAPACITY+1]; // theArray[0] serves as dummy parent for root (who is at 1) // "largest" is guaranteed to be larger than all keys in heap theArray[0] = new ValueIndexPair(Double.POSITIVE_INFINITY,-1,-1); currentSize = 0; } public ValueIndexPair getMax() { return theArray[1]; } private int parent(int i) { return i / 2; } private int leftChild(int i) { return 2 * i; } private int rightChild(int i) { return 2 * i + 1; } public void add(ValueIndexPair e) { // bubble up: int where = currentSize + 1; // new last place while ( e.compareTo(theArray[parent(where)]) > 0 ){ theArray[where] = theArray[parent(where)]; where = parent(where); } theArray[where] = e; currentSize++; } public ValueIndexPair removeMax() { ValueIndexPair min = theArray[1]; theArray[1] = theArray[currentSize]; currentSize--; boolean switched = true; // bubble down for ( int parent = 1; switched && parent < currentSize; ) { switched = false; int leftChild = leftChild(parent); int rightChild = rightChild(parent); if(leftChild <= currentSize) { // if there is a right child, see if we should bubble down there int largerChild = leftChild; if ((rightChild <= currentSize) && (theArray[rightChild].compareTo(theArray[leftChild])) > 0){ largerChild = rightChild; } if (theArray[largerChild].compareTo(theArray[parent]) > 0) { ValueIndexPair temp = theArray[largerChild]; theArray[largerChild] = theArray[parent]; theArray[parent] = temp; parent = largerChild; switched = true; } } } return min; } }
?? 快捷鍵說明
復制代碼
Ctrl + C
搜索代碼
Ctrl + F
全屏模式
F11
切換主題
Ctrl + Shift + D
顯示快捷鍵
?
增大字號
Ctrl + =
減小字號
Ctrl + -