Java tutorial
/* * Copyright 2015 Carnegie Mellon University * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ package edu.cmu.lti.oaqa.knn4qa.apps; import java.io.*; import java.util.*; import no.uib.cipr.matrix.DenseVector; import no.uib.cipr.matrix.sparse.SparseVector; import org.apache.commons.cli.CommandLine; import org.apache.commons.cli.CommandLineParser; import org.apache.commons.cli.HelpFormatter; import org.apache.commons.cli.Options; import org.apache.commons.cli.ParseException; import org.xml.sax.SAXException; import com.google.common.base.Splitter; import edu.cmu.lti.oaqa.annographix.solr.*; import edu.cmu.lti.oaqa.annographix.util.*; import edu.cmu.lti.oaqa.knn4qa.cand_providers.CandidateEntry; import edu.cmu.lti.oaqa.knn4qa.cand_providers.CandidateProvider; import edu.cmu.lti.oaqa.knn4qa.cand_providers.LuceneCandidateProvider; import edu.cmu.lti.oaqa.knn4qa.cand_providers.NmslibQueryGenerator; import edu.cmu.lti.oaqa.knn4qa.letor.*; import org.apache.thrift.TException; /* begin: Imports related to a KNN-service */ import org.apache.thrift.transport.*; import edu.cmu.lti.oaqa.similarity.*; import org.apache.thrift.protocol.*; /* end: Imports related to a KNN-service */ class DebugKNNServicImpl extends BaseQueryApp { public static final String LUCENE_INDEX_LOCATION_DESC = "Location of a Lucene index"; private static final int MAX_DIGITS_TO_COMPARE = 4; int mQtyPotentMismatch = 0; int mQtyComp = 0; String mKnnServiceURL; TTransport mKnnServiceTransp; QueryService.Client mKnnServiceClient; NmslibQueryGenerator mQueryGen; @Override void addOptions() { boolean onlyLucene = true; boolean multNumRetr = false; boolean useQRELs = false; boolean useThreadQty = false; addCandGenOpts(onlyLucene, multNumRetr, useQRELs, useThreadQty); boolean useHigHorderModels = false; addResourceOpts(useHigHorderModels); boolean useIntermModel = true, useFinalModel = false; addLetorOpts(useIntermModel, useFinalModel); mOptions.addOption(CommonParams.KNN_SERVICE_PARAM, null, true, CommonParams.KNN_SERVICE_DESC); } @Override void procCustomOptions() { mKnnServiceURL = mCmd.getOptionValue(CommonParams.KNN_SERVICE_PARAM); if (null == mKnnServiceURL) showUsageSpecify(CommonParams.KNN_SERVICE_DESC); } @Override void init() throws Exception { if (null == mInMemExtrInterm) showUsageSpecify(CommonParams.EXTRACTOR_TYPE_INTERM_DESC); if (null == mNmslibFields) showUsageSpecify(CommonParams.NMSLIB_FIELDS_PARAM); mQueryGen = new NmslibQueryGenerator(mNmslibFields, mMemIndexPref, mInMemExtrInterm); Splitter splitOnColon = Splitter.on(':'); String host = null; int port = -1; int part = 0; for (String s : splitOnColon.split(mKnnServiceURL)) { if (0 == part) { host = s; } else if (1 == part) { try { port = Integer.parseInt(s); } catch (NumberFormatException e) { showUsage("Invalid port in the service address in '" + CommonParams.KNN_SERVICE_PARAM + "'"); } } else { showUsage("Extra colon in the service address in '" + CommonParams.KNN_SERVICE_PARAM + "'"); } ++part; } if (part != 2) { showUsage("Invalid format of the service address in '" + CommonParams.KNN_SERVICE_PARAM + "'"); } mKnnServiceTransp = new TSocket(host, port); mKnnServiceTransp.open(); mKnnServiceClient = new QueryService.Client(new TBinaryProtocol(mKnnServiceTransp)); } @Override void fin() throws Exception { logger.info( String.format("# of comparisons %d, # of potential mismatches (see output before for details) %d", mQtyComp, mQtyPotentMismatch)); } static boolean compareApprox(double a, double b, int digits) { double maxMod = Math.max(Math.abs(a), Math.abs(b)); double scale = Math.pow(10, digits); double lead = Math.pow(10, Math.round(Math.log10(maxMod))); double minSign = Float.MIN_NORMAL * scale; // These guys are just too small for us to bother about their differences if (maxMod < minSign) return true; double delta = lead / scale; double diff = Math.abs(a - b); return diff <= delta; } @Override void procResults(String queryID, Map<String, String> docFields, CandidateEntry[] scoredDocs, int numRet, Map<String, DenseVector> docFeats) throws Exception { String queryObjStr = null; queryObjStr = mQueryGen.getStrObjForKNNService(docFields); logger.info("KNN Query string:"); logger.info(queryObjStr); logger.info("=========================="); for (CandidateEntry r : scoredDocs) { DenseVector feat = docFeats.get(r.mDocId); r.mScore = (float) feat.dot(mModelInterm); String docObjStr = mQueryGen.getStrObjForKNNService(r.mDocId); // Left queries double knnScore = -mKnnServiceClient.getDistance(docObjStr, queryObjStr); logger.info(String.format("docId=%s score=%f knn-Service dist=%f", r.mDocId, r.mScore, knnScore)); DenseVector v = docFeats.get(r.mDocId); for (int i = 0; i < v.size(); ++i) { if (i > 0) System.out.print(" "); System.out.print((i + 1) + ":" + v.get(i)); } System.out.println(); String s1 = cutExtraDigits(r.mScore, MAX_DIGITS_TO_COMPARE); String s2 = cutExtraDigits(knnScore, MAX_DIGITS_TO_COMPARE); mQtyComp++; if (!compareApprox(r.mScore, knnScore, MAX_DIGITS_TO_COMPARE)) { mQtyPotentMismatch++; logger.info("Potential score mismatch!: " + s1 + " (Java) vs " + s2 + " (NMSLIB)"); } } } /* * This function is not for display (because for integer numbers > 0, you will lose digits. * it is only for the purpose of approximate number comparison. */ public static String cutExtraDigits(double val, int digitsToKeep) { String valStr = (val + "").toLowerCase().trim(); String sign = ""; if (valStr.startsWith("-")) { sign = "-"; valStr = valStr.substring(1); } int pos = valStr.indexOf('e'); String exp = ""; String mant = valStr; if (pos >= 0) { exp = valStr.substring(pos); mant = valStr.substring(0, pos); } if (mant.length() > digitsToKeep + 1) mant = mant.substring(0, digitsToKeep + 1); return sign + mant + exp; } } public class DebugKNNService { public static void main(String[] args) { try { (new DebugKNNServicImpl()).run("Debug KNN-service application", args); } catch (Exception e) { e.printStackTrace(); System.err.println("Terminating due to an exception: " + e); System.exit(1); } } }