TensorFlow1升TensorFlow2
参考:https://www.tensorflow.org/guide/migrate?hl=zh-cn
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()
cuda版本
Tensorflow与cuda版本的对应关系
Linux/MacOS: https://www.tensorflow.org/install/source?hl=zh-cn#linux
Windows: https://www.tensorflow.org/install/source_windows?hl=zh-cn
显卡驱动版本和cuda版本的对应关系
https://docs.nvidia.com/cuda/cuda-toolkit-release-notes/index.html
显存
默认情况下tensorflow的显卡只能跑一个程序,如果需要跑多个程序,需要指定per_process_gpu_memory_fraction设置占用多少显存,或者设置allow_growth=True不要一下占满显存
参考:https://www.jianshu.com/p/99fca5b7fd8a
# 指定GPU
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
# 设置allow_growth,不要一下占满显存
configuration = tf.ConfigProto()
configuration.gpu_options.allow_growth = True # 需要多少用多少,但是不会释放显存,会造成内存碎片
configuration.gpu_options.allow_soft_placement=True # 如果你指定的设备不存在,允许TF自动分配设备
sess = tf.Session(config=configuration)
# 设置per_process_gpu_memory_fraction,占用显存的比例
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction = 0.7) # 只使用70%的显存
config = tf.ConfigProto(gpu_options = gpu_options)
config.gpu_options.allow_soft_placement=True # 如果你指定的设备不存在,允许TF自动分配设备
sess = tf.Session(config = config,....)
打印网络结构
summary_writer = tf.summary.FileWriter('./log/', sess.graph)
然后在shell中tensorboard --logdir=log --host=127.0.0.1