【Python】TensorFlow學習筆記(三):再探 TFRecord

費了好一番功夫,終於把自己的資料集轉成 TensorFlow 可以用的檔案格式,
接著就是讀取檔案,但是如何讀檔還真不是件容易的事。

讀取一般檔案

學習 TensorFlow 的過程中,總是一直遇到大大小小的問題,到現在也差不多該習慣了。
若是一天沒有遇到問題,說不定還會覺得渾身不對勁...

開玩笑的,拜託問題不要來。
夏恩只是隨便說說,您老人家就隨便聽聽。

言歸正傳。
在眾多問題中,比較惱人的地方我覺得有兩個:

其一,是該怎麼建立一個好的模型?
其二,是該怎麼讀取 TFRecord 檔?

第一的問題有點大,同時也是全世界都在研究的問題。
夏恩覺得應該沒有什麼特別的解法,想要找到一個好的模型大概只能多看點文獻,
把理論弄熟,了解各項參數的意義等,更多的實務經驗,意味著能夠更快找到適合的模型及參數。

第二個問題就是這次要分享的主題。

不得不說,TensorFlow 在讀取數據這一塊真的很抽象。
夏恩會建議可以閱讀以下這篇文章:

十圖詳解TensorFlow數據讀取機制

這篇文章畫得圖還真不錯!

在此,先撇開 TFRecord 檔,因為許多讀者說這一段跳太快,令人費解。
既然這樣,那我們就先「單純地」來看看一般的讀檔。

舉個例子:在這個資料夾中,有四個檔案,如下:


本恩要讀取這四個檔案,然後寫出這四個檔案。

在 TensorFlow 中,讀檔的機制分為三個部分:文件本身、文件隊列、主程式。
這邊和我們所熟悉的讀檔有所不同,我們熟悉的是主程式直接去讀取文件。
其運作示意像是:

文件 < ---- > 主程式 

意思就是可以直接讀、寫,不會有任何阻礙。

而在 TensorFlow 的運作示意像是:

讀:
文件 ==> 文件隊列(由獨立的執行緒運作) ==> 主程式

寫:
主程式 ==> 文件

意思就是寫的時候沒問題,讀的時候則是讀取文件隊列,而非文件本身。
再次強調,主程式和文件隊列是由不同的執行緒負責。
看到這邊,應該可以發現「啟動文件隊列」的這個動作是免不了的,等會兒看到這個步驟也不用太訝異。

在 TensorFlow 中,一般讀檔的步驟如下:

1. 使用 tf.string_input_producer 產生文件名隊列。

shuffle 參數表示是否要打亂文件讀取的順序;
num_epochs = N 表示將這個數據集中的圖片全部讀取 N 遍。

具體來說,就是如果有 3 個檔案在裡面,那麼就會送 3 個檔案出來。
若 num_epochs = 2,那就送 6 個檔案出來,
每個檔案送完一次之後,都會重複再送一次,依此類推...

在機器學習中,epoch = N,就是將數據集從頭到尾運算 N 次的意思,
若 num_epochs = None,表示不限循環次數,直到其他的中止條件達成才會停止計算。

2. 使用 tf.WholeFileReader 去讀文件名隊列。

請注意,因為現在要讀取一般文件,所以是用這個,如果是 TFRecord 檔的話,請使用 tf.TFRecordReader。

3. 使用 tf.train.Coordinator 和 tf.train.start_queue_runners 啟動隊列。

4. 結束後記得把多餘的隊列關一關。

以下附上範例程式:

# -*- coding: utf-8 -*-
"""
作者:Shayne
程式簡介:很普通的範例程式
"""
import tensorflow as tf

# 全部要讀取的文件名
filename = ['0.jpg', '2.jpg', '6.jpg', '7.jpg']

# 產生文件名隊列
filename_queue = tf.train.string_input_producer(filename, 
                                                shuffle=False, 
                                                num_epochs=1)

# 數據讀取器,不要用錯囉!
#reader = tf.TFRecordReader()

reader = tf.WholeFileReader()
key, value = reader.read(filename_queue)

with tf.Session() as sess:
    # 初始化是必要的動作
    sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())
    
    # 建立執行緒協調器
    coord = tf.train.Coordinator()
    
    # 啟動文件隊列,開始讀取文件
    threads = tf.train.start_queue_runners(coord=coord)

    count = 0
    
    try:
        while not coord.should_stop():

            # 這邊讀檔
            image_data = sess.run(value)
            
            # 這邊寫檔
            # 請注意,因為 WholeFileReader 讀進來的是二進制檔,
            # 輸出的時候也要使用二進制的方式。
            with open('./test_%d.jpg' % count, 'wb') as f:
                f.write(image_data)
                count += 1
                
    except tf.errors.OutOfRangeError:
            print('Done!')

    finally:
        # 最後要記得把文件隊列關掉
        coord.request_stop()
    
    coord.join(threads)

執行之後,同樣的圖片會出現。

這隻程式有幾個地方您可以親自測試。

第一:請更改 tf.train.string_input_producer 中的 shuffle 參數。

改了之後可以發現輸出的結果就不是按照順序囉!

第二:請更改 tf.train.string_input_producer 中的 num_epochs 參數。

把 1 改成 2,原本輸出 4 個檔案就會變成 8 個;
把 1 改成 3,原本輸出 4 個檔案就會變成 12 個;
把 1 改成 4,原本輸出 4 個檔案就會變成 16 個...

這句話可以重複一百次,不過我猜您已經知道本恩想要表示什麼了。
講這麼多,還是自己動手做比較有體會對不對!

讀取 TFRecord 檔案

有了讀取一般檔案的經驗之後,想必挑戰 TFRecord 檔案也是可以。
讀取 TFRecord 檔的流程大概分為幾個步驟:

Step 1: 把之前包裝好的「TFRecord」檔案,利用「tf.train.string_input_producer」做成文件隊列。

這邊的 filename 就要使用稍早前我們已經包裝好的「TFRecord」檔案。

同樣地,shuffle 參數表示是否要打亂文件讀取的順序;
num_epochs = N 表示將這個數據集中的圖片全部讀取 N 遍。

Step 2: 啟動「tf.TFRecordReader()」讀取數據。

Step 3: 解析數據。

還記得剛才讀取一般檔案的時候,讀進來的格式為二進制檔,對吧?
這邊同樣也是,不過更複雜。

因為 TFRecord 檔有經過我們自己賦予的特殊規則。

像是影像的標籤、影像本身的資訊等。
所以我們必須根據剛才的包裝順序,依序解開這個檔案。

到這邊,先複習一下我們是如何包裝數據的?

寫入 TFRecord 檔案的標準做法是:

step 1. 把所有資料轉換成「tf.train.Feature」格式。

step 2. 把所有的「tf.train.Feature」包裝成「tf.train.Features」格式。

step 3. 把所有的「tf.train.Features」組合成「tf.train.Example」格式。

step 4. 利用「tf.python_io.TFRecordWriter」將「tf.train.Example」寫入成 TFRecord 檔案。

看清楚了沒?
原始資料 > Feature > Features > Example > TFRecord

所以說,之前怎麼包裝,現在就要依序反向拆開。
TFRecord > Example > Features > Feature > 原始資料

在剛才那一步,我們已經使用 tf.TFRecordReader() 讀取 TF 檔,並得到 Example 了。
接下來要解析 Example 為 Features。

Step 4: 使用「tf.parse_single_example」將「tf.train.Example」檔案解析為「tf.train.Features」。

img_features = tf.parse_single_example(
               serialized_example,
               features={'Label'    : tf.FixedLenFeature([], tf.int64),
                         'image_raw': tf.FixedLenFeature([], tf.string),})

這邊可以看到,由於之前包裝的時候「Label」的資料格式為「tf.int64」;「image_raw」為「tf.string」。
在這邊都要一樣,怎麼包就怎麼拆。

Step 5: 使用「tf.decode_raw」或是「tf.cast」將「tf.train.Features」檔案解析為「tf.train.Feature」。

image = tf.decode_raw(img_features['image_raw'], tf.uint8)
image = tf.reshape(image, [42, 42])

label = tf.cast(img_features['Label'], tf.int64)

到這邊,我們成功地把 TF 檔案,還原成一開始包裝的原始資料。

這邊針對「tf.uint8」要特別說明一下:
夏恩在包裝的時候是把影像檔轉換成文字檔,所以反轉換要使用「tf.uint8」格式轉回影像檔。
之前曾不小心把「tf.uint8」寫成「tf.float32」,所以原本42*42的影像立刻縮水一半,
變成21*21...因為兩者大小相差一倍。這個 bug 太小,且非常不起眼,找很久才找到。

把上述的段落結合,可以寫成一支程式如下:

# -*- coding: utf-8 -*-
"""
作者:Shayne
程式簡介:有點難範例程式
"""

import cv2
import tensorflow as tf

# TF檔
filename = './py_Train.tfrecords'

# 產生文件名隊列
filename_queue = tf.train.string_input_producer([filename], 
                                                 shuffle=False, 
                                                 num_epochs=1)

# 數據讀取器
reader = tf.TFRecordReader()
key, serialized_example = reader.read(filename_queue)

# 數據解析
img_features = tf.parse_single_example(
            serialized_example,
            features={ 'Label'    : tf.FixedLenFeature([], tf.int64),
                       'image_raw': tf.FixedLenFeature([], tf.string), })
    
image = tf.decode_raw(img_features['image_raw'], tf.uint8)
image = tf.reshape(image, [42, 42])

label = tf.cast(img_features['Label'], tf.int64)

with tf.Session() as sess:
    # 初始化是必要的動作
    sess.run(tf.global_variables_initializer())
    sess.run(tf.local_variables_initializer())
    
    # 建立執行緒協調器
    coord = tf.train.Coordinator()
    
    # 啟動文件隊列,開始讀取文件
    threads = tf.train.start_queue_runners(coord=coord)

    count = 0 
    try:
        # 讀 10 張影像
        while count<10:
            
            # 這邊讀取
            image_data, label_data = sess.run([image, label])
            
            # 這邊輸出
            # 因為已經經過解碼,二進制的資料已經轉換成影像檔,因此可以直接使用
            # 影像檔的方式輸出資料。
            cv2.imwrite('./tf_%d_%d.jpg' % (label_data, count), image_data)
            count += 1

        print('Done!')        
        
    except tf.errors.OutOfRangeError:
        print('Done!')

    finally:
        # 最後要記得把文件隊列關掉
        coord.request_stop()
    
    coord.join(threads)

這支程式執行後的結果如下,原本只有一個 TFRecord 檔案,執行後從中取出了 10 的影像出來。

到這邊就算是完成 TFRecord 檔案讀取了!
下一個章節,夏恩準備仔細聊聊,有關該如何連接 TF 檔與數字辨識模型。

【Python】TensorFlow學習筆記(四):用 TFRecord 餵食 Softmax Model