tf.train.SessionRunHook
这个方法是为了打印运行时的数据,方便debug,初学者理解
import tensorflow as tfimport numpy as npx = tf.placeholder(shape=(10, 2), dtype=tf.float32)w = tf.Variable(initial_value=[[10.], [10.]])w0 = [[1], [1.]]y = tf.matmul(x, w0)loss = tf.reduce_mean((tf.matmul(x, w) - y) ** 2)optimizer = tf.train.AdamOptimizer(0.001).minimize(loss)class _Hook(tf.train.SessionRunHook):def __init__(self, loss, w):self.loss = lossself.w = wdef begin(self):passdef before_run(self, run_context):return tf.train.SessionRunArgs([self.loss, self.w]) #这里输入的[]对应run_values.resultsdef after_run(self, run_context, run_values):print(type(run_values))print(run_values)loss_value = run_values.results[0]w_value = run_values.results[1]print("loss value:", loss_value)print("w:",w_value)#sess = tf.train.MonitoredSession(hooks=[_Hook(loss,w)])sess = tf.train.MonitoredTrainingSession(hooks=[_Hook(loss,w)])for _ in range(10):x_ = np.random.random((10, 2))sess.run(optimizer, {x: x_})
tf.nn.embedding_lookup_sparse
加载meta文件,查看模型结构
import tensorflow as tffrom tensorflow.python.platform import gfile#这是从文件格式的meta文件加载模型graph = tf.get_default_graph()graphdef = graph.as_graph_def()# graphdef.ParseFromString(gfile.FastGFile("/data/TensorFlowAndroidMNIST/app/src/main/expert-graph.pb", "rb").read())# _ = tf.import_graph_def(graphdef, name="")_ = tf.train.import_meta_graph("./dsp_graph.meta")summary_write = tf.summary.FileWriter("./log" , graph)#然后tensorboard --logdir log即可
tf.add_to_collection
在一个计算图中,可以通过集合collection来管理不同类别的资源。可以通过tf.add_to_collection函数将资源计入一个或多个集合中,然后通过tf.get_collection获取一个集合里面的所有资源。tensorflow也自动管理来一些常用的集合如下
add:0的表示方式
下面脚本里面add:0表示result这个张量是计算节点”add”输出的第一个结果,编号从0开始
会话和计算图的关系
tf会自动生成一个默认图,不会自动生成默认会话
import tensorflow as tf//使用默认的计算图a = tf.constant([1,2],name="a")b = tf.constant([3,4],name="b")result = a + b//自己创建计算图g1 = tf.Graph()with g1.as_default():a = tf.constant([1,2],name="a")b = tf.constant([3,4],name="b")result = a + b//使用默认计算图with tf.Session() as sess:print(sess.run(result))//使用自己创建的计算图with tf.Session(graph=g1) as sess:print(sess.run(result))
