博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
Tensorflow的MNIST进阶,准确率提升情况,最终训练一万次,准确率达到99.28%,可以说比官方的效果还要好
阅读量:4281 次
发布时间:2019-05-27

本文共 3759 字,大约阅读时间需要 12 分钟。

补充:使用keras建模

增加步骤1:随机取样

增加步骤2:旋转训练样本图片角度(-25~+25)度
批量大小:64
训练50轮:每个角度各1轮
准确率提升至:99.48%

for i in range(0,25):		shuffle_idx = np.random.permutation(np.arange(len(x_train)))	xx = x_train[shuffle_idx,]	yy = y_train[shuffle_idx,]	 	rot = i	for j in range(xx.shape[0]):		scipy.ndimage.interpolation.rotate(xx[j].reshape(28, 28), rot, cval=0.01, reshape=False).reshape(1, -1)	model.fit(xx,yy,epochs=1,batch_size=64)	rot = -i	for j in range(xx.shape[0]):		scipy.ndimage.interpolation.rotate(xx[j].reshape(28, 28), rot, cval=0.01, reshape=False).reshape(1, -1)	model.fit(xx,yy,epochs=1,batch_size=64)

在这里插入图片描述

原博客:

1. 训练1000次,预测整个测试集的准确率约96.37%,耗时20分钟。(使用cpu版本tensorflow-1.2.1)

2. 从存档开始第2次训练1000次(共2000次),预测整个测试集的准确率约98.33%(嗯,不错的提升)

3. 从存档开始第3次训练1000次(共3000次),预测整个测试集的准确率约98.74%(提升不大)

4. 每次取200个样本,1000次需20分钟。改成每次取50个后,1000次(共4000次)训练耗时7分钟,远远少于20分钟。但是训练后准确率依然为98.74%

5. 继续改回每次取200个样本,训练1000次(共5000次),耗时20分钟,准确率提升到了99%

6. 增加到每次取500个样本,训练200次(共5200次),准确率为99.07%

7. 下班前挂机,每次取200个样本,训练5000次(共10200次),准确率为99.28%,官网demo示例给出的准确率大概是99.2%,见http://www.tensorfly.cn/tfdoc/tutorials/mnist_pros.html

下班了(3月20日),挂机训练5000次,目前看来效果不错。在这里插入图片描述

上班看效果(3月21日),到此为止共训练一万次,准确率提高到了99.28%,比官方(两万次,99.2%)的效果还好。

在这里插入图片描述

import tensorflow as tffrom tensorflow.examples.tutorials.mnist import input_datamnist = input_data.read_data_sets("./mnist_data/",one_hot=True)x = tf.placeholder("float",[None,784])def weight_variable(shape):	initial = tf.truncated_normal(shape,stddev=0.1)	return tf.Variable(initial)def bias_variable(shape):	initial = tf.constant(0.1,shape=shape)	return tf.Variable(initial)def conv2d(x,W):	return tf.nn.conv2d(x,W,strides=[1,1,1,1],padding="SAME")def max_pool_2x2(x):	return tf.nn.max_pool(x,ksize=[1,2,2,1],strides=[1,2,2,1],padding="SAME")# 第一次卷积w_conv1 = weight_variable([5,5,1,32])b_conv1 = bias_variable([32])x_image = tf.reshape(x,[-1,28,28,1])h_conv1 = tf.nn.relu(conv2d(x_image,w_conv1)+b_conv1)h_pool1 = max_pool_2x2(h_conv1)# 第二次卷积w_conv2 = weight_variable([3,3,32,64])b_conv2 = bias_variable([64])h_conv2 = tf.nn.relu(conv2d(h_pool1,w_conv2)+b_conv2)h_pool2 = max_pool_2x2(h_conv2)# 加入全连接层,密集连接层w_fc1 = weight_variable([7*7*64,1024])b_fc1 = bias_variable([1024])h_pool2_flat = tf.reshape(h_pool2,[-1,7*7*64])h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat,w_fc1)+b_fc1)# Dopout,减少过拟合keep_prob = tf.placeholder("float")h_fc1_drop = tf.nn.dropout(h_fc1,keep_prob)# 输出层w_fc2 = weight_variable([1024,10])b_fc2 = bias_variable([10])y_conv = tf.nn.softmax(tf.matmul(h_fc1_drop,w_fc2)+b_fc2)y_ = tf.placeholder("float",[None,10])cross_entropy = -tf.reduce_sum(y_*tf.log(y_conv))train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)correct_prediction = tf.equal(tf.argmax(y_conv,1),tf.argmax(y_,1))accuracy = tf.reduce_mean(tf.cast(correct_prediction,"float"))#存档两个。saver = tf.train.Saver(max_to_keep=2)max_acc = 0is_train=False# 不备注为训练,备注为预测精准度。is_train=Truewith tf.Session() as sess:	init = tf.global_variables_initializer()	sess.run(init)	#使用上次的存档点,训练时可减少时间成本	model_file=tf.train.latest_checkpoint('./ckpt/')	saver.restore(sess,model_file)	if is_train:		for i in range(1001):			batch = mnist.train.next_batch(50)			#每20次训练,计算输出一次准确度。			if i%20 == 0:				train_accuracy = accuracy.eval(feed_dict={					x:batch[0], y_: batch[1], keep_prob: 1.0})				print("step %d, training accuracy %g"%(i, train_accuracy))				#准确度更高时进行存档,最后一次存档。(存档点可以在checkpoint文件中自由切换使用)				if train_accuracy>=max_acc or i==1000:					max_acc = train_accuracy					saver.save(sess,'./ckpt/mnist.ckpt',global_step=i)			#进行训练。			train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})		#结束训练时(或者is_train==False,使用存档),打印出预测整个测试集的准确度(也可以使用训练好的模型做其他事情)。	print("test accuracy %g"%accuracy.eval(feed_dict={		x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}))

使用画图工具,新建28*28像素的bmp位图,使用训练好的模型,测试预测是否正确。

转载地址:http://gnbgi.baihongyu.com/

你可能感兴趣的文章
静态链接库设计
查看>>
动态链接库设计
查看>>
文件编程之库函数方式
查看>>
与时间相关的函数编程
查看>>
Linux进程控制相关概念
查看>>
c标准中的预定义宏
查看>>
*(volatile unsigned long *) 语法
查看>>
Linux多进程程序设计
查看>>
Linux进程间通讯基础
查看>>
Linux信号通讯编程
查看>>
信号量互斥编程
查看>>
信号量同步编程
查看>>
共享内存通讯编程
查看>>
Linux消息队列通讯编程
查看>>
多进程与多线程的优缺点
查看>>
Linux多线程程序设计
查看>>
网络协议基础
查看>>
Linux网络编程模型
查看>>
TCP通讯程序设计
查看>>
UDP通讯程序设计
查看>>