用均方误差来优化神经网络预测的结果

作者: shaneZhang 分类: 机器学习的实践 发布时间: 2018-12-24 13:53
#coding=utf-8
# 预测酸奶日销量 y,x1 和 x2 是影响日销量的两个因素。
# 应提前采集的数据有:一段时间内,每日的 x1 因素、x2 因素和销量 y_。采集的数据尽量多。
# 在本例中用销量预测产量,最优的产量应该等于销量。由于目前没有数据集,所以拟造了一套数
# 据集。利用 Tensorflow 中函数随机生成 x1、 x2,制造标准答案 y_ = x1 + x2,为了更真实,求和后 还加了正负 0.05 的随机噪声。

import tensorflow as tf
import numpy as np
BATCH_SIZE= 8
SEED = 23455
rdm = np.random.RandomState(SEED)
X = rdm.rand(32,2)
Y_ = [[x1 + x2 + (rdm.rand()/10 - 0.05)] for  (x1,x2) in X]

#定义神经网络的输入、参数和输出
x = tf.placeholder(tf.float32, shape=(None,2))
y_ = tf.placeholder(tf.float32,shape=(None,1))
w1 = tf.Variable(tf.random_normal([2,1],stddev=1,seed=1))
y = tf.matmul(x,w1)

#定义损失函数及反向传播方法
#定义损失函数为MSE,反向传播方法为梯度下降
loss_mse = tf.reduce_mean(tf.square(y_-y))
train_step = tf.train.GradientDescentOptimizer(0.001).minimize(loss_mse)

#生成会话并训练STEPS轮
with tf.Session() as sess:
    init_op = tf.global_variables_initializer()
    sess.run(init_op)
    STEPS = 20000
    for i in range(STEPS):
        start = (i * BATCH_SIZE) % 32
        end = (i * BATCH_SIZE) % 32 + BATCH_SIZE
        sess.run(train_step,feed_dict={x:X[start:end],y_:Y_[start:end]})
        if i % 500 == 0:
            print "After %d training steps,w1 is: "%(i)
            print sess.run(w1)
    print "Final w1 is:\n",sess.run(w1)
本案例输出的结果是:
After 0 training steps,w1 is: 
[[-0.80974597]
 [ 1.4852903 ]]
After 500 training steps,w1 is: 
[[-0.46074435]
 [ 1.641878  ]]
After 1000 training steps,w1 is: 
[[-0.21939856]
 [ 1.6984766 ]]
After 1500 training steps,w1 is: 
[[-0.04415595]
 [ 1.7003176 ]]
After 2000 training steps,w1 is: 
[[0.08942621]
 [1.673328  ]]
After 2500 training steps,w1 is: 
[[0.19583555]
 [1.6322677 ]]
After 3000 training steps,w1 is: 
[[0.28375748]
 [1.5854434 ]]
After 3500 training steps,w1 is: 
[[0.35848638]
 [1.5374472 ]]
After 4000 training steps,w1 is: 
[[0.42332518]
 [1.4907393 ]]
After 4500 training steps,w1 is: 
[[0.48040026]
 [1.4465574 ]]
After 5000 training steps,w1 is: 
[[0.53113604]
 [1.4054536 ]]
After 5500 training steps,w1 is: 
[[0.5765325]
 [1.3675941]]
After 6000 training steps,w1 is: 
[[0.61732584]
 [1.3329403 ]]
After 6500 training steps,w1 is: 
[[0.6540846]
 [1.3013426]]
After 7000 training steps,w1 is: 
[[0.6872685]
 [1.272602 ]]
After 7500 training steps,w1 is: 
[[0.71725976]
 [1.2465005 ]]
After 8000 training steps,w1 is: 
[[0.7443861]
 [1.2228197]]
After 8500 training steps,w1 is: 
[[0.7689324]
 [1.2013483]]
After 9000 training steps,w1 is: 
[[0.79115134]
 [1.1818889 ]]
After 9500 training steps,w1 is: 
[[0.811267 ]
 [1.1642567]]
After 10000 training steps,w1 is: 
[[0.8294814]
 [1.1482829]]
After 10500 training steps,w1 is: 
[[0.84597576]
 [1.1338125 ]]
After 11000 training steps,w1 is: 
[[0.8609128]
 [1.1207061]]
After 11500 training steps,w1 is: 
[[0.87444043]
 [1.1088346 ]]
After 12000 training steps,w1 is: 
[[0.88669145]
 [1.0980824 ]]
After 12500 training steps,w1 is: 
[[0.8977863]
 [1.0883439]]
After 13000 training steps,w1 is: 
[[0.9078348]
 [1.0795243]]
After 13500 training steps,w1 is: 
[[0.91693527]
 [1.0715363 ]]
After 14000 training steps,w1 is: 
[[0.92517716]
 [1.0643018 ]]
After 14500 training steps,w1 is: 
[[0.93264157]
 [1.0577497 ]]
After 15000 training steps,w1 is: 
[[0.9394023]
 [1.0518153]]
After 15500 training steps,w1 is: 
[[0.9455251]
 [1.0464406]]
After 16000 training steps,w1 is: 
[[0.95107025]
 [1.0415728 ]]
After 16500 training steps,w1 is: 
[[0.9560928]
 [1.037164 ]]
After 17000 training steps,w1 is: 
[[0.96064115]
 [1.0331714 ]]
After 17500 training steps,w1 is: 
[[0.96476096]
 [1.0295546 ]]
After 18000 training steps,w1 is: 
[[0.9684917]
 [1.0262802]]
After 18500 training steps,w1 is: 
[[0.9718707]
 [1.0233142]]
After 19000 training steps,w1 is: 
[[0.974931 ]
 [1.0206276]]
After 19500 training steps,w1 is: 
[[0.9777026]
 [1.0181949]]
Final w1 is:
[[0.98019385]
 [1.0159807 ]]

Process finished with exit code 0

本例中神经网络预测模型为 y = w1*x1 + w2*x2,损失函数采用均方误差。通过使 损失函数值(loss)不断降低,神经网络模型得到最终参数 w1=0.98,w2=1.02,销量预测结果为 y = 0.98*x1 + 1.02*x2。由于在生成数据集时,标准答案为 y = x1 + x2,因此,销量预测结果和标准答案已非常接近,说明该神经网络预测酸奶日销量正确。

如果觉得我的文章对您有用,请随意打赏。如果有其他问题请联系博主QQ(909491009)或者下方留言!

发表回复

您的电子邮箱地址不会被公开。 必填项已用 * 标注