第一个全连接神经网络的代码实践

作者: shaneZhang 分类: 机器学习的实践 发布时间: 2019-01-19 20:29

整个网络模块分为minist_forward.py mnist_backward.py minist_test.py三个文件,minist_forward.py为前向传播的流程,mnist_backward.py为后向传播的流程,mnist_test为测试识别准确率的一个过程。下面的代码如下

#coding=utf-8
#minist_forward.py

import tensorflow as tf

INPUT_NODE = 784
OUTPUT_NODE = 10
LAYER1_NODE=500

def get_weight(shape,regularizer):
    w = tf.Variable(tf.truncated_normal(shape,stddev=0.1))
    if regularizer != None:tf.add_to_collection('losses', tf.contrib.layers.l2_regularizer(regularizer)(w))
    return w

def get_bias(shape):
    b = tf.Variable(tf.zeros(shape))
    return b


def forward(x,regularizer):
    w1 = get_weight([INPUT_NODE,LAYER1_NODE],regularizer)
    b1 = get_bias([LAYER1_NODE])
    y1 = tf.nn.relu(tf.matmul(x,w1) + b1)

    w2 = get_weight([LAYER1_NODE,OUTPUT_NODE],regularizer)
    b2 = get_bias([OUTPUT_NODE])
    y = tf.matmul(y1,w2) + b2
    return y


# 由上述代码可知,在前向传播过程中,规定网络输入结点为 784 个(代表每张输
# 入图片的像素个数),隐藏层节点 500 个,输出节点 10 个(表示输出为数字 0-9
# 的十分类)。由输入层到隐藏层的参数 w1 形状为[784,500],由隐藏层到输出层
# 的参数 w2 形状为[500,10],参数满足截断正态分布,并使用正则化,将每个参
# 12
# 数的正则化损失加到总损失中。由输入层到隐藏层的偏置 b1 形状为长度为 500
# 的一维数组,由隐藏层到输出层的偏置 b2 形状为长度为 10 的一维数组,初始化
# 值为全 0。前向传播结构第一层为输入 x 与参数 w1 矩阵相乘加上偏置 b1,再经
# 过 relu 函数,得到隐藏层输出 y1。前向传播结构第二层为隐藏层输出 y1 与参
# 数 w2 矩阵相乘加上偏置 b2,得到输出 y。由于输出 y 要经过 softmax 函数,使
# 其符合概率分布,故输出 y 不经过 relu 函数。
#coding=utf-8
#mnist_backward.py

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import  GeneralDatas
import forward as FW

STEPS = 40000
BATCH_SIZE = 30
LEARNING_RATE_BASE=0.001
LEARNING_RATE_DECAY=0.999
REGULARIZER = 0.01


def backword():
    x = tf.placeholder(tf.float32,shape=(None,2))
    y_ = tf.placeholder(tf.float32,shape=(None,1))
    X,Y_,Y_C = GeneralDatas.generalData()
    print X
    print Y_
    print Y_C
    y = FW.forward(x,REGULARIZER)
    global_step = tf.Variable(0,trainable=False)
    learning_rate = tf.train.exponential_decay(LEARNING_RATE_BASE,global_step,300/BATCH_SIZE,LEARNING_RATE_DECAY,staircase=True)
    loss_mse = tf.reduce_mean(tf.square(y-y_))
    loss_total = loss_mse + tf.add_n(tf.get_collection('losses'))
    #定义反向传播方法:包含正则化的
    train_step = tf.train.AdamOptimizer(learning_rate).minimize(loss_total)
    with tf.Session() as sess:
        init_op = tf.global_variables_initializer()
        sess.run(init_op)
        for i in range(STEPS):
            start = (i * BATCH_SIZE) % 300
            end = start + BATCH_SIZE
            sess.run(train_step,feed_dict={x:X[start:end],y_:Y_[start:end]})
            if i % 2000 == 0:
                loss_v = sess.run(loss_total,feed_dict={x:X,y_:Y_})
                print "After %d steps ,loss is:%f"%(i,loss_v)
        xx,yy = np.mgrid[-3:3:0.01,-3:3:0.01]
        grid = np.c_[xx.ravel(),yy.ravel()]
        probs = sess.run(y,feed_dict={x:grid})
        probs = probs.reshape(xx.shape)

    plt.scatter(X[:,0],X[:,1],c=np.squeeze(Y_C))
    plt.contour(xx,yy,probs,levels=[0.5])
    plt.show()


if __name__ == '__main__':
    backword()
#coding=utf-8
#mnist_py.py

import  time
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import mnist_forward
import mnist_backward

TEST_INTERVAL_SECS=5

def test(mnist):
    with tf.Graph().as_default() as g:
        x = tf.placeholder(tf.float32,[None,mnist_forward.INPUT_NODE])
        y_ = tf.placeholder(tf.float32,[None,mnist_forward.OUTPUT_NODE])
        y = mnist_forward.forward(x,None)

        ema = tf.train.ExponentialMovingAverage(mnist_backward.MOVING_AVERAGE_DECAY)
        ema_restore = ema.variables_to_restore()
        saver = tf.train.Saver(ema_restore)

        correct_predication = tf.equal(tf.argmax(y,1),tf.argmax(y_,1))
        accuracy = tf.reduce_mean(tf.cast(correct_predication,tf.float32))

        while True:
            with tf.Session() as sess:
                ckpt = tf.train.get_checkpoint_state(mnist_backward.MODEL_SAVE_PATH)
                if ckpt and ckpt.model_checkpoint_path:
                    saver.restore(sess,ckpt.model_checkpoint_path)
                    global_step = ckpt.model_checkpoint_path.split('/')[-1].split('-')[-1]
                    accuracy_score = sess.run(accuracy,feed_dict={x:mnist.test.images,y_:mnist.test.labels})
                    print "After %s train steps, test accuracy = %g"%(global_step,accuracy_score)
                else:
                    print "No checkpoint file found"
                    return
            time.sleep(TEST_INTERVAL_SECS)


def main():
    mnist = input_data.read_data_sets("./data/",one_hot=True)
    test(mnist)


if __name__ == "__main__":
    main()

如果觉得我的文章对您有用,请随意打赏。如果有其他问题请联系博主QQ(909491009)或者下方留言!

发表回复

您的电子邮箱地址不会被公开。 必填项已用 * 标注