de.tudarmstadt.ukp.dkpro.core.api.datasets.Dataset.java Source code

Java tutorial

Introduction

Here is the source code for de.tudarmstadt.ukp.dkpro.core.api.datasets.Dataset.java

Source

/*
 * Copyright 2016
 * Ubiquitous Knowledge Processing (UKP) Lab
 * Technische Universitt Darmstadt
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package de.tudarmstadt.ukp.dkpro.core.api.datasets;

import java.io.File;
import java.util.Arrays;

import org.apache.commons.lang.ArrayUtils;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;

import de.tudarmstadt.ukp.dkpro.core.api.datasets.internal.SplitImpl;

public interface Dataset {
    String getName();

    String getLanguage();

    String getEncoding();

    File[] getDataFiles();

    File[] getLicenseFiles();

    Split getDefaultSplit();

    default Split getSplit(double aTrainRatio) {
        return getSplit(aTrainRatio, 1.0 - aTrainRatio);
    }

    default Split getSplit(double aTrainRatio, double aTestRatio) {
        Log LOG = LogFactory.getLog(getClass());

        File[] all = getDataFiles();
        Arrays.sort(all, (File a, File b) -> {
            return a.getName().compareTo(b.getName());
        });
        LOG.info("Found " + all.length + " files");

        int trainPivot = (int) Math.round(all.length * aTrainRatio);
        int testPivot = (int) Math.round(all.length * aTestRatio) + trainPivot;
        File[] train = (File[]) ArrayUtils.subarray(all, 0, trainPivot);
        File[] test = (File[]) ArrayUtils.subarray(all, trainPivot, testPivot);

        LOG.debug("Assigned " + train.length + " files to training set");
        LOG.debug("Assigned " + test.length + " files to test set");

        if (testPivot != all.length) {
            LOG.info("Files missing from split: [" + (all.length - testPivot) + "]");
        }

        return new SplitImpl(train, test, null);
    }
}