如果深層神經(jīng)網(wǎng)絡(luò)模型的復(fù)雜度非常高的話,那么訓(xùn)練它可能需要相當(dāng)長的一段時間,當(dāng)然這也取決于你擁有的數(shù)據(jù)量,運行模型的硬件等等。在大多數(shù)情況下,你需要通過保存文件來保障你試驗的穩(wěn)定性,防止如果中斷(或一個錯誤),你能夠繼續(xù)從沒有錯誤的地方開始。 更重要的是,對于任何深度學(xué)習(xí)的框架,像TensorFlow,在成功的訓(xùn)練之后,你需要重新使用模型的學(xué)習(xí)參數(shù)來完成對新數(shù)據(jù)的預(yù)測。 在這篇文章中,我們來看一下如何保存和恢復(fù)TensorFlow模型,我們在此介紹一些最有用的方法,并提供一些例子。 1.首先我們將快速介紹TensorFlow模型 TensorFlow的主要功能是通過張量來傳遞其基本數(shù)據(jù)結(jié)構(gòu)類似于NumPy中的多維數(shù)組,而圖表則表示數(shù)據(jù)計算。它是一個符號庫,這意味著定義圖形和張量將僅創(chuàng)建一個模型,而獲取張量的具體值和操作將在會話(session)中執(zhí)行,會話(session)一種在圖中執(zhí)行建模操作的機制。會話關(guān)閉時,張量的任何具體值都會丟失,這也是運行會話后將模型保存到文件的另一個原因。 通過示例可以幫助我們更容易理解,所以讓我們?yōu)槎S數(shù)據(jù)的線性回歸創(chuàng)建一個簡單的TensorFlow模型。 首先,我們將導(dǎo)入我們的庫: import tensorflow as tf import numpy as np import matplotlib.pyplot as plt %matplotlib inline 下一步是創(chuàng)建模型。我們將生成一個模型,它將以以下的形式估算二次函數(shù)的水平和垂直位移: y = (x - h) ^ 2 + v 其中 以下是如何生成模型的過程(有關(guān)詳細(xì)信息,請參閱代碼中的注釋): # Clear the current graph in each run, to avoid variable duplicationtf.reset_default_graph()# Create placeholders for the x and y pointsX = tf.placeholder('float') Y = tf.placeholder('float')# Initialize the two parameters that need to be learnedh_est = tf.Variable(0.0, name='hor_estimate') v_est = tf.Variable(0.0, name='ver_estimate')# y_est holds the estimated values on y-axisy_est = tf.square(X - h_est) + v_est# Define a cost function as the squared distance between Y and y_estcost = (tf.pow(Y - y_est, 2))# The training operation for minimizing the cost function. The# learning rate is 0.001trainop = tf.train.GradientDescentOptimizer(0.001).minimize(cost) 在創(chuàng)建模型的過程中,我們需要有一個在會話中運行的模型,并且傳遞一些真實的數(shù)據(jù)。我們生成一些二次數(shù)據(jù)(Quadratic data),并給他們添加噪聲。 # Use some values for the horizontal and vertical shifth = 1 v = -2# Generate training data with noisex_train = np.linspace(-2,4,201) noise = np.random.randn(*x_train.shape) * 0.4 y_train = (x_train - h) ** 2 + v + noise# Visualize the data plt.rcParams['figure.figsize'] = (10, 6) plt.scatter(x_train, y_train) plt.xlabel('x_train') plt.ylabel('y_train') 2.The Saver class
2.1保存模型 在以下幾行代碼中,我們定義一個 # Create a Saver objectsaver = tf.train.Saver()init = tf.global_variables_initializer()# Run a session. Go through 100 iterations to minimize the costdef train_graph(): with tf.Session() as sess: sess.run(init) for i in range(100): for (x, y) in zip(x_train, y_train): # Feed actual data to the train operation sess.run(trainop, feed_dict={X: x, Y: y}) # Create a checkpoint in every iteration saver.save(sess, 'model_iter', global_step=i) # Save the final model saver.save(sess, 'model_final') h_ = sess.run(h_est) v_ = sess.run(v_est) return h_, v_ 現(xiàn)在讓我們用上述功能訓(xùn)練模型,并打印出訓(xùn)練的參數(shù)。 result = train_graph() print('h_est = %.2f, v_est = %.2f' % result) $ python tf_save.pyh_est = 1.01, v_est = -1.96 Okay,參數(shù)是非常準(zhǔn)確的。如果我們檢查我們的文件系統(tǒng),最后4次迭代中保存有文件以及最終的模型。 保存模型時,你會注意到需要4種類型的文件才能保存: “.meta”文件:包含圖形結(jié)構(gòu)。 “.data”文件:包含變量的值。 “.index”文件:標(biāo)識檢查點。 “checkpoint”文件:具有最近檢查點列表的協(xié)議緩沖區(qū)。 圖1:檢查點文件保存到磁盤 調(diào)用
如果你想要了解更多信息,請查看官方文檔的 3.Restoring Models 恢復(fù)TensorFlow模型時要做的第一件事就是將圖形結(jié)構(gòu)從“.meta”文件加載到當(dāng)前圖形中。 tf.reset_default_graph() imported_meta = tf.train.import_meta_graph('model_final.meta') 也可以使用以下命令探索當(dāng)前圖形tf.get_default_graph()。接著第二步是加載變量的值。提醒:值僅存在于會話(session)中。 with tf.Session() as sess: imported_meta.restore(sess, tf.train.latest_checkpoint('./')) h_est2 = sess.run('hor_estimate:0') v_est2 = sess.run('ver_estimate:0') print('h_est: %.2f, v_est: %.2f' % (h_est2, v_est2)) $ python tf_restore.pyINFO:tensorflow:Restoring parameters from ./model_final h_est: 1.01, v_est: -1.96 如前面所提到的,這種方法只保存圖形結(jié)構(gòu)和變量,這意味著通過占位符“X”和“Y”輸入的訓(xùn)練數(shù)據(jù)不會被保存。 無論如何,在這個例子中,我們將使用我們定義的訓(xùn)練數(shù)據(jù)tf,并且可視化模型擬合。 plt.scatter(x_train, y_train, label='train data') plt.plot(x_train, (x_train - h_est2) ** 2 + v_est2, color='red', label='model') plt.xlabel('x_train') plt.ylabel('y_train') plt.legend()
4.SavedModel格式(Format) 在TensorFlow中保存和恢復(fù)模型的一種新方法是使用SavedModel,Builder和loader功能。這個方法實際上是 雖然這種 4.1使用SavedModel Builder保存模型 接下來我們嘗試使用 tf.reset_default_graph()# Re-initialize our two variablesh_est = tf.Variable(h_est2, name='hor_estimate2') v_est = tf.Variable(v_est2, name='ver_estimate2')# Create a builderbuilder = tf.saved_model.builder.SavedModelBuilder('./SavedModel/')# Add graph and variables to builder and savewith tf.Session() as sess: sess.run(h_est.initializer) sess.run(v_est.initializer) builder.add_meta_graph_and_variables(sess, [tf.saved_model.tag_constants.TRAINING], signature_def_map=None, assets_collection=None)builder.save() $ python tf_saved_model_builder.pyINFO:tensorflow:No assets to save. INFO:tensorflow:No assets to write. INFO:tensorflow:SavedModel written to: b'./SavedModel/saved_model.pb' 運行此代碼時,你會注意到我們的模型已保存到位于“./SavedModel/saved_model.pb”的文件中。 4.2使用SavedModel Loader程序恢復(fù)模型 模型恢復(fù)使用 在下面的例子中,我們將加載模型,并打印出我們的兩個系數(shù)( with tf.Session() as sess: tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.TRAINING], './SavedModel/') h_est = sess.run('hor_estimate2:0') v_est = sess.run('ver_estimate2:0') print('h_est: %.2f, v_est: %.2f' % (h_est, v_est)) $ python tf_saved_model_loader.pyINFO:tensorflow:Restoring parameters from b'./SavedModel/variables/variables' h_est: 1.01, v_est: -1.96 5.結(jié)論 如果你知道你的深度學(xué)習(xí)網(wǎng)絡(luò)的訓(xùn)練可能會花費很長時間,保存和恢復(fù)TensorFlow模型是非常有用的功能。該主題太廣泛,無法在一篇博客文章中詳細(xì)介紹。不管怎樣,在這篇文章中我們介紹了兩個工具: 作者信息 作者:Mihajlo Pavloski,數(shù)據(jù)科學(xué)與機器學(xué)習(xí)的愛好者,博士生。 本文由阿里云云棲社區(qū)組織翻譯。 文章原標(biāo)題《TensorFlow : Saveand Restore Models》 作者:Mihajlo Pavloski 譯者:虎說八道,審閱: |
|