tf.train.Coordinatorとエンキュースレッドスターターtf.train.start_queue_runnersintensorflow



Tf Train Coordinator



## tensorflow tf.train.Coordinator tf.train.start_queue_runners 画像上記のキューとコーディネーターの操作例:

# -*- coding:utf-8 -*- import tensorflow as tf import numpy as np # Number of samples sample_num=5 # Set the number of iterations epoch_num = 2 # Set the number of samples in a batch batch_size = 3 # Count the number of batches contained in each round of epoch batch_total = int(sample_num/batch_size)+1 # Generate 4 data and labels def generate_data(sample_num=sample_num): labels = np.asarray(range(0, sample_num)) images = np.random.random([sample_num, 224, 224, 3]) print('image size {},label size :{}'.format(images.shape, labels.shape)) return images,labels def get_batch_data(batch_size=batch_size): images, label = generate_data() # Data type conversion to tf.float32 images = tf.cast(images, tf.float32) label = tf.cast(label, tf.int32) #Sequentially or randomly extract a tensor from the tensor list to be placed in the file name queue input_queue = tf.train.slice_input_producer([images, label], num_epochs=epoch_num, shuffle=False) #Read the file from the file name queue and put it in the file queue image_batch, label_batch = tf.train.batch(input_queue, batch_size=batch_size, num_threads=2, capacity=64, allow_smaller_final_batch=False) return image_batch, label_batch image_batch, label_batch = get_batch_data(batch_size=batch_size) with tf.Session() as sess: # Perform initialization work first sess.run(tf.global_variables_initializer()) sess.run(tf.local_variables_initializer()) # Start a coordinator coord = tf.train.Coordinator() # Use start_queue_runners to start queue filling threads = tf.train.start_queue_runners(sess, coord) try: while not coord.should_stop(): print '************' # Get batch_size samples and labels in each batch image_batch_v, label_batch_v = sess.run([image_batch, label_batch]) print(image_batch_v.shape, label_batch_v) except tf.errors.OutOfRangeError: #This exception will be thrown if the end of the file queue is read print('done! now lets kill all the threads……') finally: # Coordinator coord sends all thread termination signals coord.request_stop() print('all threads are asked to stop!') coord.join(threads) #Add the opened thread to the main thread and wait for the threads to end print('all threads are stopped!')

出力:



************ ((3, 224, 224, 3), array([0, 1, 2], dtype=int32)) ************ ((3, 224, 224, 3), array([3, 4, 0], dtype=int32)) ************ ((3, 224, 224, 3), array([1, 2, 3], dtype=int32)) ************ done! now lets kill all the threads…… all threads are asked to stop! all threads are stopped!

上記のプログラムは、tf.train.slice_input_producer関数のnum_epochsの数を設定するため、ファイルキューの最後に終了マークがあります。この終了マークが読み取られると、OutofRangeError例外がスローされ、各スレッドを終了できます。

元のリンク:https://blog.csdn.net/dcrmg/article/details/79780331