Java tutorial
/* * (C) Copyright 2015 ILLC University of Amsterdam (http://www.illc.uva.nl) * * This work was supported by "STW Open Technologieprogramma" grant * under project name "Data-Powered Domain-Specific Translation Services On Demand" * * All rights reserved. This program and the accompanying materials * are made available under the terms of the GNU Lesser General Public License * (LGPL) version 2.1 which accompanies this distribution, and is available at * http://www.gnu.org/licenses/lgpl-2.1.html * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * */ package nl.uva.illc.dataselection; import edu.berkeley.nlp.lm.ConfigOptions; import edu.berkeley.nlp.lm.NgramLanguageModel; import edu.berkeley.nlp.lm.StringWordIndexer; import edu.berkeley.nlp.lm.io.ArpaLmReader; import edu.berkeley.nlp.lm.io.LmReaders; import java.io.BufferedReader; import java.io.BufferedWriter; import java.io.File; import java.io.FileInputStream; import java.io.FileNotFoundException; import java.io.FileOutputStream; import java.io.FileReader; import java.io.IOException; import java.io.InputStreamReader; import java.io.LineNumberReader; import java.io.OutputStreamWriter; import java.io.PrintWriter; import java.nio.charset.Charset; import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.concurrent.CountDownLatch; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; import net.openhft.koloboke.collect.map.hash.HashIntFloatMap; import net.openhft.koloboke.collect.map.hash.HashIntFloatMaps; import net.openhft.koloboke.collect.map.hash.HashIntIntMap; import net.openhft.koloboke.collect.map.hash.HashIntIntMaps; import net.openhft.koloboke.collect.map.hash.HashIntObjMap; import net.openhft.koloboke.collect.map.hash.HashIntObjMaps; import net.openhft.koloboke.collect.map.hash.HashObjIntMap; import net.openhft.koloboke.collect.map.hash.HashObjIntMaps; import org.apache.commons.cli.CommandLine; import org.apache.commons.cli.CommandLineParser; import org.apache.commons.cli.GnuParser; import org.apache.commons.cli.HelpFormatter; import org.apache.commons.cli.Options; import org.apache.commons.cli.ParseException; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; /** * Invitation based data selection approach exploits in-domain data (both * monolingual and bilingual) as prior to guide word alignment and phrase pair * estimates in the large mix-domain corpus. As a by-product, accurate estimates * for P(D|e,f) of the mixed-domain sentences are produced (with D being either * in-domain or out-of-domain), which can be used to rank the sentences in Dmix * according to their relevance to Din. * * For more information see: Hoang, Cuong and Sima'an, Khalil (2014): Latent * Domain Translation Models in Mix-of-Domains Haystack, Proceedings of COLING * 2014, the 25th International Conference on Computational Linguistics * http://www.aclweb.org/anthology/C14-1182.pdf * * @author Amir Kamran */ public class InvitationModel { private static Logger log = LogManager.getLogger(InvitationModel.class); static String IN = null; static String MIX = null; static String SRC = null; static String TRG = null; static int iMAX = 10; static int src_indomain[][] = null; static int trg_indomain[][] = null; static int src_mixdomain[][] = null; static int trg_mixdomain[][] = null; static int src_outdomain[][] = null; static int trg_outdomain[][] = null; static HashObjIntMap<String> src_codes = null; static HashObjIntMap<String> trg_codes = null; static float lm[][] = null; static float LOG_0_5 = (float) Math.log(0.5); // default confidence threshold: use to decide which sentences // will update the translation table static float CONF_THRESHOLD = (float) Math.log(0.5); // default convergence threshold: How much change in PD1 is significant // to continue to next iteration static float CONV_THRESHOLD = 0.00001f; static float PD1 = LOG_0_5; static float PD0 = LOG_0_5; static TranslationTable ttable[] = new TranslationTable[4]; public static CountDownLatch latch = null; public static ExecutorService jobs = Executors.newCachedThreadPool(); public static HashIntIntMap ignore = HashIntIntMaps.newMutableMap(); public static float n = 0.5f; public static float V = 500000f; public static float nV = n * V; public static float p = -(float) Math.log(V); public static void main(String args[]) throws IOException, InterruptedException { log.info("Start ..."); processCommandLineArguments(args); readFiles(); initialize(); burnIN(); createLM(); training(); jobs.shutdown(); jobs.awaitTermination(10, TimeUnit.MINUTES); log.info("END"); } public static void processCommandLineArguments(String args[]) { Options options = new Options(); options.addOption("cmix", "mix-domain-corpus", true, "Mix-domain corpus name"); options.addOption("cin", "in-domain-corpus", true, "In-domain corpus name"); options.addOption("src", "src-language", true, "Source Language"); options.addOption("trg", "trg-language", true, "Target Language"); options.addOption("i", "max-iterations", true, "Maximum Iterations"); options.addOption("th", "threshold", true, "This threshold deicdes which sentences updates translation tables. Default is 0.5"); options.addOption("cf", "conv_threshold", true, "This threshold decide if the convergence is reached. Default is 0.00001"); CommandLineParser parser = new GnuParser(); try { CommandLine cmd = parser.parse(options, args); if (cmd.hasOption("cmix") && cmd.hasOption("cin") && cmd.hasOption("src") && cmd.hasOption("trg")) { MIX = cmd.getOptionValue("cmix"); IN = cmd.getOptionValue("cin"); SRC = cmd.getOptionValue("src"); TRG = cmd.getOptionValue("trg"); if (cmd.hasOption("i")) { iMAX = Integer.parseInt(cmd.getOptionValue("i")); } if (cmd.hasOption("th")) { CONF_THRESHOLD = (float) Math.log(Double.parseDouble(cmd.getOptionValue("th"))); } if (cmd.hasOption("cf")) { CONV_THRESHOLD = (float) Float.parseFloat(cmd.getOptionValue("cf")); } } else { System.out.println("Missing required argumetns!"); printHelp(options); } } catch (ParseException e) { printHelp(options); } } private static void printHelp(Options options) { HelpFormatter formatter = new HelpFormatter(); formatter.printHelp("java " + InvitationModel.class.getName(), options); System.exit(1); } public static void initialize() throws InterruptedException { log.info("Initializing Translaiton Tables"); for (int i = 0; i < ttable.length; i++) { ttable[i] = new TranslationTable(); } latch = new CountDownLatch(4); initializeTranslationTable(src_indomain, trg_indomain, ttable[0]); initializeTranslationTable(trg_indomain, src_indomain, ttable[1]); initializeTranslationTable(src_mixdomain, trg_mixdomain, ttable[2]); initializeTranslationTable(trg_mixdomain, src_mixdomain, ttable[3]); latch.await(); log.info("DONE"); } public static void initializeTranslationTable(final int src[][], final int trg[][], final TranslationTable ttable) { jobs.execute(new Runnable() { @Override public void run() { HashIntFloatMap totals = HashIntFloatMaps.newMutableMap(); for (int sent = 0; sent < src.length; sent++) { if (sent % 100000 == 0) log.debug("Sentence " + sent); int ssent[] = src[sent]; int tsent[] = trg[sent]; for (int t = 1; t < tsent.length; t++) { int tw = tsent[t]; for (int s = 0; s < ssent.length; s++) { int sw = ssent[s]; ttable.increas(tw, sw, 1f); totals.addValue(sw, 1f, 0f); } } } // normalizing and smoothing for (int tw : ttable.ttable.keySet()) { HashIntFloatMap tMap = ttable.ttable.get(tw); for (int sw : tMap.keySet()) { float prob = (float) (Math.log(ttable.get(tw, sw) + n) - Math.log(totals.get(sw) + nV)); ttable.put(tw, sw, prob); } } log.info("."); InvitationModel.latch.countDown(); } }); } public static void createLM() throws InterruptedException { log.info("Creating Language Models ..."); lm = new float[4][]; latch = new CountDownLatch(4); createLM(IN + "." + SRC + ".encoded", lm, 0, src_mixdomain); createLM(IN + "." + TRG + ".encoded", lm, 1, trg_mixdomain); createLM("outdomain." + SRC + ".encoded", lm, 2, src_mixdomain); createLM("outdomain." + TRG + ".encoded", lm, 3, trg_mixdomain); latch.await(); log.info("DONE"); } public static void burnIN() throws IOException, InterruptedException { log.info("BurnIN started ... "); HashIntObjMap<Result> results = null; for (int i = 1; i <= 1; i++) { log.info("Iteration " + i); results = HashIntObjMaps.newMutableMap(); float sPD[][] = new float[2][src_mixdomain.length]; int split = (int) Math.ceil(src_mixdomain.length / 100000d); latch = new CountDownLatch(split); for (int sent = 0; sent < src_mixdomain.length; sent += 100000) { int end = sent + 100000; if (end > src_mixdomain.length) { end = src_mixdomain.length; } calcualteBurnInScore(sent, end, sPD); } latch.await(); float countPD[] = new float[2]; countPD[0] = Float.NEGATIVE_INFINITY; countPD[1] = Float.NEGATIVE_INFINITY; for (int sent = 0; sent < src_mixdomain.length; sent++) { if (ignore.containsKey(sent)) continue; if (Float.isNaN(sPD[0][sent]) || Float.isNaN(sPD[1][sent])) { ignore.put(sent, sent); log.info("Ignoring " + (sent + 1)); continue; } countPD[0] = logAdd(countPD[0], sPD[0][sent]); countPD[1] = logAdd(countPD[1], sPD[1][sent]); results.put(sent, new Result(sent, sPD[0][sent])); } } log.info("BurnIN DONE"); log.info("Writing outdomain corpus ... "); ArrayList<Result> sortedResult = new ArrayList<Result>(results.values()); Collections.sort(sortedResult); PrintWriter src_out = new PrintWriter("outdomain." + SRC + ".encoded"); PrintWriter trg_out = new PrintWriter("outdomain." + TRG + ".encoded"); PrintWriter out_score = new PrintWriter("outdomain.scores"); src_outdomain = new int[src_indomain.length][]; trg_outdomain = new int[trg_indomain.length][]; int j = 0; for (Result r : sortedResult) { int sentIndex = r.sentenceNumber - 1; int ssent[] = src_mixdomain[sentIndex]; int tsent[] = trg_mixdomain[sentIndex]; out_score.println(r.sentenceNumber + "\t" + r.score); src_outdomain[j] = ssent; trg_outdomain[j] = tsent; for (int w = 1; w < ssent.length; w++) { src_out.print(ssent[w]); src_out.print(" "); } src_out.println(); for (int w = 1; w < tsent.length; w++) { trg_out.print(tsent[w]); trg_out.print(" "); } trg_out.println(); j++; if (j == src_indomain.length) { break; } } out_score.close(); src_out.close(); trg_out.close(); log.info("DONE"); } public static void training() throws FileNotFoundException, InterruptedException { log.info("Starting Invitation EM ..."); latch = new CountDownLatch(2); ttable[2] = new TranslationTable(); ttable[3] = new TranslationTable(); initializeTranslationTable(src_outdomain, trg_outdomain, ttable[2]); initializeTranslationTable(trg_outdomain, src_outdomain, ttable[3]); latch.await(); for (int i = 1; i <= iMAX; i++) { log.info("Iteration " + i); HashIntObjMap<Result> results = HashIntObjMaps.newMutableMap(); float sPD[][] = new float[2][src_mixdomain.length]; int splits = 10; int split_size = src_mixdomain.length / splits; latch = new CountDownLatch(splits); for (int s = 0; s < splits; s++) { int start = s * split_size; int end = start + split_size; if (s == (splits - 1)) { end = src_mixdomain.length; } calcualteScore(start, end, sPD); } latch.await(); float countPD[] = new float[2]; countPD[0] = Float.NEGATIVE_INFINITY; countPD[1] = Float.NEGATIVE_INFINITY; for (int sent = 0; sent < src_mixdomain.length; sent++) { if (ignore.containsKey(sent)) continue; if (Float.isNaN(sPD[0][sent]) || Float.isNaN(sPD[1][sent])) { ignore.put(sent, sent); log.info("Ignoring " + (sent + 1)); continue; } countPD[0] = logAdd(countPD[0], sPD[0][sent]); countPD[1] = logAdd(countPD[1], sPD[1][sent]); float srcP = lm[0][sent]; float trgP = lm[1][sent]; results.put(sent, new Result(sent, sPD[1][sent], srcP + trgP)); } float newPD1 = countPD[1] - logAdd(countPD[0], countPD[1]); float newPD0 = countPD[0] - logAdd(countPD[0], countPD[1]); log.info("PD1 ~ PD0 " + Math.exp(newPD1) + " ~ " + Math.exp(newPD0)); writeResult(i, results); if (i > 1 && Math.abs(Math.exp(newPD1) - Math.exp(PD1)) <= CONV_THRESHOLD) { log.info("Convergence threshold reached."); break; } PD1 = newPD1; PD0 = newPD0; if (i < iMAX) { latch = new CountDownLatch(4); updateTranslationTable(src_mixdomain, trg_mixdomain, ttable[0], sPD[1]); updateTranslationTable(trg_mixdomain, src_mixdomain, ttable[1], sPD[1]); updateTranslationTable(src_mixdomain, trg_mixdomain, ttable[2], sPD[0]); updateTranslationTable(trg_mixdomain, src_mixdomain, ttable[3], sPD[0]); latch.await(); } } } public static void calcualteScore(final int start, final int end, final float sPD[][]) { jobs.execute(new Runnable() { @Override public void run() { for (int sent = start; sent < end; sent++) { if (ignore.containsKey(sent)) continue; int ssent[] = src_mixdomain[sent]; int tsent[] = trg_mixdomain[sent]; float sProb[] = new float[4]; sProb[0] = calculateProb(ssent, tsent, ttable[0]); sProb[1] = calculateProb(tsent, ssent, ttable[1]); sProb[2] = calculateProb(ssent, tsent, ttable[2]); sProb[3] = calculateProb(tsent, ssent, ttable[3]); float in_score = PD1 + logAdd(sProb[0] + lm[1][sent], sProb[1] + lm[0][sent]); float mix_score = PD0 + logAdd(sProb[2] + lm[3][sent], sProb[3] + lm[2][sent]); sPD[1][sent] = in_score - logAdd(in_score, mix_score); sPD[0][sent] = mix_score - logAdd(in_score, mix_score); } InvitationModel.latch.countDown(); } }); } public static void calcualteBurnInScore(final int start, final int end, final float sPD[][]) { jobs.execute(new Runnable() { @Override public void run() { for (int sent = start; sent < end; sent++) { if (ignore.containsKey(sent)) continue; int ssent[] = src_mixdomain[sent]; int tsent[] = trg_mixdomain[sent]; float sProb[] = new float[4]; sProb[0] = calculateProb(ssent, tsent, ttable[0]); sProb[1] = calculateProb(tsent, ssent, ttable[1]); sProb[2] = calculateProb(ssent, tsent, ttable[2]); sProb[3] = calculateProb(tsent, ssent, ttable[3]); float in_score = PD1 + logAdd(sProb[0], sProb[1]); float mix_score = PD0 + logAdd(sProb[2], sProb[3]); sPD[1][sent] = in_score - logAdd(in_score, mix_score); sPD[0][sent] = mix_score - logAdd(in_score, mix_score); } InvitationModel.latch.countDown(); } }); } public static void writeResult(final int iterationNumber, final HashIntObjMap<Result> results) { jobs.execute(new Runnable() { @Override public void run() { ArrayList<Result> sortedResult = new ArrayList<Result>(results.values()); Collections.sort(sortedResult); try { PrintWriter output = new PrintWriter("output_" + iterationNumber + ".txt"); for (Result r : sortedResult) { output.println(r.sentenceNumber + "\t" + Math.exp(r.score) + "\t" + Math.exp(r.lm_score)); } output.close(); } catch (FileNotFoundException e) { e.printStackTrace(); } } }); } public static float calculateProb(final int ssent[], final int tsent[], final TranslationTable ttable) { float prob = 0; for (int t = 1; t < tsent.length; t++) { int tw = tsent[t]; float sum = Float.NEGATIVE_INFINITY; for (int s = 0; s < ssent.length; s++) { int sw = ssent[s]; sum = logAdd(sum, ttable.get(tw, sw, p)); } prob += sum; } return prob - (float) Math.log(Math.pow(ssent.length, tsent.length - 1)); } public static void updateTranslationTable(final int src[][], final int trg[][], final TranslationTable ttable, final float sPD[]) { jobs.execute(new Runnable() { @Override public void run() { log.info("Updating translation table ... "); TranslationTable counts = new TranslationTable(); HashIntFloatMap totals = HashIntFloatMaps.newMutableMap(); for (int sent = 0; sent < src.length; sent++) { if (sent % 100000 == 0) log.debug("Sentence " + sent); if (ignore.containsKey(sent)) continue; if (sPD[sent] < CONF_THRESHOLD) continue; int ssent[] = src[sent]; int tsent[] = trg[sent]; HashIntFloatMap s_total = HashIntFloatMaps.newMutableMap(); // calculating normalization for (int t = 1; t < tsent.length; t++) { int tw = tsent[t]; for (int s = 0; s < ssent.length; s++) { int sw = ssent[s]; s_total.put(tw, logAdd(s_total.getOrDefault(tw, Float.NEGATIVE_INFINITY), ttable.get(tw, sw, p))); } } // collect counts for (int t = 1; t < tsent.length; t++) { int tw = tsent[t]; for (int s = 0; s < ssent.length; s++) { int sw = ssent[s]; float in_count = sPD[sent] + (ttable.get(tw, sw, p) - s_total.get(tw)); counts.put(tw, sw, logAdd(counts.get(tw, sw, Float.NEGATIVE_INFINITY), in_count)); totals.put(sw, logAdd(totals.getOrDefault(sw, Float.NEGATIVE_INFINITY), in_count)); } } } // maximization for (int tw : counts.ttable.keySet()) { HashIntFloatMap tMap = counts.ttable.get(tw); for (int sw : tMap.keySet()) { float newProb = counts.get(tw, sw) - totals.get(sw); ttable.put(tw, sw, newProb); } } log.info("Updating translation table DONE"); InvitationModel.latch.countDown(); } }); } public static void readFiles() throws IOException, InterruptedException { log.info("Reading files"); src_codes = HashObjIntMaps.newMutableMap(); trg_codes = HashObjIntMaps.newMutableMap(); src_codes.put(null, 0); trg_codes.put(null, 0); LineNumberReader lr = new LineNumberReader(new FileReader(IN + "." + SRC)); lr.skip(Long.MAX_VALUE); int indomain_size = lr.getLineNumber(); lr.close(); lr = new LineNumberReader(new FileReader(MIX + "." + SRC)); lr.skip(Long.MAX_VALUE); int mixdomain_size = lr.getLineNumber(); lr.close(); src_indomain = new int[indomain_size][]; trg_indomain = new int[indomain_size][]; src_mixdomain = new int[mixdomain_size][]; trg_mixdomain = new int[mixdomain_size][]; latch = new CountDownLatch(2); readFile(IN + "." + SRC, src_codes, src_indomain); readFile(IN + "." + TRG, trg_codes, trg_indomain); latch.await(); latch = new CountDownLatch(2); readFile(MIX + "." + SRC, src_codes, src_mixdomain); readFile(MIX + "." + TRG, trg_codes, trg_mixdomain); latch.await(); } public static void readFile(final String fileName, final HashObjIntMap<String> codes, final int lines[][]) throws IOException { jobs.execute(new Runnable() { @Override public void run() { try { BufferedReader reader = new BufferedReader( new InputStreamReader(new FileInputStream(fileName), Charset.forName("UTF8"))); String line = null; int i = 0; while ((line = reader.readLine()) != null) { String words[] = line.split("\\s+"); lines[i] = new int[words.length + 1]; lines[i][0] = 0; int j = 1; for (String word : words) { int code = 0; if (!codes.containsKey(word)) { code = codes.size() + 1; codes.put(word, code); } else { code = codes.getInt(word); } lines[i][j++] = code; } i++; } reader.close(); } catch (IOException e) { e.printStackTrace(); System.exit(1); } writeEncodedFile(fileName + ".encoded", lines); log.info(fileName + " ... DONE"); InvitationModel.latch.countDown(); } }); } public static void writeEncodedFile(final String fileName, final int lines[][]) { jobs.execute(new Runnable() { @Override public void run() { try { BufferedWriter encodedWriter = new BufferedWriter( new OutputStreamWriter(new FileOutputStream(fileName), Charset.forName("UTF8"))); for (int i = 0; i < lines.length; i++) { for (int j = 1; j < lines[i].length; j++) { int word = lines[i][j]; encodedWriter.write("" + word); encodedWriter.write(" "); } encodedWriter.write("\n"); } encodedWriter.close(); } catch (IOException e) { e.printStackTrace(); } } }); } public static float getLMProb(NgramLanguageModel<String> lm, int sent[]) { List<String> words = new ArrayList<String>(); for (int i = 1; i < sent.length; i++) { words.add("" + sent[i]); } return lm.getLogProb(words); } public static void createLM(final String fileName, final float lm[][], final int index, final int corpus[][]) { jobs.execute(new Runnable() { @Override public void run() { log.info("Creating language model"); NgramLanguageModel<String> createdLM = null; final int lmOrder = 4; final List<String> inputFiles = new ArrayList<String>(); inputFiles.add(fileName); final StringWordIndexer wordIndexer = new StringWordIndexer(); wordIndexer.setStartSymbol(ArpaLmReader.START_SYMBOL); wordIndexer.setEndSymbol(ArpaLmReader.END_SYMBOL); wordIndexer.setUnkSymbol(ArpaLmReader.UNK_SYMBOL); createdLM = LmReaders.readContextEncodedKneserNeyLmFromTextFile(inputFiles, wordIndexer, lmOrder, new ConfigOptions(), new File(fileName + ".lm")); lm[index] = new float[corpus.length]; for (int i = 0; i < corpus.length; i++) { int sent[] = corpus[i]; lm[index][i] = getLMProb(createdLM, sent); } log.info("."); InvitationModel.latch.countDown(); } }); } public static float logAdd(float a, float b) { float max, negDiff; if (a > b) { max = a; negDiff = b - a; } else { max = b; negDiff = a - b; } if (max == Float.NEGATIVE_INFINITY) { return max; } else if (negDiff < -20.0f) { return max; } else { return max + (float) Math.log(1.0 + Math.exp(negDiff)); } } } class Result implements Comparable<Result> { int sentenceNumber; float score = 1; float lm_score = 1; public Result(int sentenceNumber, float score) { this.sentenceNumber = sentenceNumber + 1; this.score = score; } public Result(int sentenceNumber, float score, float lm_score) { this.sentenceNumber = sentenceNumber + 1; this.score = score; this.lm_score = lm_score; } @Override public int compareTo(Result result) { int cmp = Float.compare(result.score, this.score); if (cmp == 0) { cmp = Float.compare(result.lm_score, this.lm_score); } return cmp; } }