Java implements BP neural network

Java implements BP neural network

  • Algorithm explanation
    • forward propagation
    • backpropagation
  • code design
    • Abstract class GeneralAnn
    • Ann concrete implementation class SimpleAnn

Algorithm explanation

BP (Back-propagation, backpropagation) neural network is the most traditional neural network. The process of BP neural network is mainly divided into two stages. The first stage is the forward propagation of the signal, from the input layer to the hidden layer, and finally reaches the output layer; the second stage is the backpropagation of the error, from the output layer to the hidden layer. Containing layers, and finally to the input layer, adjust the weights and biases from the hidden layer to the output layer in turn, and the weights and biases from the input layer to the hidden layer.

Forward propagation


Every value in the hidden layer

h

j

H_j

Hj? are all obtained by the combination of linear operation and nonlinear operation on the data of the input layer.

Forward Propagation: In forward propagation, the neural network passes the input data through a series of calculations of weights and activation functions, layer by layer, and finally generates a prediction result. Specific steps are as follows:

  • The input data is passed to the first layer (input layer), and each input is connected to the corresponding neuron.
  • For each layer, compute a weighted sum for that layer equal to the sum of the product of the previous layer’s output times the weights, plus a bias term.
  • Calculate the activation function of the weighted sum, such as Sigmoid, ReLU, etc., to obtain the output of this layer. This example uses sigmod
  • The output of this layer is used as the input of the next layer, and the calculation of the weighted sum and activation function is continued until the output layer is reached, and the final prediction result is output.

Backpropagation

Backward Propagation: In the process of back propagation, the error is passed from the output layer back to the input layer by calculating the gradient of the loss function in order to adjust the weights and bias items in the network. Specific steps are as follows:

  • Calculate the prediction error at the output layer, based on the difference between the predicted result and the true label. The error of the previous layer is calculated by passing the error of the output layer back to the previous layer by using the chain rule. This example uses mean squared error
  • Update the weights and bias terms in the network to minimize the loss function. This can be achieved with optimization algorithms such as gradient descent, where each weight and bias term is updated in the direction opposite to its corresponding gradient.

    in

    p

    p

    p is the variable to be updated, the other is the learning rate

Code Design

Abstract class GeneralAnn

This class implements some specific and fixed Ann process methods, such as: the construction method for reading files and assigning values to attributes, the train() method for model training, and the argmax() method for implementing activation functions. But the core forward propagation () and backpropagation (backPropagation) are abstract and not implemented.

package bp;

import weka.core.Instances;

import java.io.FileReader;
import java.util.Arrays;
import java.util.Random;
//Data reading and basic structure
public abstract class GeneralAnn {<!-- -->


    /**
     * The whole dataset.
     */
    Instances dataset;

    /**
     * Number of layers. It is counted according to nodes instead of edges.
     */
    int numLayers;

    /**
     * The number of nodes for each layer, e.g., [3, 4, 6, 2] means that there
     * are 3 input nodes (conditional attributes), 2 hidden layers with 4 and 6
     * nodes, respectively, and 2 class values (binary classification).
     */
    int[] layerNumNodes;

    /**
     * Momentum coefficient.
     */
    public double mobp;

    /**
     * Learning rate.
     */
    public double learningRate;

    /**
     * For random number generation.
     */
    Random random = new Random();

    /**
     *********************
     * The first constructor.
     * @param paraFilename
     * The arff filename.
     * @param paraLayerNumNodes
     * The number of nodes for each layer (may be different).
     * @param paraLearningRate
 * Learning rate.
     * @param paraMobp
     *********************
     */
    public GeneralAnn(String paraFilename, int[] paraLayerNumNodes, double paraLearningRate,
                           double paraMobp) {<!-- -->
        // Step 1. Read data.
        try {<!-- -->
            FileReader tempReader = new FileReader(paraFilename);
            dataset = new Instances(tempReader);
                // The last attribute is the decision class.
            dataset.setClassIndex(dataset.numAttributes() - 1);
            tempReader. close();
        } catch (Exception ee) {<!-- -->
            System.out.println("Error occurred while trying to read \'" + paraFilename
                     + "\' in GeneralAnn constructor.\r\\
" + ee);
            System. exit(0);
        } // Of try

        // Step 2. Accept parameters.
        layerNumNodes = paraLayerNumNodes;
        numLayers = layerNumNodes. length;
        // Adjust if necessary.
        layerNumNodes[0] = dataset. numAttributes() - 1;
        layerNumNodes[numLayers - 1] = dataset. numClasses();
        learningRate = paraLearningRate;
        mobp = paraMobp;
    }//Of the first constructor

    /**
     *********************
     * Forward prediction.
     *
     * @param paraInput
     * The input data of one instance.
     * @return The data at the output end.
     *********************
     */
    public abstract double[] forward(double[] paraInput);

    /**
     *********************
     * Back propagation.
     *
     * @param paraTarget
     * For 3-class data, it is [0, 0, 1], [0, 1, 0] or [1, 0, 0].
     *
     *********************
     */
    public abstract void backPropagation(double[] paraTarget);

    /**
     *********************
     * Train using the dataset.
     *********************
     */
    public void train() {<!-- -->
        double[] tempInput = new double[dataset. numAttributes() - 1];
        double[] tempTarget = new double[dataset. numClasses()];
        for (int i = 0; i < dataset. numInstances(); i ++ ) {<!-- -->
            // Fill the data.
            for (int j = 0; j < tempInput. length; j ++ ) {<!-- -->
                tempInput[j] = dataset.instance(i).value(j);
            } // Of for j

            // Fill the class label.
            Arrays.fill(tempTarget, 0);
            tempTarget[(int) dataset. instance(i). classValue()] = 1;

            // Train with this instance.
            forward(tempInput);
            backPropagation(tempTarget);
        } // Of for i
    }// Of train

    /**
     *********************
     * Get the index corresponding to the max value of the array.
     *
     * @return the index.
     *********************
     */
    public static int argmax(double[] paraArray) {<!-- -->
        int resultIndex = -1;
        double tempMax = -1e10;
        for (int i = 0; i < paraArray. length; i ++ ) {<!-- -->
            if (tempMax < paraArray[i]) {<!-- -->
                tempMax = paraArray[i];
                resultIndex = i;
            } // Of if
        } // Of for i

        return resultIndex;
    }// Of argmax

    /**
     *********************
     * Test using the dataset.
     *
     * @return The precision.
     *********************
     */
    public double test() {<!-- -->
        double[] tempInput = new double[dataset. numAttributes() - 1];
        double tempNumCorrect = 0;
        double[] tempPrediction;
        int tempPredictedClass = -1;

        for (int i = 0; i < dataset. numInstances(); i ++ ) {<!-- -->
            // Fill the data.
            for (int j = 0; j < tempInput. length; j ++ ) {<!-- -->
                tempInput[j] = dataset.instance(i).value(j);
            } // Of for j

            // Train with this instance.
            tempPrediction = forward(tempInput);
            //System.out.println("prediction: " + Arrays.toString(tempPrediction));
            tempPredictedClass = argmax(tempPrediction);
            if (tempPredictedClass == (int) dataset. instance(i). classValue()) {<!-- -->
                tempNumCorrect++;
            } // Of if
        } // Of for i

        System.out.println("Correct: " + tempNumCorrect + " out of " + dataset.numInstances());

        return tempNumCorrect / dataset. numInstances();
    }// Of test



}//Of class GeneralAnn

Ann implements class SimpleAnn

The forward and backPropagation methods are implemented in this class

package bp;


/**
 * Back-propagation neural networks. The code comes from
 * https://mp.weixin.qq.com
 * /s?__biz=MjM5MjAwODM4MA== &mid=402665740 &idx=1 &sn=18d84d72934e59ca8bcd828782172667
 *
 * @author Peng Yuan revised by [email protected]
 */

public class SimpleAnn extends GeneralAnn{<!-- -->

    /**
     * The value of each node that changes during the forward process.
     * dimension stands for the layer, and the second stands for the node.
     */
    public double[][] layerNodeValues;

    /**
     * The error on each node that changes during the back-propagation process.
     * The first dimension stands for the layer, and the second stands for the
     * node.
     */
    public double[][] layerNodeErrors;

    /**
     * The weights of edges. The first dimension stands for the layer, the
     * second stands for the node index of the layer, and the third dimension
     * stands for the node index of the next layer.
     */
    public double[][][] edgeWeights;

    /**
     * The change of edge weights. It has the same size as edgeWeights.
     */
    public double[][][] edgeWeightsDelta;

    /**
     *********************
     * The first constructor.
     *
     * @param paraFilename
     * The arff filename.
     * @param paraLayerNumNodes
     * The number of nodes for each layer (may be different).
     * @param paraLearningRate
     * Learning rate.
     * @param paraMobp
     * Momentum coefficient.
     *********************
     */
    public SimpleAnn(String paraFilename, int[] paraLayerNumNodes, double paraLearningRate,
                     double paraMobp) {<!-- -->
        super(paraFilename, paraLayerNumNodes, paraLearningRate, paraMobp);

        // Step 1. Across layer initialization.
        layerNodeValues = new double[numLayers][];
        layerNodeErrors = new double[numLayers][];
        edgeWeights = new double[numLayers - 1][][];
        edgeWeightsDelta = new double[numLayers - 1][][];

        // Step 2. Inner layer initialization.
        for (int l = 0; l < numLayers; l ++ ) {<!-- -->
            layerNodeValues[l] = new double[layerNumNodes[l]];
            layerNodeErrors[l] = new double[layerNumNodes[l]];

            // One less layer because each edge crosses two layers.
            if (l + 1 == numLayers) {<!-- -->
                break;
            } // of if

            // In layerNumNodes[l] + 1, the last one is reserved for the offset.
            edgeWeights[l] = new double[layerNumNodes[l] + 1][layerNumNodes[l + 1]];
            edgeWeightsDelta[l] = new double[layerNumNodes[l] + 1][layerNumNodes[l + 1]];
            for (int j = 0; j < layerNumNodes[l] + 1; j + + ) {<!-- -->
                for (int i = 0; i < layerNumNodes[l + 1]; i + + ) {<!-- -->
                    // Initialize weights.
                    edgeWeights[l][j][i] = random. nextDouble();
                } // Of for i
            } // Of for j
        } // Of for l
    }// Of the constructor

    /**
     *********************
     * Forward prediction.
     *
     * @param paraInput
     * The input data of one instance.
     * @return The data at the output end.
     *********************
     */
    public double[] forward(double[] paraInput) {<!-- -->
        // Initialize the input layer.
        // First put the input into the first layer
        for (int i = 0; i < layerNodeValues[0].length; i ++ ) {<!-- -->
            layerNodeValues[0][i] = paraInput[i];
        } // Of for i

        // Calculate the node values of each layer.
        //Cycle numLayers-1 times, vector multiply the input data to the next layer according to the weight, until the network is completely filled
        double z;
        for (int l = 1; l < numLayers; l ++ ) {<!-- -->
            for (int j = 0; j < layerNodeValues[l].length; j ++ ) {<!-- -->
                // Initialize according to the offset, which is always + 1
                //Use the z variable to store the value of matrix multiplication, and put it into layerNodeValues after accumulation
                z = edgeWeights[l - 1][layerNodeValues[l - 1].length][j];
                // Weighted sum on all edges for this node.
                for (int i = 0; i < layerNodeValues[l - 1].length; i ++ ) {<!-- -->
                    z + = edgeWeights[l - 1][i][j] * layerNodeValues[l - 1][i];
                } // Of for i

                // Sigmoid activation.
                // This line should be changed for other activation functions.
                // After filling the network, use the activation function to process and get the output
                layerNodeValues[l][j] = 1 / (1 + Math.exp(-z));
            } // Of for j
        } // Of for l

        return layerNodeValues[numLayers - 1];
    }// Of forward

    /**
     *********************
     * Back propagation and change the edge weights.
     *
     * @param paraTarget
     * For 3-class data, it is [0, 0, 1], [0, 1, 0] or [1, 0, 0].
     *********************
     */
    public void backPropagation(double[] paraTarget) {<!-- -->
        // Step 1. Initialize the output layer error.
        //Initialize error array
        int l = numLayers - 1;
        for (int j = 0; j < layerNodeErrors[l].length; j ++ ) {<!-- -->
            layerNodeErrors[l][j] = layerNodeValues[l][j] * (1 - layerNodeValues[l][j])
                    * (paraTarget[j] - layerNodeValues[l][j]);
        } // Of for j

        // Step 2. Back-propagation even for l == 0
        while (l > 0) {<!-- -->
            l--;
            // Layer l, for each node.
            for (int j = 0; j < layerNumNodes[l]; j ++ ) {<!-- -->
                double z = 0.0;
                // For each node of the next layer.
                for (int i = 0; i < layerNumNodes[l + 1]; i + + ) {<!-- -->
                    if (l > 0) {<!-- -->
                        z + = layerNodeErrors[l + 1][i] * edgeWeights[l][j][i];
                    } // Of if

                    // Weight adjusting.
                    edgeWeightsDelta[l][j][i] = mobp * edgeWeightsDelta[l][j][i]
                             + learningRate * layerNodeErrors[l + 1][i] * layerNodeValues[l][j];
                    edgeWeights[l][j][i] + = edgeWeightsDelta[l][j][i];
                    if (j == layerNumNodes[l] - 1) {<!-- -->
                        // Weight adjusting for the offset part.
                        edgeWeightsDelta[l][j+1][i] = mobp * edgeWeightsDelta[l][j+1][i]
                                 + learningRate * layerNodeErrors[l + 1][i];
                        edgeWeights[l][j+1][i] + = edgeWeightsDelta[l][j+1][i];
                    } // Of if
                } // Of for i

                // Record the error according to the differential of Sigmoid.
                // This line should be changed for other activation functions.
                layerNodeErrors[l][j] = layerNodeValues[l][j] * (1 - layerNodeValues[l][j]) * z;
            } // Of for j
        } // Of while
    }// Of backPropagation

    /**
     *********************
     * Test the algorithm.
     *********************
     */
    public static void main(String[] args) {<!-- -->
        int[] tempLayerNodes = {<!-- --> 4, 8, 8, 3 };
        SimpleAnn tempNetwork = new SimpleAnn("C:\Users\hp\Desktop\deepLearning\src\main\java\resources \iris.arff", tempLayerNodes, 0.01,
                0.6);

        for (int round = 0; round < 5000; round ++ ) {<!-- -->
            tempNetwork. train();
        } // Of for n

        double tempAccuracy = tempNetwork. test();
        System.out.println("The accuracy is: " + tempAccuracy);
    }// Of main
}// Of class SimpleAnn