tf.train.SessionRunHook

这个方法是为了打印运行时的数据,方便debug,初学者理解

  1. import tensorflow as tf
  2. import numpy as np
  3. x = tf.placeholder(shape=(10, 2), dtype=tf.float32)
  4. w = tf.Variable(initial_value=[[10.], [10.]])
  5. w0 = [[1], [1.]]
  6. y = tf.matmul(x, w0)
  7. loss = tf.reduce_mean((tf.matmul(x, w) - y) ** 2)
  8. optimizer = tf.train.AdamOptimizer(0.001).minimize(loss)
  9. class _Hook(tf.train.SessionRunHook):
  10. def __init__(self, loss, w):
  11. self.loss = loss
  12. self.w = w
  13. def begin(self):
  14. pass
  15. def before_run(self, run_context):
  16. return tf.train.SessionRunArgs([self.loss, self.w]) #这里输入的[]对应run_values.results
  17. def after_run(self, run_context, run_values):
  18. print(type(run_values))
  19. print(run_values)
  20. loss_value = run_values.results[0]
  21. w_value = run_values.results[1]
  22. print("loss value:", loss_value)
  23. print("w:",w_value)
  24. #sess = tf.train.MonitoredSession(hooks=[_Hook(loss,w)])
  25. sess = tf.train.MonitoredTrainingSession(hooks=[_Hook(loss,w)])
  26. for _ in range(10):
  27. x_ = np.random.random((10, 2))
  28. sess.run(optimizer, {x: x_})

tf.nn.embedding_lookup_sparse

链接

加载meta文件,查看模型结构

  1. import tensorflow as tf
  2. from tensorflow.python.platform import gfile
  3. #这是从文件格式的meta文件加载模型
  4. graph = tf.get_default_graph()
  5. graphdef = graph.as_graph_def()
  6. # graphdef.ParseFromString(gfile.FastGFile("/data/TensorFlowAndroidMNIST/app/src/main/expert-graph.pb", "rb").read())
  7. # _ = tf.import_graph_def(graphdef, name="")
  8. _ = tf.train.import_meta_graph("./dsp_graph.meta")
  9. summary_write = tf.summary.FileWriter("./log" , graph)
  10. #然后tensorboard --logdir log即可

tf.add_to_collection

在一个计算图中,可以通过集合collection来管理不同类别的资源。可以通过tf.add_to_collection函数将资源计入一个或多个集合中,然后通过tf.get_collection获取一个集合里面的所有资源。tensorflow也自动管理来一些常用的集合如下
image.png

add:0的表示方式

下面脚本里面add:0表示result这个张量是计算节点”add”输出的第一个结果,编号从0开始
image.png

会话和计算图的关系

tf会自动生成一个默认图,不会自动生成默认会话

  1. import tensorflow as tf
  2. //使用默认的计算图
  3. a = tf.constant([1,2],name="a")
  4. b = tf.constant([3,4],name="b")
  5. result = a + b
  6. //自己创建计算图
  7. g1 = tf.Graph()
  8. with g1.as_default():
  9. a = tf.constant([1,2],name="a")
  10. b = tf.constant([3,4],name="b")
  11. result = a + b
  12. //使用默认计算图
  13. with tf.Session() as sess:
  14. print(sess.run(result))
  15. //使用自己创建的计算图
  16. with tf.Session(graph=g1) as sess:
  17. print(sess.run(result))