本文主要了解tensorflow中feed和fetch的用法。
feed¶
在tensorflow中,为了节约内存,可以使用placeholder来声明一个占位符,声明时不分配内存,使用时才分配。对于数据量很大的计算图,这样可以提升运算性能。而如果在运行时分配内存呢,就是使用feed了,顾名思义,feed就是将值填充进占位符。
In [1]:
import tensorflow as tf
In [2]:
#例一:一个简单的placeholder和feed,输出“hello tensorflow”
#定义一个字符串类型的placeholder
var_str=tf.placeholder(tf.string)
with tf.Session() as sess:
output=sess.run(var_str,feed_dict={var_str:"hello tensorflow"})
print(output)
In [3]:
#例二:浮点数乘积示例
#定义两个浮点数类型的placeholder
var_mul_1=tf.placeholder(tf.float32)
var_mul_2=tf.placeholder(tf.float32)
#定义一个乘法op
mul=tf.multiply(var_mul_1,var_mul_2)
with tf.Session() as sess:
output=sess.run(mul,feed_dict={var_mul_1:2.,var_mul_2:3.})
print(output)
In [12]:
#例三:矩阵乘法
#定义两个矩阵
var_matrix_1=tf.placeholder(dtype=tf.float32,shape=[2,3])
var_matrix_2=tf.placeholder(dtype=tf.float32,shape=[3,2])
#定义矩阵乘法
matrix_mul=tf.matmul(var_matrix_1,var_matrix_2)
#tf.random_normal 从正态分布中输出随机值
matrix1=tf.random_normal(shape=[2,3])
#tf.truncated_normal 从截断的正态分布中输出随机值
matrix2=tf.truncated_normal(shape=[3,2])
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
m1=sess.run(matrix1)
print('m1的值是\n{}\n'.format(m1))
m2=sess.run(matrix2)
print('m2的值是\n{}\n'.format(m2))
output=sess.run(matrix_mul,feed_dict={var_matrix_1:m1,var_matrix_2:m2})
print('矩阵乘法的结果是\n{}'.format(output))
fetch¶
Fetch操作是指TensorFlow的session可以一次run多个op 语法: 将多个op放入数组中然后传给run方法
In [20]:
input1 = tf.constant(3.0)
input2 = tf.constant(2.0)
input3 = tf.constant(5.0)
#定义两个op
add = tf.add(input2, input3)
mu1 = tf.multiply(input1, add)
with tf.Session() as sess:
#一次操作两个op, 按顺序返回结果
#分别fetch到多个变量中
result1,result2 = sess.run((mu1, add))
print(result1)
print(result2)
#将结果fetch到一个数组中
result3 = sess.run([mu1, add])
print(result3)