【Python】TensorFlow學習筆記(五):存檔 & 讀檔

經歷了不少風雨,做完 TFRecord 檔案,Model 也如期完工,
擋在前面的妖魔鬼怪都消滅得乾乾淨淨,最後只剩下該如何重複使用現有的模型。

tf.train.Saver

訓練好一個神經網路模型後,我們希望能夠用來預測資料。
tf.train.Saver 就是 TensorFlow 所提供的存檔工具。

若您以這個函數作為關鍵字搜尋,即可看到堆積如山的教學。
在此,夏恩挑一個最簡單,最討喜的方法,來和大家分享。

訓練的方式夏恩就以上一篇最後的方法為例:

# -*- coding: utf-8 -*-
"""
作者:Shayne
程式簡介:加入存檔功能,且有點難的簡單的建模程式
"""

import tensorflow as tf

def read_and_decode(filename, batch_size): 
    # 建立文件名隊列
    filename_queue = tf.train.string_input_producer([filename], num_epochs=None)
    
    # 數據讀取器
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(filename_queue)
    
    # 數據解析
    img_features = tf.parse_single_example(
            serialized_example,
            features={ 'Label'    : tf.FixedLenFeature([], tf.int64),
                       'image_raw': tf.FixedLenFeature([], tf.string), })
    
    image = tf.decode_raw(img_features['image_raw'], tf.uint8)
    image = tf.reshape(image, [42, 42])
    
    label = tf.cast(img_features['Label'], tf.int64)

    # 依序批次輸出 / 隨機批次輸出
    # tf.train.batch / tf.train.shuffle_batch
    image_batch, label_batch =tf.train.shuffle_batch(
                                 [image, label],
                                 batch_size=batch_size,
                                 capacity=10000 + 3 * batch_size,
                                 min_after_dequeue=1000)

    return image_batch, label_batch

###############
# 以下為主程式 #

# tfrecord 檔案位置
filename = './py_Train.tfrecords'

# batch 可以自由設定
batch_size = 256

# 0-9共10個類別,請根據自己的資料修改
Label_size = 10

# 調用 read_and_decode 函數
image_batch, label_batch = read_and_decode(filename, batch_size)

# 轉換陣列的形狀
image_batch_train = tf.reshape(image_batch, [-1, 42*42])

# 把 Label 轉換成獨熱編碼
label_batch_train = tf.one_hot(label_batch, Label_size)

# W 和 b 就是我們要訓練的對象
W = tf.Variable(tf.zeros([42*42, Label_size]))
b = tf.Variable(tf.zeros([Label_size]))

# 我們的影像資料,會透過 x 變數來輸入 
x = tf.placeholder(tf.float32, [None, 42*42])

# 這是參數預測的結果
y = tf.nn.softmax(tf.matmul(x, W) + b)

# 這是每張影像的正確標籤
y_ = tf.placeholder(tf.float32, [None, 10])

# 計算最小交叉熵
cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=y_, logits=y))

# 使用梯度下降法來找最佳解
train_step = tf.train.GradientDescentOptimizer(0.05).minimize(cross_entropy)

# 計算預測正確率
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))

####################################################
####################################################
# 新增的內容在這邊 #

# 計算 y 向量的最大值
y_pred = tf.argmax(y, 1)

# 建立 tf.train.Saver 物件
saver = tf.train.Saver()

# 將輸入與輸出值加入集合
tf.add_to_collection('input' , x)
tf.add_to_collection('output', y_pred)

####################################################
####################################################
    
with tf.Session() as sess:
    # 初始化是必要的動作
    sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())
    
    # 建立執行緒協調器
    coord = tf.train.Coordinator()
    
    # 啟動文件隊列,開始讀取文件
    threads = tf.train.start_queue_runners(coord=coord)
    
    # 迭代 100 次,看看訓練的成果
    for count in range(100):     
        # 這邊開始讀取資料
        image_data, label_data = sess.run([image_batch_train, label_batch_train])
   
        # 送資料進去訓練
        sess.run(train_step, feed_dict={x: image_data, y_: label_data})
        
        # 這裡是結果展示區,每 10 次迭代後,把最新的正確率顯示出來
        if count % 10 == 0:
            train_accuracy = accuracy.eval(feed_dict={x: image_data, y_: label_data})
            print('Iter %d, accuracy %4.2f%%' % (count, train_accuracy*100))

    # 結束後記得把文件名隊列關掉
    coord.request_stop() 
    coord.join(threads)
    
    ####################################################
    # 這裡也是新增的內容 #

    # 存檔路徑 #
    save_path = './model/test_model'

    # 把整張計算圖存檔
    spath = saver.save(sess, save_path)
    print("Model saved in file: %s" % spath)
    ####################################################

這支程式在夏恩的電腦上是可以直接執行的,若您無法順利執行...
應該是不會啦。

相信您也發現了多出來的程式碼,我們來仔細瞧瞧:

####################################################
####################################################
# 新增的內容在這邊 #

# 計算 y 向量的最大值
y_pred = tf.argmax(y, 1)

# 建立 tf.train.Saver 物件
saver = tf.train.Saver()

# 將輸入與輸出值加入集合
tf.add_to_collection('input' , x)
tf.add_to_collection('output', y_pred)

####################################################
####################################################

首先,我們必須另外給定一個 y_pred 變數。
當然,您也可以直接把 y 存下來,但請別忘記:
在程式中,y 是個一維向量,其值類似:[0.1, 0.2, 0.6, 0.1],是不同類別所代表的機率值。
因此我們在這邊調用 tf.argmax 函數,以求得向量中最大值所在的編號,並作為預測的結果。

接著是建立 tf.train.Saver 物件,這個物件會把我們之前所定義的變數都保存下來。
在本程式中,就是權重 W 以及誤差 b,它們的定義都是 tf.Variable。

另外,請注意位於 tf.train.Saver 之後的變數不會被儲存。

最後一個就是使用 add_to_collection 函數,將我們的輸入值與輸出值包裝成集合。
下圖是從官網上擷取下來的說明:

第一個輸入參數是集合名稱,可以自行定義,自己明白是什麼就好;
第二個輸入參數就是我們要存檔的對象。

在上面的程式也可以改成:

# 分別存到 x 集合與 y 集合
tf.add_to_collection('x', x)
tf.add_to_collection('y', y_pred)

###

# 通通存到 A 集合
tf.add_to_collection('A', x)
tf.add_to_collection('A', y_pred)

總之,集合的名稱設定請您自由發揮。

在完成上述的設定後,我們會在計算圖之中啟動 Saver,才算是儲存完畢。
啟動的語法就是在程式中新增內容的第二個部分:

####################################################
# 這裡也是新增的內容 #

# 存檔路徑 #
save_path = './model/test_model'

# 把整張計算圖存檔
spath = saver.save(sess, save_path)
print("Model saved in file: %s" % spath)

####################################################

使用 saver.save 把整張計算圖存到指定位置,就完成了這個章節。

tf.train.restore

程式執行之後應該會看到結果長這樣:

在夏恩的設定中,這些檔案會統一存在名為 model 的資料夾內。

這時候我們要重新寫一支程式,用來讀取並重建模型:

# -*- coding: utf-8 -*-
"""
作者:Shayne
程式簡介:讀取並重建模型程式
"""

import cv2
import tensorflow as tf

with tf.Session() as sess:

    ##################################################
    # load model #
    save_path = "./model/test_model.meta"

    # 使用 import_meta_graph 載入計算圖
    saver = tf.train.import_meta_graph(save_path)

    # 使用 restore 重建計算圖
    saver.restore(sess, "./model/test_model")
    
    # 取出集合內的值
    x = tf.get_collection("input")[0]
    y = tf.get_collection("output")[0]

    ##################################################
    
    # 讀一張影像
    img = cv2.imread('img1.jpg', 0);

    # 辨識影像,並印出結果
    result = sess.run(y, feed_dict = {x: img.reshape((-1, 42*42))})
    print(result)

這裡夏恩補充兩點:

1. 若存檔時使用同一個集合名稱,範例如下:

# 通通存到 A 集合
tf.add_to_collection('A', x)
tf.add_to_collection('A', y_pred)

那麼讀取時就依存檔順序取出:

# 取出集合內的值
x = tf.get_collection("A")[0]
y = tf.get_collection("A")[1]

2. 若是在即時辨識系統中,建議把重建的步驟擺在第一步:

# 以下為建議的寫法
with tf.Session() as sess:

    ##################################################

    # tf.train.restore...
    # 重建計算圖很費時,請放在無限迴圈的外面

    ##################################################
    
    while True:
        # 即時影像辨識流程...

        # 辨識影像,並印出結果
        result = sess.run(y, feed_dict = {x: img.reshape((-1, 42*42))})
        print(result)


# 以下為不建議的寫法
while True:

    # 即時影像辨識流程...
    # ...

    with tf.Session() as sess:
        # 重建計算圖
        # tf.train.restore...

        # 辨識影像,並印出結果
        result = sess.run(y, feed_dict = {x: img.reshape((-1, 42*42))})
        print(result)

不建議把重建的步驟放在迴圈內,速度大概差了 10 幾倍,模型愈大差愈多。
因為重建計算圖的成本是固定的,做愈少次愈好。

模型訓練進階 

明白該怎麼使用 Saver 和 restore 之後,下一個問題就是「訓練時被中斷」怎麼辦?

剛才的範例是在訓練結束之後,再把最後的結果儲存下來,
但實際上,大型的深度學習非常耗時,少則 7 天,多達數個月。

所以必須有一套可以在模型被中斷時能夠接續的方法。

還記得剛才提到的輸出檔嗎?
是不是有一個檔案名稱為 checkpoint?

這個檔案紀錄了目前最新存檔的狀態,我們可以在使用 Saver.save 時,加入第三個參數:global_step。

global_step 參數會在存檔時接續在檔名的末端,使得我們可以大量存檔,以下為 Saver.save 的調用方法:

spath = saver.save(sess, save_path, global_step=count)

其中,global_step 所代的值,會接在檔名後面,如下圖:

另外,新增一個變數:iscontinue ,來表示是否要接續原有模型進行訓練。

tf.train.Saver 物件的預設值是保留最近五次的訓練結果,預設值的修改在建立物件時就可以指定,
詳細的參數請參考官網說明:

若是要載入原有的模型,可以先調用 tf.train.latest_checkpoint 確認最新一次的訓練結果。
接著用 tf.train.import_meta_graph 和 saver.restore 把計算圖重建回來。

以下範例代碼,省略了前面定義權重與算式的過程,因為內容都一樣,夏恩就不複述了。
主要展示該如何延續未訓練完之模型的方法。

import tensorflow as tf

# 建立 tf.train.Saver 物件
saver = tf.train.Saver()

# 存檔路徑 #
save_path = './model/test_model'

# 狀態指定
iscontinue = 1

with tf.Session() as sess:
    # 初始化
    sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())
    
    coord = tf.train.Coordinator()
    threads = tf.train.start_queue_runners(coord=coord)
    
    if not iscontinue:
        # 迭代 100 次,看看訓練的成果
        for count in range(100):     
            # 讀取資料
            image_data, label_data = sess.run([image_batch_train, label_batch_train])
       
            # 訓練
            sess.run(train_step, feed_dict={x: image_data, y_: label_data})
            
            # 結果展示
            if count % 10 == 0:
                train_accuracy = accuracy.eval(feed_dict={x: image_data, y_: label_data})
                print('Iter %d, accuracy %4.2f%%' % (count, train_accuracy*100))
                
                # 存檔
                spath = saver.save(sess, save_path, global_step=count)
                print("Model saved in file: %s" % spath)
                
    else:
        # 重建 model #
        last_ckp = tf.train.latest_checkpoint("./model")
        saver = tf.train.import_meta_graph(last_ckp+'.meta')
        saver.restore(sess, last_ckp)
            
        # 延續舊有資料繼續迭代 100 次
        for count in range(100, 200):     
            # 讀取資料
            image_data, label_data = sess.run([image_batch_train, label_batch_train])
    
            # 訓練
            sess.run(train_step, feed_dict={x: image_data, y_: label_data})

            # 結果展示
            if count % 10 == 0:
                train_accuracy = accuracy.eval(feed_dict={x: image_data, y_: label_data})
                print('Iter %d, accuracy %4.2f%%' % (count, train_accuracy*100))
                
                spath = saver.save(sess, save_path, global_step=count)
                print("Model saved in file: %s" % spath)

    # 關掉文件名隊列
    coord.request_stop() 
    coord.join(threads)

到這邊,該如何存取模型資料可說是告一段落。
雖然比預想中佔用了更多了篇幅...

下一篇真的要來講深度學習的事情了...吧,沒意外的話。

【Python】TensorFlow學習筆記(六):卷積的那些小事