TensorFlow实现手写数字识别

1
2
3
import  tensorflow as tf
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics
import matplotlib.pyplot as plt
1
2
3
4
5
6
7
8
9
10
11
12
13

# 设置GPU使用方式
# 获取GPU列表
gpus = tf.config.experimental.list_physical_devices('GPU')
print(gpus)
if gpus:
try:
# 设置GPU为增长式占用
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
except RuntimeError as e:
# 打印异常
print(e)
[]

加载数据集

1
2
3
4
5
6
7
8
(xs, ys),_ = datasets.mnist.load_data() #自动下载
print('datasets:', xs.shape, ys.shape, xs.min(), xs.max())

batch_size = 32

xs = tf.convert_to_tensor(xs, dtype=tf.float32) / 255.
db = tf.data.Dataset.from_tensor_slices((xs,ys))
db = db.batch(batch_size).repeat(30) #将数据集分隔成batch_size个batch repeat(30)代表训练30轮
datasets: (60000, 28, 28) (60000,) 0 255

模型构建

1
2
3
4
5
6
7
8
9
# Dense 全连接层
model = Sequential([layers.Dense(256, activation='relu'),
layers.Dense(128, activation='relu'),
layers.Dense(10)])
model.build(input_shape=(4, 28*28))
model.summary() #打印模型信息

optimizer = optimizers.SGD(lr=0.01)
acc_meter = metrics.Accuracy()
Model: "sequential_3"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_9 (Dense)              multiple                  200960    
_________________________________________________________________
dense_10 (Dense)             multiple                  32896     
_________________________________________________________________
dense_11 (Dense)             multiple                  1290      
=================================================================
Total params: 235,146
Trainable params: 235,146
Non-trainable params: 0
_________________________________________________________________

模型训练

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
total_acc = []
total_loss = []
for step, (x,y) in enumerate(db):

with tf.GradientTape() as tape:
# 打平操作,[b, 28, 28] => [b, 784]
x = tf.reshape(x, (-1, 28*28))
# Step1. 得到模型输出output [b, 784] => [b, 10]
out = model(x)
# [b] => [b, 10]
y_onehot = tf.one_hot(y, depth=10)
# 计算差的平方和,[b, 10]
loss = tf.square(out-y_onehot)
# 计算每个样本的平均误差,[b]
loss = tf.reduce_sum(loss) / x.shape[0]


acc_meter.update_state(tf.argmax(out, axis=1), y)

# 更新梯度
grads = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))


if step % 200==0:
total_acc.append(acc_meter.result().numpy())
total_loss.append(float(loss))
print(step, 'loss:', float(loss), 'acc:', acc_meter.result().numpy())
acc_meter.reset_states()
0 loss: 0.15352573990821838 acc: 0.95457846
200 loss: 0.1597624123096466 acc: 0.95765626
400 loss: 0.1545962393283844 acc: 0.9475
600 loss: 0.1637483835220337 acc: 0.94578123
800 loss: 0.15958964824676514 acc: 0.9476563
1000 loss: 0.20374062657356262 acc: 0.94078124
1200 loss: 0.18700677156448364 acc: 0.94921875
1400 loss: 0.15020117163658142 acc: 0.95140624
1600 loss: 0.14052757620811462 acc: 0.9421875
1800 loss: 0.15347686409950256 acc: 0.95109373
2000 loss: 0.15199950337409973 acc: 0.965
2200 loss: 0.08358028531074524 acc: 0.95390624
2400 loss: 0.16466480493545532 acc: 0.95484376
2600 loss: 0.14589539170265198 acc: 0.954375
2800 loss: 0.10357476770877838 acc: 0.95203125
3000 loss: 0.1552669256925583 acc: 0.95109373
3200 loss: 0.13997772336006165 acc: 0.95640624
3400 loss: 0.10184890031814575 acc: 0.9504688
3600 loss: 0.0966319739818573 acc: 0.950625
3800 loss: 0.14316204190254211 acc: 0.9660938
4000 loss: 0.14158621430397034 acc: 0.9660938
4200 loss: 0.11101004481315613 acc: 0.9559375
4400 loss: 0.10630815476179123 acc: 0.96046877
4600 loss: 0.1794758141040802 acc: 0.95484376
4800 loss: 0.10882129520177841 acc: 0.9557812
5000 loss: 0.11497728526592255 acc: 0.96125
5200 loss: 0.20665663480758667 acc: 0.9559375
5400 loss: 0.22054731845855713 acc: 0.95390624
5600 loss: 0.08281977474689484 acc: 0.96828127
5800 loss: 0.1415213942527771 acc: 0.9696875
6000 loss: 0.10287574678659439 acc: 0.96234375
6200 loss: 0.14125148952007294 acc: 0.959375
6400 loss: 0.0872105211019516 acc: 0.96296877
6600 loss: 0.09639552235603333 acc: 0.95890623
6800 loss: 0.09451914578676224 acc: 0.96203125
7000 loss: 0.09597232937812805 acc: 0.96515626
7200 loss: 0.27553442120552063 acc: 0.95390624
7400 loss: 0.11118196696043015 acc: 0.9635937
7600 loss: 0.13853389024734497 acc: 0.9742187
7800 loss: 0.0864826887845993 acc: 0.96765625
8000 loss: 0.12683367729187012 acc: 0.9657813
8200 loss: 0.0835743099451065 acc: 0.96671873
8400 loss: 0.07526201754808426 acc: 0.96203125
8600 loss: 0.08792373538017273 acc: 0.9634375
8800 loss: 0.1303400695323944 acc: 0.96625
9000 loss: 0.11500948667526245 acc: 0.9615625
9200 loss: 0.07716777920722961 acc: 0.96203125
9400 loss: 0.06136278063058853 acc: 0.9740625
9600 loss: 0.16090892255306244 acc: 0.97375
9800 loss: 0.050481319427490234 acc: 0.96796876
10000 loss: 0.12725922465324402 acc: 0.9678125
10200 loss: 0.10010596364736557 acc: 0.9653125
10400 loss: 0.12236626446247101 acc: 0.964375
10600 loss: 0.06595691293478012 acc: 0.97015625
10800 loss: 0.16140709817409515 acc: 0.96734375
11000 loss: 0.07903225719928741 acc: 0.96140623
11200 loss: 0.09545521438121796 acc: 0.97078127
11400 loss: 0.08132430166006088 acc: 0.97578126
11600 loss: 0.14214324951171875 acc: 0.9703125
11800 loss: 0.0826168805360794 acc: 0.96921873
12000 loss: 0.07449790835380554 acc: 0.96875
12200 loss: 0.052427180111408234 acc: 0.9684375
12400 loss: 0.10587689280509949 acc: 0.96984375
12600 loss: 0.14333687722682953 acc: 0.96921873
12800 loss: 0.1134927049279213 acc: 0.9664062
13000 loss: 0.07688809186220169 acc: 0.96875
13200 loss: 0.1413525640964508 acc: 0.9767187
13400 loss: 0.09582379460334778 acc: 0.97484374
13600 loss: 0.10331477224826813 acc: 0.97046876
13800 loss: 0.1101207584142685 acc: 0.971875
14000 loss: 0.05487038940191269 acc: 0.96875
14200 loss: 0.1763731837272644 acc: 0.9715625
14400 loss: 0.08923567831516266 acc: 0.970625
14600 loss: 0.14577150344848633 acc: 0.96953124
14800 loss: 0.07163411378860474 acc: 0.966875
15000 loss: 0.09502746164798737 acc: 0.97828126
15200 loss: 0.08452443033456802 acc: 0.97625
15400 loss: 0.09686124324798584 acc: 0.97328126
15600 loss: 0.0802001953125 acc: 0.97296876
15800 loss: 0.10131652653217316 acc: 0.9728125
16000 loss: 0.12712924182415009 acc: 0.97015625
16200 loss: 0.10623563826084137 acc: 0.973125
16400 loss: 0.07512909919023514 acc: 0.97484374
16600 loss: 0.08629828691482544 acc: 0.9675
16800 loss: 0.09165250509977341 acc: 0.97328126
17000 loss: 0.08541638404130936 acc: 0.9796875
17200 loss: 0.03573321923613548 acc: 0.9739063
17400 loss: 0.08292931318283081 acc: 0.97515625
17600 loss: 0.09117152541875839 acc: 0.9746875
17800 loss: 0.055078715085983276 acc: 0.97265625
18000 loss: 0.07413016259670258 acc: 0.9742187
18200 loss: 0.10115832090377808 acc: 0.9742187
18400 loss: 0.06299661844968796 acc: 0.9728125
18600 loss: 0.056077420711517334 acc: 0.97046876
18800 loss: 0.08558100461959839 acc: 0.97859377
19000 loss: 0.0742747038602829 acc: 0.9803125
19200 loss: 0.06931197643280029 acc: 0.9734375
19400 loss: 0.07148833572864532 acc: 0.9767187
19600 loss: 0.11700235307216644 acc: 0.9739063
19800 loss: 0.06992246210575104 acc: 0.9735938
20000 loss: 0.05687614157795906 acc: 0.9759375
20200 loss: 0.14907342195510864 acc: 0.97453123
20400 loss: 0.1653566211462021 acc: 0.97265625
20600 loss: 0.03767940402030945 acc: 0.9795312
20800 loss: 0.07523676753044128 acc: 0.9792187
21000 loss: 0.07417837530374527 acc: 0.9765625
21200 loss: 0.08299359679222107 acc: 0.9764063
21400 loss: 0.04976201429963112 acc: 0.97765625
21600 loss: 0.04690690338611603 acc: 0.975
21800 loss: 0.04981674998998642 acc: 0.9746875
22000 loss: 0.0513470321893692 acc: 0.97828126
22200 loss: 0.22548101842403412 acc: 0.971875
22400 loss: 0.07231388241052628 acc: 0.97625
22600 loss: 0.09100791811943054 acc: 0.9828125
22800 loss: 0.06141120195388794 acc: 0.97796875
23000 loss: 0.0918603241443634 acc: 0.9795312
23200 loss: 0.04934592917561531 acc: 0.97796875
23400 loss: 0.046293482184410095 acc: 0.97578126
23600 loss: 0.055974338203668594 acc: 0.97546875
23800 loss: 0.10047663748264313 acc: 0.97703123
24000 loss: 0.0761006623506546 acc: 0.9765625
24200 loss: 0.03693201392889023 acc: 0.9742187
24400 loss: 0.03466993570327759 acc: 0.9821875
24600 loss: 0.10374853014945984 acc: 0.981875
24800 loss: 0.026225175708532333 acc: 0.98
25000 loss: 0.10136633366346359 acc: 0.9784375
25200 loss: 0.07362035661935806 acc: 0.978125
25400 loss: 0.10292142629623413 acc: 0.97453123
25600 loss: 0.03838391602039337 acc: 0.9784375
25800 loss: 0.1306307017803192 acc: 0.9790625
26000 loss: 0.056037187576293945 acc: 0.97546875
26200 loss: 0.06480588018894196 acc: 0.9803125
26400 loss: 0.0460970476269722 acc: 0.98265624
26600 loss: 0.11328999698162079 acc: 0.98078126
26800 loss: 0.052961524575948715 acc: 0.9809375
27000 loss: 0.046149395406246185 acc: 0.9784375
27200 loss: 0.03660104423761368 acc: 0.97859377
27400 loss: 0.07231131196022034 acc: 0.97796875
27600 loss: 0.0969608873128891 acc: 0.97859377
27800 loss: 0.07953917980194092 acc: 0.978125
28000 loss: 0.05938175320625305 acc: 0.97765625
28200 loss: 0.10404374450445175 acc: 0.984375
28400 loss: 0.07486425340175629 acc: 0.98328125
28600 loss: 0.06746675819158554 acc: 0.98078126
28800 loss: 0.10037437081336975 acc: 0.98078126
29000 loss: 0.04371890425682068 acc: 0.9790625
29200 loss: 0.1544964462518692 acc: 0.97984374
29400 loss: 0.06675610691308975 acc: 0.9795312
29600 loss: 0.11682221293449402 acc: 0.98046875
29800 loss: 0.055132970213890076 acc: 0.97765625
30000 loss: 0.06676920503377914 acc: 0.9845312
30200 loss: 0.061399221420288086 acc: 0.984375
30400 loss: 0.0784747451543808 acc: 0.981875
30600 loss: 0.05704239010810852 acc: 0.981875
30800 loss: 0.06897450983524323 acc: 0.98078126
31000 loss: 0.09554265439510345 acc: 0.9792187
31200 loss: 0.07846721261739731 acc: 0.97984374
31400 loss: 0.05001729726791382 acc: 0.9828125
31600 loss: 0.06418764591217041 acc: 0.978125
31800 loss: 0.06677256524562836 acc: 0.98078126
32000 loss: 0.06099291145801544 acc: 0.98609376
32200 loss: 0.02437475323677063 acc: 0.9825
32400 loss: 0.054472584277391434 acc: 0.98296875
32600 loss: 0.07025627791881561 acc: 0.9821875
32800 loss: 0.043403904885053635 acc: 0.980625
33000 loss: 0.05405097454786301 acc: 0.9817188
33200 loss: 0.07893810421228409 acc: 0.9814063
33400 loss: 0.04618637263774872 acc: 0.9815625
33600 loss: 0.03955160081386566 acc: 0.9784375
33800 loss: 0.06573085486888885 acc: 0.985
34000 loss: 0.05302225425839424 acc: 0.9859375
34200 loss: 0.04989982396364212 acc: 0.9828125
34400 loss: 0.06038925424218178 acc: 0.9834375
34600 loss: 0.08686563372612 acc: 0.98234373
34800 loss: 0.05421433597803116 acc: 0.9809375
35000 loss: 0.04355776682496071 acc: 0.9828125
35200 loss: 0.12376898527145386 acc: 0.98125
35400 loss: 0.12621557712554932 acc: 0.98078126
35600 loss: 0.025318723171949387 acc: 0.985
35800 loss: 0.047489866614341736 acc: 0.9842188
36000 loss: 0.06456907093524933 acc: 0.98390627
36200 loss: 0.05885958671569824 acc: 0.9842188
36400 loss: 0.038796693086624146 acc: 0.9834375
36600 loss: 0.028620203956961632 acc: 0.9820312
36800 loss: 0.03317582234740257 acc: 0.9814063
37000 loss: 0.03372523933649063 acc: 0.9834375
37200 loss: 0.18894273042678833 acc: 0.97859377
37400 loss: 0.06384888291358948 acc: 0.98234373
37600 loss: 0.07211901247501373 acc: 0.9870312
37800 loss: 0.05046132951974869 acc: 0.9846875
38000 loss: 0.07783858478069305 acc: 0.985625
38200 loss: 0.03523237258195877 acc: 0.9846875
38400 loss: 0.03408464789390564 acc: 0.9821875
38600 loss: 0.044989801943302155 acc: 0.9828125
38800 loss: 0.07720053195953369 acc: 0.98296875
39000 loss: 0.05239255353808403 acc: 0.9817188
39200 loss: 0.026881976053118706 acc: 0.98078126
39400 loss: 0.022131649777293205 acc: 0.98609376
39600 loss: 0.0717465877532959 acc: 0.9865625
39800 loss: 0.020211482420563698 acc: 0.98484373
40000 loss: 0.09074624627828598 acc: 0.985
40200 loss: 0.05809881165623665 acc: 0.98515624
40400 loss: 0.08850902318954468 acc: 0.981875
40600 loss: 0.02669745683670044 acc: 0.9842188
40800 loss: 0.11308488249778748 acc: 0.9834375
41000 loss: 0.0438566729426384 acc: 0.9809375
41200 loss: 0.04586176574230194 acc: 0.9846875
41400 loss: 0.036804813891649246 acc: 0.98625
41600 loss: 0.09279486536979675 acc: 0.98546875
41800 loss: 0.03883099555969238 acc: 0.9867188
42000 loss: 0.033931516110897064 acc: 0.985
42200 loss: 0.03281283751130104 acc: 0.9840625
42400 loss: 0.05394122377038002 acc: 0.9834375
42600 loss: 0.0685659870505333 acc: 0.983125
42800 loss: 0.058388061821460724 acc: 0.98328125
43000 loss: 0.05259823054075241 acc: 0.9825
43200 loss: 0.0837898999452591 acc: 0.986875
43400 loss: 0.06377860903739929 acc: 0.9870312
43600 loss: 0.04784414917230606 acc: 0.98640627
43800 loss: 0.08819957822561264 acc: 0.98640627
44000 loss: 0.033343978226184845 acc: 0.98390627
44200 loss: 0.1391856074333191 acc: 0.98609376
44400 loss: 0.05300295725464821 acc: 0.98484373
44600 loss: 0.1040736585855484 acc: 0.98359376
44800 loss: 0.041021399199962616 acc: 0.9817188
45000 loss: 0.050329744815826416 acc: 0.9884375
45200 loss: 0.04664241150021553 acc: 0.9871875
45400 loss: 0.06599676609039307 acc: 0.98640627
45600 loss: 0.04718659818172455 acc: 0.98765624
45800 loss: 0.04680747911334038 acc: 0.98609376
46000 loss: 0.0784095823764801 acc: 0.9845312
46200 loss: 0.062429361045360565 acc: 0.98625
46400 loss: 0.03825797885656357 acc: 0.985
46600 loss: 0.04770933836698532 acc: 0.983125
46800 loss: 0.05306899920105934 acc: 0.985625
47000 loss: 0.0469331368803978 acc: 0.98796874
47200 loss: 0.018814777955412865 acc: 0.9867188
47400 loss: 0.04082533344626427 acc: 0.988125
47600 loss: 0.05882826820015907 acc: 0.98640627
47800 loss: 0.035880930721759796 acc: 0.98609376
48000 loss: 0.04324483126401901 acc: 0.98640627
48200 loss: 0.06581659615039825 acc: 0.98515624
48400 loss: 0.038081344217061996 acc: 0.9853125
48600 loss: 0.032461464405059814 acc: 0.98296875
48800 loss: 0.05362715572118759 acc: 0.98796874
49000 loss: 0.04097782447934151 acc: 0.989375
49200 loss: 0.038341302424669266 acc: 0.9867188
49400 loss: 0.05667734891176224 acc: 0.9871875
49600 loss: 0.06505107134580612 acc: 0.98625
49800 loss: 0.048090700060129166 acc: 0.9867188
50000 loss: 0.038220785558223724 acc: 0.986875
50200 loss: 0.10568118095397949 acc: 0.985
50400 loss: 0.10147371888160706 acc: 0.985
50600 loss: 0.02043035626411438 acc: 0.98890626
50800 loss: 0.03445787355303764 acc: 0.9871875
51000 loss: 0.05728784203529358 acc: 0.9871875
51200 loss: 0.04533722251653671 acc: 0.98796874
51400 loss: 0.03492813929915428 acc: 0.988125
51600 loss: 0.020884167402982712 acc: 0.986875
51800 loss: 0.025441624224185944 acc: 0.985625
52000 loss: 0.026575148105621338 acc: 0.98640627
52200 loss: 0.16384227573871613 acc: 0.9842188
52400 loss: 0.05616907402873039 acc: 0.9867188
52600 loss: 0.059886958450078964 acc: 0.98890626
52800 loss: 0.04272538423538208 acc: 0.98796874
53000 loss: 0.06959083676338196 acc: 0.9884375
53200 loss: 0.026869554072618484 acc: 0.9878125
53400 loss: 0.027034994214773178 acc: 0.98640627
53600 loss: 0.03967958688735962 acc: 0.98796874
53800 loss: 0.060415420681238174 acc: 0.98546875
54000 loss: 0.03867259994149208 acc: 0.98640627
54200 loss: 0.021209895610809326 acc: 0.9845312
54400 loss: 0.017180444672703743 acc: 0.9892188
54600 loss: 0.0526532307267189 acc: 0.98890626
54800 loss: 0.015818050131201744 acc: 0.9878125
55000 loss: 0.08110405504703522 acc: 0.98859376
55200 loss: 0.04867621138691902 acc: 0.9884375
55400 loss: 0.07329011708498001 acc: 0.9867188
55600 loss: 0.0186185110360384 acc: 0.9875
55800 loss: 0.09356825053691864 acc: 0.98640627
56000 loss: 0.034789077937603 acc: 0.98484373
56200 loss: 0.03775575011968613 acc: 0.9878125
1
2
x = [i for i in range(0, len(total_loss))]
plt.plot(x,total_loss)
[<matplotlib.lines.Line2D at 0x7f8643ca4290>]


png

1
2
x = [i for i in range(0, len(total_acc))]
plt.plot(x,total_acc)
[<matplotlib.lines.Line2D at 0x7f8643e6c2d0>]

png

1