在jupyter notebook环境下如何调试tensorflow命令行参数

Tensorflow中可以使用flags来通过命令行动态更改代码中的参数,达到灵活调节模型中超参数的问题。

  • 定义参数

通过tf.flags.DEFINE_###()系列,括号内的三个参数分别为(参数名称,默认值,参数描述)

  1. import tensorflow as tf
  2. flags = tf.flags
  3. # Define a string
  4. flags.DEFINE_string("label_dir",label_dir,"")
  5. # Define a list
  6. class_list = ["liver","kidney"]
  7. flags.DEFINE_list("class_list",class_list,"")
  8. # Define a integer
  9. flags.DEFINE_integer("num_classes",len(class_list)+1,"")
  10. # Define a float
  11. flags.DEFINE_float("lr",3e-4,"")
  12. # Define a bool
  13. flags.DEFINE_bool("use_augment",False,"")
  • 调用参数
  1. config = flags.FLAGS
  • 在jupyter环境中测试

如果直接print(config.lr)此时会出现Unknown command line flag 'f'报错,需要加上如下命令。

  1. tf.app.flags.DEFINE_string('f', '', 'kernel')
  2. print(config.lr)

TFRecords格式数据集的制作和存储

.tfrecords格式文件是tensorflow标准数据集格式,适用于大量数据读取。

  • TFRecords格式文件的生成
  1. # record_file为TFRecords格式文件的存储路径
  2. writer = tf.python_io.TFRecordWriter(record_file)

TFRecord支持写入三种格式的数据:string,int64,float32,以列表的形式分别通过tf.train.BytesList,tf.train.Int64List,tf.train.FloatList 写入 tf.train.Feature。

  1. tf.train.BytesList(value=[value]) # value转化为字符串(二进制)列表
  2. tf.train.FloatList(value=[value]) # value转化为浮点型列表
  3. tf.train.Int64List(value=[value]) # value转化为整型列表
  4. # 内层特征编码方式
  5. feature_internal = {
  6. "A":tf.train.Feature(int64_list=tf.train.Int64List(value=[A])),
  7. "B":tf.train.Feature(float_list=tf.train.FloatList(value=[B])),
  8. "C":tf.train.Feature(bytes_list=tf.train.BytesList(value=[C]))
  9. }

tf.train.Feature生成协议信息通过内层和外层这两层Features存储信息。对于图像文件,其像素点矩阵不能直接存储,需要转换为string字符型进行存储。以下演示如何存储DICOM格式的医学图像及其对应的ID号。

  1. # 使用pydicom库对DICOM格式文件进行读取
  2. # 具体细节可参见https://www.yuque.com/oliviagao/iggd2c/xmxrse
  3. import pydicom as dicom
  4. array_dicom = image_ds.pixel_array
  5. rescaleIntercept = np.int(image_ds.RescaleIntercept)
  6. rescaleSlope = np.int(image_ds.RescaleSlope)
  7. array_ct = array_dicom*rescaleSlope + rescaleIntercept
  8. array_img = setDicomWinWidthWinCenter(array_ct,400,60,512,512)
  9. #array数据类型强行转换,注意要保持存储和读取时,前后一致
  10. array_img = array_img.astype(np.int16)
  11. #使用tostring将图像矩阵转换为string格式进行存储
  12. image = array_img.tostring()
  13. #ID号转换为int格式
  14. patient_id = int(image_ds.PatientID)
  15. # 内层特征,使用f.train.Feature构建
  16. feature_dict = {
  17. "image": tf.train.Feature(bytes_list=tf.train.BytesList(value=[image])),
  18. "id": tf.train.Feature(int64_list=tf.train.Int64List(value=[patient_id]))
  19. }
  20. # 外层特征,使用tf.train.Features构建,将内层字典特征编码
  21. features = tf.train.Features(feature=feature_dict)
  22. # 用tf.train.Example将features编码数据封装成特定的PB协议格式
  23. example = tf.train.Example(features=features)
  24. # 将example数据系列化为字符串,并且将系列化为字符串的example数据写入协议缓冲区
  25. writer.write(example.SerializeToString())


  • TFRecord格式文件的读取

parser解析函数的作用是把数据从TFRecord文件中解析出来。对于image array, 生成文件时使用了tostring()进行编码,在文件读取步骤时需要decode_raw()进行解码,注意⚠️一定要保持数据类型前后一致,否则会造成shape前后发生变化。例如512512尺寸的image array数据类型为int32,进行编码,如果之后读取时以int16数据类型解码,shape会变大,不在时512512。

  1. image_size = 512
  2. def parser(example):
  3. features = {
  4. "image": tf.FixedLenFeature([], tf.string),
  5. "id": tf.FixedLenFeature([], tf.int64)
  6. }
  7. features = tf.parse_single_example(example, features)
  8. shape = tf.stack([image_size, image_size, 1])
  9. # tf.decode_raw()步骤中的类型值特别重要!!!!!
  10. # 一定要和生成TFRecord文件时的arrary类型保持一致,否则shape会发生变化,造成报错
  11. image = tf.decode_raw(features["image"], tf.int16)
  12. image = tf.reshape(image, shape)
  13. image = tf.cast(image, tf.float32)
  14. image = tf.concat([image] * 3, axis=-1)
  15. patient_id = tf.cast(features["id"], tf.int64)
  16. return image,patient_id

map(parser)表示对数据集中的每一条数据调用parser方法。dataset.batch函数是设置一次取出的batch的大小。dataset.repeat函数表示取全部数据几个epoch

  1. dataset = tf.data.TFRecordDataset(record_file).\
  2. map(parser).\
  3. batch(1).\
  4. repeat()
  5. predict_iterator = dataset.make_one_shot_iterator()
  6. with tf.Session() as sess:
  7. image_test,mask_test,id_test = sess.run(predict_iterator.get_next())

Tensorboard的对模型的可视化

在终端中输入以下命令:
tensorboard --logdir="./tensorboard"
会给出指定的localhost端口,比如http://localhost:6006

PS: IP地址查询

https://www.ipaddress.com/