tensorflow系列教程之三:feed和fetch

本文主要了解tensorflow中feed和fetch的用法。

feed

在tensorflow中,为了节约内存,可以使用placeholder来声明一个占位符,声明时不分配内存,使用时才分配。对于数据量很大的计算图,这样可以提升运算性能。而如果在运行时分配内存呢,就是使用feed了,顾名思义,feed就是将值填充进占位符。

In [1]:
import tensorflow as tf
In [2]:
#例一:一个简单的placeholder和feed,输出“hello tensorflow”

#定义一个字符串类型的placeholder
var_str=tf.placeholder(tf.string)

with tf.Session() as sess:
    output=sess.run(var_str,feed_dict={var_str:"hello tensorflow"})
    print(output)
hello tensorflow
In [3]:
#例二:浮点数乘积示例

#定义两个浮点数类型的placeholder
var_mul_1=tf.placeholder(tf.float32)
var_mul_2=tf.placeholder(tf.float32)

#定义一个乘法op
mul=tf.multiply(var_mul_1,var_mul_2)

with tf.Session() as sess:
    output=sess.run(mul,feed_dict={var_mul_1:2.,var_mul_2:3.})
    print(output)
6.0
In [12]:
#例三:矩阵乘法

#定义两个矩阵
var_matrix_1=tf.placeholder(dtype=tf.float32,shape=[2,3])
var_matrix_2=tf.placeholder(dtype=tf.float32,shape=[3,2])

#定义矩阵乘法
matrix_mul=tf.matmul(var_matrix_1,var_matrix_2)

#tf.random_normal 从正态分布中输出随机值
matrix1=tf.random_normal(shape=[2,3])
#tf.truncated_normal 从截断的正态分布中输出随机值
matrix2=tf.truncated_normal(shape=[3,2])

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    m1=sess.run(matrix1)
    print('m1的值是\n{}\n'.format(m1))
    m2=sess.run(matrix2)
    print('m2的值是\n{}\n'.format(m2))
    output=sess.run(matrix_mul,feed_dict={var_matrix_1:m1,var_matrix_2:m2})
    print('矩阵乘法的结果是\n{}'.format(output))
    
m1的值是
[[-0.9469458  -1.0637448  -0.9293784 ]
 [-1.1452185  -0.06935152 -1.7329748 ]]

m2的值是
[[ 0.7995693  -1.5922078 ]
 [-0.25059786  0.28382644]
 [ 1.2194905  -0.37946603]]

矩阵乘法的结果是
[[-1.6239448  1.5584831]
 [-3.0116487  2.461347 ]]

fetch

Fetch操作是指TensorFlow的session可以一次run多个op 语法: 将多个op放入数组中然后传给run方法

In [20]:
input1 = tf.constant(3.0)
input2 = tf.constant(2.0)
input3 = tf.constant(5.0)

#定义两个op
add = tf.add(input2, input3)
mu1 = tf.multiply(input1, add)

with tf.Session() as sess:
    #一次操作两个op, 按顺序返回结果
    #分别fetch到多个变量中
    result1,result2 = sess.run((mu1, add))
    print(result1)
    print(result2)
    #将结果fetch到一个数组中
    result3 = sess.run([mu1, add])
    print(result3)
                               
21.0
7.0
[21.0, 7.0]