Java tutorial
/** * This file is part of FNLP (formerly FudanNLP). * * FNLP is free software: you can redistribute it and/or modify * it under the terms of the GNU Lesser General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * FNLP is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU Lesser General Public License for more details. * * You should have received a copy of the GNU General Public License * along with FudanNLP. If not, see <http://www.gnu.org/licenses/>. * * Copyright 2009-2014 www.fnlp.org. All rights reserved. */ package org.fnlp.nlp.similarity.train; import java.io.BufferedInputStream; import java.io.BufferedOutputStream; import java.io.BufferedWriter; import java.io.File; import java.io.FileInputStream; import java.io.FileOutputStream; import java.io.IOException; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; import java.io.OutputStreamWriter; import java.io.Serializable; import java.util.Date; import java.util.zip.GZIPInputStream; import java.util.zip.GZIPOutputStream; import org.apache.commons.cli.BasicParser; import org.apache.commons.cli.CommandLine; import org.apache.commons.cli.Options; import org.fnlp.data.reader.Reader; import org.fnlp.ml.types.alphabet.LabelAlphabet; import org.fnlp.nlp.similarity.Cluster; import org.fnlp.util.MyArrays; import org.fnlp.util.MyCollection; import org.fnlp.util.MyHashSparseArrays; import gnu.trove.iterator.TIntFloatIterator; import gnu.trove.iterator.TIntIterator; import gnu.trove.iterator.TIntObjectIterator; import gnu.trove.iterator.hash.TObjectHashIterator; import gnu.trove.map.hash.TIntFloatHashMap; import gnu.trove.map.hash.TIntIntHashMap; import gnu.trove.map.hash.TIntObjectHashMap; import gnu.trove.set.hash.TIntHashSet; import gnu.trove.set.hash.TLinkedHashSet; /** * Brown ??? * @author xpqiu * */ public class WordCluster implements Serializable { private static final long serialVersionUID = 1632709924496094832L; private static float ENERGY = 0.999f; public int slotsize = 50; int lastid; LabelAlphabet alpahbet = new LabelAlphabet(); TIntObjectHashMap<TIntHashSet> leftnodes = new TIntObjectHashMap<TIntHashSet>(); TIntObjectHashMap<TIntHashSet> rightnodes = new TIntObjectHashMap<TIntHashSet>(); TIntObjectHashMap<Cluster> clusters = new TIntObjectHashMap<Cluster>(); /** * */ TIntIntHashMap heads = new TIntIntHashMap(200, 0.5f, -1, -1); TIntHashSet slots = new TIntHashSet(); /** * ? */ TIntObjectHashMap<TIntFloatHashMap> pcc = new TIntObjectHashMap<TIntFloatHashMap>(); /** * ? */ TIntObjectHashMap<TIntFloatHashMap> wcc = new TIntObjectHashMap<TIntFloatHashMap>(); TIntFloatHashMap wordProb = new TIntFloatHashMap(); public int totalword; /** * ??? */ private boolean meger = true; public WordCluster() { } /** * ? */ public void read(Reader reader) { totalword = 0; while (reader.hasNext()) { String content = (String) reader.next().getData(); int prechar = -1; wordProb.adjustOrPutValue(prechar, 1, 1); totalword += content.length() + 2; for (int i = 0; i < content.length() + 1; i++) { int idx; if (i < content.length()) { String c = String.valueOf(content.charAt(i)); idx = alpahbet.lookupIndex(c); } else { idx = -2; } wordProb.adjustOrPutValue(idx, 1, 1); TIntFloatHashMap map = pcc.get(prechar); if (map == null) { map = new TIntFloatHashMap(); pcc.put(prechar, map); } map.adjustOrPutValue(idx, 1, 1); TIntHashSet left = leftnodes.get(idx); if (left == null) { left = new TIntHashSet(); leftnodes.put(idx, left); } left.add(prechar); TIntHashSet right = rightnodes.get(prechar); if (right == null) { right = new TIntHashSet(); rightnodes.put(prechar, right); } right.add(idx); prechar = idx; } } lastid = alpahbet.size(); System.out.println("[]\t" + totalword); int size = alpahbet.size(); System.out.println("[?]\t" + size); statisticProb(); } /** * */ private void statisticProb() { System.out.println(""); TIntFloatIterator it = wordProb.iterator(); while (it.hasNext()) { it.advance(); float v = it.value() / totalword; it.setValue(v); int key = it.key(); if (key < 0) continue; Cluster cluster = new Cluster(key, v, alpahbet.lookupString(key)); clusters.put(key, cluster); } TIntObjectIterator<TIntFloatHashMap> it1 = pcc.iterator(); while (it1.hasNext()) { it1.advance(); TIntFloatHashMap map = it1.value(); TIntFloatIterator it2 = map.iterator(); while (it2.hasNext()) { it2.advance(); it2.setValue(it2.value() / totalword); } } } /** * total graph weight * * @param c1 * @param c2 * @param b * @return */ private float weight(int c1, int c2) { float w; float pc1 = wordProb.get(c1); float pc2 = wordProb.get(c2); if (c1 == c2) { float pcc = getProb(c1, c1); w = clacW(pcc, pc1, pc2); } else { float pcc1 = getProb(c1, c2); float p1 = clacW(pcc1, pc1, pc2); float pcc2 = getProb(c2, c1); float p2 = clacW(pcc2, pc2, pc1); w = p1 + p2; } setweight(c1, c2, w); return w; } /** * c1,c2??k?? * @param c1 * @param c2 * @param k * @return */ private float weight(int c1, int c2, int k) { float w; float pc1 = wordProb.get(c1); float pc2 = wordProb.get(c2); float pck = wordProb.get(k); // float pc = pc1 + pc2; if (c1 == k) { float pcc1 = getProb(c1, c1); float pcc2 = getProb(c2, c2); float pcc3 = getProb(c1, c2); float pcc4 = getProb(c2, c1); float pcc = pcc1 + pcc2 + pcc3 + pcc4; w = clacW(pcc, pc, pc); } else { float pcc1 = getProb(c1, k); float pcc2 = getProb(c2, k); float pcc12 = pcc1 + pcc2; float p1 = clacW(pcc12, pc, pck); float pcc3 = getProb(k, c1); float pcc4 = getProb(k, c2); float pcc34 = pcc3 + pcc4; float p2 = clacW(pcc34, pck, pc); w = p1 + p2; } return w; } private float clacW(float pcc, float pc1, float pc2) { float p = 0; if (pcc != 0f) p = pcc * (float) (Math.log(pcc) - Math.log(pc1) - Math.log(pc2)); // if(Float.isInfinite(p)||Float.isNaN(p)) // return p; return p; } private float getProb(int c1, int c2) { float p; TIntFloatHashMap map = pcc.get(c1); if (map == null) { p = 0f; } else { p = pcc.get(c1).get(c2); } return p; } /** * merge clusters */ public void mergeCluster() { int maxc1 = -1; int maxc2 = -1; float maxL = Float.NEGATIVE_INFINITY; TIntIterator it1 = slots.iterator(); while (it1.hasNext()) { int i = it1.next(); TIntIterator it2 = slots.iterator(); // System.out.print(i+": "); while (it2.hasNext()) { int j = it2.next(); if (i >= j) continue; // System.out.print(j+" "); float L = calcL(i, j); // System.out.print(L+" "); if (L > maxL) { maxL = L; maxc1 = i; maxc2 = j; } } // System.out.println(); } // if(maxL == Float.NEGATIVE_INFINITY ) // return; merge(maxc1, maxc2); } /** * ?c1c2 * @param c1 * @param c2 */ protected void merge(int c1, int c2) { int newid = lastid++; heads.put(c1, newid); heads.put(c2, newid); TIntFloatHashMap newpcc = new TIntFloatHashMap(); TIntFloatHashMap inewpcc = new TIntFloatHashMap(); TIntFloatHashMap newwcc = new TIntFloatHashMap(); float pc1 = wordProb.get(c1); float pc2 = wordProb.get(c2); // float pc = pc1 + pc2; float w; { float pcc1 = getProb(c1, c1); float pcc2 = getProb(c2, c2); float pcc3 = getProb(c1, c2); float pcc4 = getProb(c2, c1); float pcc = pcc1 + pcc2 + pcc3 + pcc4; if (pcc != 0.0f) newpcc.put(newid, pcc); w = clacW(pcc, pc, pc); if (w != 0.0f) newwcc.put(newid, w); } TIntIterator it = slots.iterator(); while (it.hasNext()) { int k = it.next(); float pck = wordProb.get(k); if (c1 == k || c2 == k) { continue; } else { float pcc1 = getProb(c1, k); float pcc2 = getProb(c2, k); float pcc12 = pcc1 + pcc2; if (pcc12 != 0.0f) newpcc.put(newid, pcc12); float p1 = clacW(pcc12, pc, pck); float pcc3 = getProb(k, c1); float pcc4 = getProb(k, c2); float pcc34 = pcc3 + pcc4; if (pcc34 != 0.0f) inewpcc.put(k, pcc34); float p2 = clacW(pcc34, pck, pc); w = p1 + p2; if (w != 0.0f) newwcc.put(newid, w); } } //slots slots.remove(c1); slots.remove(c2); slots.add(newid); pcc.put(newid, newpcc); pcc.remove(c1); pcc.remove(c2); TIntFloatIterator it2 = inewpcc.iterator(); while (it2.hasNext()) { it2.advance(); TIntFloatHashMap pmap = pcc.get(it2.key()); // if(pmap==null){ // pmap = new TIntFloatHashMap(); // pcc.put(it2.key(), pmap); // } pmap.put(newid, it2.value()); pmap.remove(c1); pmap.remove(c2); } // //newid it3.key; wcc.put(newid, new TIntFloatHashMap()); wcc.remove(c1); wcc.remove(c2); TIntFloatIterator it3 = newwcc.iterator(); while (it3.hasNext()) { it3.advance(); TIntFloatHashMap pmap = wcc.get(it3.key()); pmap.put(newid, it3.value()); pmap.remove(c1); pmap.remove(c2); } wordProb.remove(c1); wordProb.remove(c2); wordProb.put(newid, pc); //cluster Cluster cluster = new Cluster(newid, clusters.get(c1), clusters.get(c2), pc); clusters.put(newid, cluster); System.out.println("?" + cluster.rep); } /** * calculate the value L * * @param c1 * @param c2 * @param window * @return */ public float calcL(int c1, int c2) { float L = 0; TIntIterator it = slots.iterator(); while (it.hasNext()) { int k = it.next(); if (k == c2) continue; L += weight(c1, c2, k); } it = slots.iterator(); while (it.hasNext()) { int k = it.next(); L -= getweight(c1, k); L -= getweight(c2, k); } return L; } private void setweight(int c1, int c2, float w) { if (w == 0.0f) return; int max, min; if (c1 <= c2) { max = c2; min = c1; } else { max = c1; min = c2; } TIntFloatHashMap map2 = wcc.get(min); if (map2 == null) { map2 = new TIntFloatHashMap(); wcc.put(min, map2); } map2.put(max, w); } private float getweight(int c1, int c2) { int max, min; if (c1 <= c2) { max = c2; min = c1; } else { max = c1; min = c2; } float w; TIntFloatHashMap map2 = wcc.get(min); if (map2 == null) { w = 0; } else w = map2.get(max); return w; } /** * start clustering */ public Cluster startClustering() { // int[] idx = MyCollection.sort(wordProb); wordProb.remove(-1); wordProb.remove(-2); int[] idx = MyHashSparseArrays.trim(wordProb, ENERGY); int mergeCount = idx.length; int remainCount = idx.length; System.out.println("[?]\t" + mergeCount); System.out.println("[]\t" + totalword); int round; for (round = 0; round < Math.min(slotsize, mergeCount); round++) { slots.add(idx[round]); System.out.println(round + "\t" + alpahbet.lookupString(idx[round]) + "\t" + slots.size()); } TIntIterator it1 = slots.iterator(); while (it1.hasNext()) { int i = it1.next(); TIntIterator it2 = slots.iterator(); while (it2.hasNext()) { int j = it2.next(); if (i > j) continue; weight(i, j); } } while (slots.size() > 1) { if (round < mergeCount) System.out.println(round + "\t" + alpahbet.lookupString(idx[round]) + "\tSize:\t" + slots.size()); else System.out.println(round + "\t" + "\tSize:\t" + slots.size()); System.out.println("[?]\t" + remainCount--); long starttime = System.currentTimeMillis(); mergeCluster(); long endtime = System.currentTimeMillis(); System.out.println("\tTime:\t" + (endtime - starttime) / 1000.0); if (round < mergeCount) { int id = idx[round]; slots.add(id); TIntIterator it = slots.iterator(); while (it.hasNext()) { int j = it.next(); weight(j, id); } } else { if (!meger) return null; } try { saveTxt("../tmp/res-" + round); } catch (Exception e) { // TODO Auto-generated catch block e.printStackTrace(); } round++; } return clusters.get(slots.toArray()[0]); } public String toString() { StringBuilder sb = new StringBuilder(); TIntObjectHashMap<TLinkedHashSet<String>> sets = new TIntObjectHashMap<TLinkedHashSet<String>>(); for (int i = 0; i < alpahbet.size(); i++) { int head = getHead(i); TLinkedHashSet<String> s = sets.get(head); if (s == null) { s = new TLinkedHashSet(); sets.put(head, s); } s.add(alpahbet.lookupString(i)); } TIntObjectIterator<TLinkedHashSet<String>> it = sets.iterator(); while (it.hasNext()) { it.advance(); if (it.value().size() < 2) continue; sb.append(wordProb.get(it.key())); sb.append(" "); TObjectHashIterator<String> itt = it.value().iterator(); while (itt.hasNext()) { String ss = itt.next(); sb.append(ss); sb.append(" "); } sb.append("\n"); } return sb.toString(); } private int getHead(int i) { int h = heads.get(i); if (h == -1) return i; else return getHead(h); } /** * * @param file * @throws IOException */ public void saveModel(String file) throws IOException { File f = new File(file); File path = f.getParentFile(); if (!path.exists()) { path.mkdirs(); } ObjectOutputStream out = new ObjectOutputStream( new GZIPOutputStream(new BufferedOutputStream(new FileOutputStream(file)))); out.writeObject(this); out.close(); } public static WordCluster loadFrom(String file) throws IOException, ClassNotFoundException { ObjectInputStream in = new ObjectInputStream( new GZIPInputStream(new BufferedInputStream(new FileInputStream(file)))); WordCluster cl = (WordCluster) in.readObject(); in.close(); return cl; } /** * ? * @param file * @throws Exception */ public void saveTxt(String file) throws Exception { FileOutputStream fos = new FileOutputStream(file); BufferedWriter bout = new BufferedWriter(new OutputStreamWriter(fos, "UTF8")); bout.write(this.toString()); bout.close(); } /** * @param args * @throws Exception */ public static void main(String[] args) throws Exception { /** * ?? */ Options opt = new Options(); opt.addOption("path", true, "?"); opt.addOption("res", true, "?"); opt.addOption("slot", true, "?"); BasicParser parser = new BasicParser(); CommandLine cl; try { cl = parser.parse(opt, args); } catch (Exception e) { System.err.println("Parameters format error"); return; } int slotsize = Integer.parseInt(cl.getOptionValue("slot", "50")); System.out.println("?:" + slotsize); String file = cl.getOptionValue("path", "./tmp/news.allsites.txt"); System.out.println("?:" + file); String resfile = cl.getOptionValue("res", "./tmp/res.txt"); System.out.println(":" + resfile); SougouCA sca = new SougouCA(file); WordCluster wc = new WordCluster(); wc.slotsize = slotsize; wc.read(sca); wc.startClustering(); wc.saveModel(resfile + ".m"); wc.saveTxt(resfile); wc = WordCluster.loadFrom(resfile + ".m"); wc.saveTxt(resfile + "1"); System.out.println(new Date().toString()); System.out.println("Done"); } }