Tensorflow에서 훈련된 모델 저장 및 복원
Tensorflow에서 모델을 훈련한 후에는 모델을 보존하고 재사용하는 것이 중요합니다. 모델 저장을 효과적으로 처리하는 방법은 다음과 같습니다.
학습된 모델 저장(Tensorflow 버전 0.11 이상):
예제 코드:
import tensorflow as tf # Prepare input placeholders w1 = tf.placeholder("float", name="w1") w2 = tf.placeholder("float", name="w2") # Define test operation w3 = tf.add(w1, w2) w4 = tf.multiply(w3, tf.Variable(2.0, name="bias"), name="op_to_restore") # Initialize variables and run session sess = tf.Session() sess.run(tf.global_variables_initializer()) # Create saver object saver = tf.train.Saver() # Save the model saver.save(sess, 'my_test_model', global_step=1000)
저장된 내용 복원 모델:
예제 코드:
# Restore model saver = tf.train.import_meta_graph('my_test_model-1000.meta') saver.restore(sess, tf.train.latest_checkpoint('./')) # Get placeholders and feed data w1 = sess.graph.get_tensor_by_name("w1:0") w2 = sess.graph.get_tensor_by_name("w2:0") feed_dict = {w1: 13.0, w2: 17.0} # Run saved operation op_to_restore = sess.graph.get_tensor_by_name("op_to_restore:0") result = sess.run(op_to_restore, feed_dict)
위 내용은 TensorFlow에서 훈련된 모델을 효과적으로 저장하고 복원하는 방법은 무엇입니까?의 상세 내용입니다. 자세한 내용은 PHP 중국어 웹사이트의 기타 관련 기사를 참조하세요!