識別MNIST已經(jīng)成了深度學(xué)習(xí)的hello world,所以每次例程基本都會用到這個數(shù)據(jù)集,這個數(shù)據(jù)集在tensorflow內(nèi)部用著很好的封裝,因此可以方便地使用。
這次我們用tensorflow搭建一個softmax多分類器,和之前搭建線性回歸差不多,第一步是通過確定變量建立圖模型,然后確定誤差函數(shù),最后調(diào)用優(yōu)化器優(yōu)化。
誤差函數(shù)與線性回歸不同,這里因為是多分類問題,所以使用了交叉熵。
另外,有一點值得注意的是,這里構(gòu)建模型時我試圖想拆分多個函數(shù),但是后來發(fā)現(xiàn)這樣做難度很大,因為圖是在規(guī)定變量就已經(jīng)定義好的,不能隨意拆分,也不能當(dāng)做變量傳來傳去,因此需要將他們寫在一起。
代碼如下:
#encoding=utf-8 __author__ = 'freedom' import tensorflow as tf def loadMNIST(): from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets('MNIST_data',one_hot=True) return mnist def softmax(mnist,rate=0.01,batchSize=50,epoch=20): n = 784 # 向量的維度數(shù)目 m = None # 樣本數(shù),這里可以獲取,也可以不獲取 c = 10 # 類別數(shù)目 x = tf.placeholder(tf.float32,[m,n]) y = tf.placeholder(tf.float32,[m,c]) w = tf.Variable(tf.zeros([n,c])) b = tf.Variable(tf.zeros([c])) pred= tf.nn.softmax(tf.matmul(x,w)+b) loss = tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred),reduction_indices=1)) opt = tf.train.GradientDescentOptimizer(rate).minimize(loss) init = tf.initialize_all_variables() sess = tf.Session() sess.run(init) for index in range(epoch): avgLoss = 0 batchNum = int(mnist.train.num_examples/batchSize) for batch in range(batchNum): batch_x,batch_y = mnist.train.next_batch(batchSize) _,Loss = sess.run([opt,loss],{x:batch_x,y:batch_y}) avgLoss += Loss avgLoss /= batchNum print 'every epoch average loss is ',avgLoss right = tf.equal(tf.argmax(pred,1),tf.argmax(y,1)) accuracy = tf.reduce_mean(tf.cast(right,tf.float32)) print 'Accracy is ',sess.run(accuracy,({x:mnist.test.images,y:mnist.test.labels})) if __name__ == "__main__": mnist = loadMNIST() softmax(mnist)
網(wǎng)頁題目:tensorflow實現(xiàn)softma識別MNIST-創(chuàng)新互聯(lián)
URL網(wǎng)址:http://m.rwnh.cn/article30/ggjso.html
成都網(wǎng)站建設(shè)公司_創(chuàng)新互聯(lián),為您提供域名注冊、企業(yè)建站、網(wǎng)站改版、電子商務(wù)、品牌網(wǎng)站建設(shè)、服務(wù)器托管
聲明:本網(wǎng)站發(fā)布的內(nèi)容(圖片、視頻和文字)以用戶投稿、用戶轉(zhuǎn)載內(nèi)容為主,如果涉及侵權(quán)請盡快告知,我們將會在第一時間刪除。文章觀點不代表本網(wǎng)站立場,如需處理請聯(lián)系客服。電話:028-86922220;郵箱:631063699@qq.com。內(nèi)容未經(jīng)允許不得轉(zhuǎn)載,或轉(zhuǎn)載時需注明來源: 創(chuàng)新互聯(lián)
猜你還喜歡下面的內(nèi)容