org.deeplearning4j.examples.tictactoe.TicTacToeData.java Source code

Java tutorial

Introduction

Here is the source code for org.deeplearning4j.examples.tictactoe.TicTacToeData.java

Source

/*******************************************************************************
 * Copyright (c) 2015-2019 Skymind, Inc.
 *
 * This program and the accompanying materials are made available under the
 * terms of the Apache License, Version 2.0 which is available at
 * https://www.apache.org/licenses/LICENSE-2.0.
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
 * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
 * License for the specific language governing permissions and limitations
 * under the License.
 *
 * SPDX-License-Identifier: Apache-2.0
 ******************************************************************************/

package org.deeplearning4j.examples.tictactoe;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.factory.Nd4j;

import java.io.BufferedWriter;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

/**
 * This program generates basic data to be used in Training Program.
 * It performs following major steps
 * - generates all possible game states
 * - reward all game states generated in above step by finding winning state, assign it to value 1 and goes back upto first step through all steps and
 * calculates probability of each step in the game to make that move win game in the last state.
 * - Writes all states data along with probability of each state to win the game which was calculated in above step.
 * Note :
 * - Used <b>http://www.se16.info/hgb/tictactoe.htm</b> link to understand all possible number of moves in Tic-Tac-Toe game.
 * - Refer ReadMe.txt for detail explanation of each step.
 * <p>
 * <b>Developed by KIT Solutions Pvt. Ltd. (www.kitsol.com), 19-Jan-2017.</b>
 */

public class TicTacToeData {

    private static Log log = LogFactory.getLog(TicTacToeData.class);

    /**
     * Main function that calls all major functions one-by-one to generate training data to be used in training program.
     */
    public static void main(String[] args) throws Exception {
        long start = System.nanoTime();
        try {
            TicTacToeData data = new TicTacToeData();
            log.info("Data Processing Started");
            final String allMoves = data.generatePossibleGames();
            log.info("All possible game state sequence generated, Finished");

            final Path dataFile = Paths.get(System.getProperty("user.home") + "/AllMoveWithReward.txt");
            Files.deleteIfExists(dataFile);
            final Path dataFilePath = Files.createFile(dataFile);
            try (BufferedWriter writer = Files.newBufferedWriter(dataFilePath)) {
                writer.write(allMoves);
            }
        } catch (Exception e) {
            log.error(e);
        }
        log.info("Total time = " + (System.nanoTime() - start) / 1_000_000);
    }

    /**
     * Initiate generating all possible game states. Refer ReadMe.txt for detailed explanation.
     */
    private String generatePossibleGames() throws Exception {
        List<String> values = new ArrayList<>();
        List<INDArray> moveSequenceList = new ArrayList<>();
        for (int index = 1; index <= 9; index++) {
            generateStateBasedOnMoveNumber(index, moveSequenceList, values);
        }
        return values.stream().distinct().collect(Collectors.joining("\r\n"));
    }

    /**
     * This function called by generatePossibleGames. It is the main function that generates all possible game states.
     * Refer ReadMe.txt for detailed explanation.
     */
    private void generateStateBasedOnMoveNumber(int moveNumber, List<INDArray> moveSequenceList,
            List<String> values) throws Exception {

        List<INDArray> tempMoveSequenceList = new ArrayList<>();
        tempMoveSequenceList.addAll(moveSequenceList);
        moveSequenceList.clear();

        if (moveNumber == 1) {
            for (int i = 0; i < 9; i++) {
                INDArray temp2 = Nd4j.zeros(1, 9);
                temp2.putScalar(new int[] { 0, i }, 1);
                moveSequenceList.add(temp2);
            }
        } else {
            int lengthOfTempMoveSequenceList = tempMoveSequenceList.size();

            for (INDArray moveArraySequence : tempMoveSequenceList)
                IntStream.range(0, 9).filter(j -> moveArraySequence.getInt(j) == 0).forEach(j -> {
                    INDArray temp1 = Nd4j.zeros(1, 9);
                    Nd4j.copy(moveArraySequence, temp1);
                    temp1.putScalar(new int[] { 0, j }, moveNumber);
                    if (moveNumber > 4) {
                        boolean isOddMoveNumber = (moveNumber % 2) != 0;
                        if (checkWin(temp1, isOddMoveNumber)) {
                            values.addAll(generateGameStatesAndRewardToIt(temp1, isOddMoveNumber ? 0 : 1));
                        } else {
                            moveSequenceList.add(temp1);
                        }

                    } else {
                        moveSequenceList.add(temp1);
                    }
                });
        }
        if (moveNumber == 9) {
            values.addAll(
                    moveSequenceList.stream().flatMap(temp1 -> generateGameStatesAndRewardToIt(temp1, 0).stream())
                            .collect(Collectors.toList()));
        }
    }

    /**
     * Identify the game state win/Draw.
     */
    private boolean checkWin(INDArray sequence, boolean isOdd) {
        double boardPosition1 = sequence.getDouble(0);
        boolean boardIsOdd = boardPosition1 % 2.0 != 0;
        double boardPosition2 = sequence.getDouble(1);
        double boardPosition3 = sequence.getDouble(2);
        double boardPosition4 = sequence.getDouble(3);
        double boardPosition5 = sequence.getDouble(4);
        double boardPosition6 = sequence.getDouble(5);
        double boardPosition7 = sequence.getDouble(6);
        double boardPosition8 = sequence.getDouble(7);
        double boardPosition9 = sequence.getDouble(8);

        boolean position1 = isOdd && boardIsOdd;
        boolean position2 = isOdd ? (sequence.getDouble(1) % 2.0 != 0) : (sequence.getDouble(1) % 2.0 == 0);
        boolean position3 = isOdd ? (sequence.getDouble(2) % 2.0 != 0) : (sequence.getDouble(2) % 2.0 == 0);
        boolean position4 = isOdd ? (sequence.getDouble(3) % 2.0 != 0) : (sequence.getDouble(3) % 2.0 == 0);
        boolean position5 = isOdd ? (sequence.getDouble(4) % 2.0 != 0) : (sequence.getDouble(4) % 2.0 == 0);
        boolean position6 = isOdd ? (sequence.getDouble(5) % 2.0 != 0) : (sequence.getDouble(5) % 2.0 == 0);
        boolean position7 = isOdd ? (sequence.getDouble(6) % 2.0 != 0) : (sequence.getDouble(6) % 2.0 == 0);
        boolean position8 = isOdd ? (sequence.getDouble(7) % 2.0 != 0) : (sequence.getDouble(7) % 2.0 == 0);
        boolean position9 = isOdd ? (sequence.getDouble(8) % 2.0 != 0) : (sequence.getDouble(8) % 2.0 == 0);

        if (((position1 && position2 && position3)
                && (boardPosition1 != 0 && boardPosition2 != 0 && boardPosition3 != 0))
                || ((position4 && position5 && position6)
                        && (boardPosition4 != 0 && boardPosition5 != 0 && boardPosition6 != 0))
                || ((position7 && position8 && position9)
                        && (boardPosition7 != 0 && boardPosition8 != 0 && boardPosition9 != 0))
                || ((position1 && position4 && position7)
                        && (boardPosition1 != 0 && boardPosition4 != 0 && boardPosition7 != 0))
                || ((position2 && position5 && position8)
                        && (boardPosition2 != 0 && boardPosition5 != 0 && boardPosition8 != 0))
                || ((position3 && position6 && position9)
                        && (boardPosition3 != 0 && boardPosition6 != 0 && boardPosition9 != 0))
                || ((position1 && position5 && position9)
                        && (boardPosition1 != 0 && boardPosition5 != 0 && boardPosition9 != 0))
                || ((position3 && position5 && position7)
                        && (boardPosition3 != 0 && boardPosition5 != 0 && boardPosition7 != 0))) {

            return true;
        } else {
            return false;
        }
    }

    /**
     * This function generate all intermediate (including winning) game state from the winning state available oddPlayerWiningList or evenPlayerWiningList
     * and pass it to calculateReward function to calculate probability of all states of winning game.
     * Refer ReadMe.txt for detailed explanation.
     */
    private List<String> generateGameStatesAndRewardToIt(INDArray output, int moveType) {
        Map<INDArray, Double> valueMap = new HashMap<>();
        INDArray maxArray = Nd4j.max(output);
        double maxNumber = maxArray.getDouble(0);

        List<INDArray> sequenceList = new ArrayList<>();
        INDArray sequenceArray = Nd4j.zeros(1, 9);

        int move = 1;
        int positionOfDigit = 0;
        for (int i = 1; i <= maxNumber; i++) {
            INDArray newTempArray = Nd4j.zeros(1, 9);
            positionOfDigit = getPosition(output, i);

            if (i % 2 == moveType) {
                Nd4j.copy(sequenceArray, newTempArray);
                sequenceList.add(newTempArray);
            } else {
                Nd4j.copy(sequenceArray, newTempArray);
                valueMap.put(newTempArray, 0.5);
            }
            sequenceArray.putScalar(new int[] { 0, positionOfDigit }, move);
            move = move * (-1);
        }
        move = move * (-1);
        INDArray newTempArray2 = Nd4j.zeros(1, 9);

        sequenceArray.putScalar(new int[] { 0, positionOfDigit }, move);
        Nd4j.copy(sequenceArray, newTempArray2);
        sequenceList.add(newTempArray2);
        calculateReward(sequenceList, valueMap);
        return valueMap.entrySet().parallelStream()
                .map(entry -> generateStringList(entry.getKey(), entry.getValue())).distinct()
                .collect(Collectors.toList());
    }

    /**
     * This function gives cell number of a particular move
     */
    private int getPosition(INDArray array, double number) {

        for (int i = 0; i < array.length(); i++) {
            if (array.getDouble(i) == number) {
                return i;
            }
        }
        return 0;
    }

    /**
     * Function to calculate Temporal Difference. Refer ReadMe.txt for detailed explanation.
     */
    private void calculateReward(List<INDArray> arrayList, Map<INDArray, Double> valueMap) {

        double probabilityValue = 0;
        for (int p = (arrayList.size() - 1); p >= 0; p--) {
            if (p == (arrayList.size() - 1)) {
                probabilityValue = 1.0;
            } else {
                probabilityValue = 0.5 + 0.1 * (probabilityValue - 0.5);
            }
            INDArray stateAsINDArray = arrayList.get(p);
            valueMap.merge(stateAsINDArray, probabilityValue,
                    (oldValue, newValue) -> oldValue > newValue ? oldValue : newValue);
        }
    }

    private String generateStringList(INDArray arrayFromInputList, double rewardValue) {
        List<String> strings = new ArrayList<>();
        StringBuilder stringBuilder = new StringBuilder();
        String tempString = arrayFromInputList.toString().replace('[', ' ').replace(']', ' ').replace(',', ':')
                .replaceAll("\\s", "");
        stringBuilder.append(tempString.replaceAll("-1", "2"));
        stringBuilder.append(" ");
        stringBuilder.append(rewardValue);
        stringBuilder.append("\r\n");
        stringBuilder.append(tempString.replaceAll("1", "2").replaceAll("-2", "1"));
        stringBuilder.append(" ");
        stringBuilder.append(rewardValue);
        return stringBuilder.toString();
    }
}