这篇博客讲的是MNIST数据集的读取,和Fashion_MNIST数据集的读取异曲同工,原文链接

四个文件

train-images-idx3-ubyte.gz: training set imagestrain-images-idx3-ubyte.gz train-labels-idx1-ubyte.gz: training set labelstrain-labels-idx1-ubyte.gz t10k-images-idx3-ubyte.gz: test set imagest10k-images-idx3-ubyte.gz t10k-labels-idx1-ubyte.gz: test set labelst10k-labels-idx1-ubyte.gz

将上述的四个文件下载下来,然后解压,得到

train-images-idx3-ubyte train-labels-idx1-ubyte t10k-images-idx3-ubyte t10k-labels-idx1-ubyte

原博客有一堆巴拉巴拉的讲了一堆的数据意思,我看不下去,就直接上代码了

读取

  1. from __future__ import division
  2. from __future__ import print_function
  3. import os
  4. import sys
  5. import struct
  6. file_list = [
  7. "train-images-idx3-ubyte",
  8. "train-labels-idx1-ubyte",
  9. "t10k-images-idx3-ubyte",
  10. "t10k-labels-idx1-ubyte",
  11. ]
  12. '''
  13. 如果不是已有的路径,则创建路径
  14. '''
  15. def create_path(path):
  16. if not os.path.isdir(path):
  17. os.makedirs(path)
  18. '''
  19. judge weather the last character of the path is /?
  20. if yes, return "path"+"name"
  21. if no ,return "path"+"/"+"name"
  22. '''
  23. def get_file_full_name(path, name):
  24. create_path(path)
  25. if path[-1] == "/":
  26. full_name = path + name
  27. else:
  28. full_name = path + "/" + name
  29. return full_name
  30. def read_mnist(file_name):
  31. file_path = "D:\\谷歌下载\\fashion-mnist-master\\data\\fashion"
  32. full_path = get_file_full_name(file_path, file_name)
  33. file_object = open(full_path, 'rb')
  34. return file_object
  35. '''
  36. the key part: I can't understand the goal of unpack,what does mean?
  37. '''
  38. def get_file_header_data(file_name, header_len, unpack_str):
  39. f = read_mnist(file_name)
  40. raw_header = f.read(header_len)
  41. header_data = struct.unpack(unpack_str, raw_header)
  42. return header_data
  43. def show_images_file_header(file_name):
  44. show_file_header(file_name, 16, ">4I")
  45. def show_labels_file_header(file_name):
  46. show_file_header(file_name, 8, ">2I")
  47. def show_file_header(file_name, header_len, unpack_str):
  48. header_data = get_file_header_data(file_name, header_len, unpack_str)
  49. print("%s header data:%s" % (file_name, header_data))
  50. def show_mnist_file_header():
  51. train_images_file_name = file_list[0]
  52. show_images_file_header(train_images_file_name)
  53. test_images_file_name = file_list[2]
  54. show_images_file_header(test_images_file_name)
  55. train_labels_file_name = file_list[1]
  56. show_labels_file_header(train_labels_file_name)
  57. test_labels_file_name = file_list[3]
  58. show_labels_file_header(test_labels_file_name)
  59. def run():
  60. show_mnist_file_header()
  61. run()

输出: train-images-idx3-ubyte header data:(2051, 60000, 28, 28) t10k-images-idx3-ubyte header data:(2051, 10000, 28, 28) train-labels-idx1-ubyte header data:(2049, 60000) t10k-labels-idx1-ubyte header data:(2049, 10000)

下面我们读取一张图片,并且展示图片和它的标记

  1. from __future__ import division
  2. from __future__ import print_function
  3. #gunzip *.gz
  4. #http://yann.lecun.com/exdb/mnist/
  5. import os
  6. import sys
  7. import struct
  8. import numpy as np
  9. import matplotlib.pyplot as plt
  10. from PIL import Image
  11. file_list = [
  12. "train-images-idx3-ubyte",
  13. "train-labels-idx1-ubyte",
  14. "t10k-images-idx3-ubyte",
  15. "t10k-labels-idx1-ubyte",
  16. ]
  17. def create_path(path):
  18. if not os.path.isdir(path):
  19. os.makedirs(path)
  20. def get_file_full_name(path, name):
  21. create_path(path)
  22. if path[-1] == "/":
  23. full_name = path + name
  24. else:
  25. full_name = path + "/" + name
  26. return full_name
  27. def read_mnist(file_name):
  28. file_path = "/home/your/data/path"
  29. full_path = get_file_full_name(file_path, file_name)
  30. file_object = open(full_path, 'rb') #python3 need rb python2 r is ok
  31. return file_object
  32. def get_file_header_data(file_obj, header_len, unpack_str):
  33. raw_header = file_obj.read(header_len)
  34. header_data = struct.unpack(unpack_str, raw_header)
  35. return header_data
  36. def show_images_file_header(file_name):
  37. show_file_header(file_name, 16, ">4I")
  38. def show_labels_file_header(file_name):
  39. show_file_header(file_name, 8, ">2I")
  40. def show_file_header(file_name, header_len, unpack_str):
  41. file_obj = read_mnist(file_name)
  42. header_data = get_file_header_data(file_obj, header_len, unpack_str)
  43. show_file_header_data(file_name, header_data)
  44. file_obj.close()
  45. def show_mnist_file_header():
  46. train_images_file_name = file_list[0]
  47. show_images_file_header(train_images_file_name)
  48. test_images_file_name = file_list[2]
  49. show_images_file_header(test_images_file_name)
  50. train_labels_file_name = file_list[1]
  51. show_labels_file_header(train_labels_file_name)
  52. test_labels_file_name = file_list[3]
  53. show_labels_file_header(test_labels_file_name)
  54. def read_a_image(file_object):
  55. img = file_object.read(28*28)
  56. tp = struct.unpack(">784B",img)
  57. image = np.asarray(tp)
  58. image = image.reshape((28,28))
  59. #image = image.astype(np.float64)
  60. plt.imshow(image,cmap = plt.cm.gray)
  61. plt.show()
  62. def read_a_label(file_object):
  63. img = file_object.read(1)
  64. tp = struct.unpack(">B",img)
  65. print("the label is :%s" % tp[0])
  66. def show_file_header_data(file_name,header_data):
  67. print("%s header data:%s" % (file_name, header_data))
  68. def show_a_image():
  69. images_file_name = file_list[0]
  70. labels_file_name = file_list[1]
  71. images_file = read_mnist(images_file_name)
  72. header_data = get_file_header_data(images_file, 16, ">4I")
  73. show_file_header_data(images_file_name, header_data)
  74. labels_file = read_mnist(labels_file_name)
  75. header_data = get_file_header_data(labels_file, 8, ">2I")
  76. show_file_header_data(labels_file_name, header_data)
  77. read_a_image(images_file)
  78. read_a_label(labels_file)
  79. def run():
  80. #show_mnist_file_header()
  81. show_a_image()
  82. run()

输出:

train-images-idx3-ubyte header data:(2051, 60000, 28, 28)

train-labels-idx1-ubyte header data:(2049, 60000)

the label is :9

Figure_1.png
由标记对应图片的结果,看到

Label Description
0 T-shirt/top
1 Trouser
2 Pullover
3 Dress
4 Coat
5 Sandal
6 Shirt
7 Sneaker
8 Bag
9 Ankle boot

我们输出与我们的图片相对应。
然后我们修改成能自动生成批数据

  1. from __future__ import division
  2. from __future__ import print_function
  3. #gunzip *.gz
  4. #http://yann.lecun.com/exdb/mnist/
  5. import os
  6. import sys
  7. import struct
  8. import numpy as np
  9. import matplotlib.pyplot as plt
  10. from PIL import Image
  11. file_list = [
  12. "train-images-idx3-ubyte",
  13. "train-labels-idx1-ubyte",
  14. "t10k-images-idx3-ubyte",
  15. "t10k-labels-idx1-ubyte",
  16. ]
  17. def show_images_file_header(file_name):
  18. show_file_header(file_name, 16, ">4I")
  19. def show_labels_file_header(file_name):
  20. show_file_header(file_name, 8, ">2I")
  21. def show_file_header(file_name, header_len, unpack_str):
  22. file_obj = read_mnist(file_name)
  23. header_data = get_file_header_data(file_obj, header_len, unpack_str)
  24. show_file_header_data(file_name, header_data)
  25. file_obj.close()
  26. def show_mnist_file_header():
  27. train_images_file_name = file_list[0]
  28. show_images_file_header(train_images_file_name)
  29. test_images_file_name = file_list[2]
  30. show_images_file_header(test_images_file_name)
  31. train_labels_file_name = file_list[1]
  32. show_labels_file_header(train_labels_file_name)
  33. test_labels_file_name = file_list[3]
  34. show_labels_file_header(test_labels_file_name)
  35. def show_a_image(file_object):
  36. image = read_a_image(images_file)
  37. image = np.asarray(tp)
  38. image = image.reshape((28,28))
  39. plt.imshow(image,cmap = plt.cm.gray)
  40. plt.show()
  41. def show_a_lebel(file_object):
  42. tp = read_a_label(file_object)
  43. print("the label is :%s" % tp)
  44. def show_file_header_data(file_name,header_data):
  45. print("%s header data:%s" % (file_name, header_data))
  46. def show_a_image():
  47. images_file_name = file_list[0]
  48. labels_file_name = file_list[1]
  49. images_file = read_mnist(images_file_name)
  50. header_data = get_file_header_data(images_file, 16, ">4I")
  51. show_file_header_data(images_file_name, header_data)
  52. labels_file = read_mnist(labels_file_name)
  53. header_data = get_file_header_data(labels_file, 8, ">2I")
  54. show_file_header_data(labels_file_name, header_data)
  55. show_a_image(images_file)
  56. read_a_label(labels_file)
  57. def create_path(path):
  58. if not os.path.isdir(path):
  59. os.makedirs(path)
  60. def get_file_full_name(path, name):
  61. create_path(path)
  62. if path[-1] == "/":
  63. full_name = path + name
  64. else:
  65. full_name = path + "/" + name
  66. return full_name
  67. def read_mnist(file_name):
  68. file_path = "/home/your/data/path"
  69. full_path = get_file_full_name(file_path, file_name)
  70. file_object = open(full_path, 'rb') #python3 need rb python2 r is ok
  71. return file_object
  72. def get_file_header_data(file_obj, header_len, unpack_str):
  73. raw_header = file_obj.read(header_len)
  74. header_data = struct.unpack(unpack_str, raw_header)
  75. return header_data
  76. def read_a_image(file_object):
  77. raw_img = file_object.read(28*28)
  78. img = struct.unpack(">784B",raw_img)
  79. return img
  80. def read_a_label(file_object):
  81. raw_label = file_object.read(1)
  82. label = struct.unpack(">B",raw_label)
  83. return label
  84. def generate_a_batch(images_file_name,labels_file_name,batch_size=8):
  85. images_file = read_mnist(images_file_name)
  86. header_data = get_file_header_data(images_file, 16, ">4I")
  87. #show_file_header_data(images_file_name, header_data)
  88. labels_file = read_mnist(labels_file_name)
  89. header_data = get_file_header_data(labels_file, 8, ">2I")
  90. #show_file_header_data(labels_file_name, header_data)
  91. while True:
  92. images = []
  93. labels = []
  94. for i in range(100):
  95. try:
  96. image = read_a_image(images_file)
  97. label = read_a_label(labels_file)
  98. images.append(image)
  99. labels.append(label)
  100. except Exception as err:
  101. print(err)
  102. break
  103. yield images,labels
  104. def get_train_data_generator():
  105. images_file_name = file_list[0]
  106. labels_file_name = file_list[1]
  107. gennerator = generate_a_batch(images_file_name,labels_file_name)
  108. return gennerator-
  109. def get_test_data_generator():
  110. images_file_name = file_list[2]
  111. labels_file_name = file_list[3]
  112. gennerator = generate_a_batch(images_file_name,labels_file_name)
  113. return gennerator
  114. def get_test_data_generator():
  115. images_file_name = file_list[2]
  116. labels_file_name = file_list[3]
  117. gennerator = generate_a_batch(images_file_name,labels_file_name)
  118. return gennerator-
  119. def get_a_batch(data_generator):
  120. if sys.version >'3':
  121. batch_img, batch_labels = data_generator.__next__()
  122. else:
  123. batch_img, batch_labels = data_generator.next()
  124. return batch_img,batch_labels
  125. def generate_test_batch():
  126. data_generator = get_test_data_generator()
  127. count = 1
  128. while count:
  129. batch_img,batch_labels = get_a_batch(data_generator)
  130. if not batch_img and not batch_labels:
  131. break
  132. batch_img = np.array(batch_img)
  133. batch_labels = np.array(batch_labels)
  134. print("img shape:%s label shape:%s count:%s" %(batch_img.shape,batch_labels.shape,count))
  135. count +=1
  136. def generate_train_batch():
  137. epoch = 0
  138. while epoch<10:
  139. epoch += 1
  140. data_generator = get_train_data_generator()
  141. count = 1
  142. while count:
  143. batch_img,batch_labels = get_a_batch(data_generator)
  144. if not batch_img and not batch_labels:
  145. break
  146. batch_img = np.array(batch_img)
  147. batch_labels = np.array(batch_labels)
  148. print("epoch:%s img shape:%s label shape:%s count:%s" %(epoch,batch_img.shape,batch_labels.shape,count))
  149. count +=1
  150. def run():
  151. generate_train_batch()
  152. generate_test_batch()
  153. run()

上面的代码好多没有用的代码,把没有用的代码删掉,我们得到

  1. from __future__ import division
  2. from __future__ import print_function
  3. #gunzip *.gz
  4. #http://yann.lecun.com/exdb/mnist/
  5. import os
  6. import sys
  7. import struct
  8. import numpy as np
  9. import matplotlib.pyplot as plt
  10. from PIL import Image
  11. file_list = [
  12. "train-images-idx3-ubyte",
  13. "train-labels-idx1-ubyte",
  14. "t10k-images-idx3-ubyte",
  15. "t10k-labels-idx1-ubyte",
  16. ]
  17. def create_path(path):
  18. if not os.path.isdir(path):
  19. os.makedirs(path)
  20. def get_file_full_name(path, name):
  21. create_path(path)
  22. if path[-1] == "/":
  23. full_name = path + name
  24. else:
  25. full_name = path + "/" + name
  26. return full_name
  27. def read_mnist(file_name):
  28. file_path = "/home/your/data/path"
  29. full_path = get_file_full_name(file_path, file_name)
  30. file_object = open(full_path, 'rb') #python3 need rb python2 r is ok
  31. return file_object
  32. def get_file_header_data(file_obj, header_len, unpack_str):
  33. raw_header = file_obj.read(header_len)
  34. header_data = struct.unpack(unpack_str, raw_header)
  35. return header_data
  36. def read_a_image(file_object):
  37. raw_img = file_object.read(28*28)
  38. img = struct.unpack(">784B",raw_img)
  39. return img
  40. def read_a_label(file_object):
  41. raw_label = file_object.read(1)
  42. label = struct.unpack(">B",raw_label)
  43. return label
  44. def generate_a_batch(images_file_name,labels_file_name,batch_size=8):
  45. images_file = read_mnist(images_file_name)
  46. header_data = get_file_header_data(images_file, 16, ">4I")
  47. labels_file = read_mnist(labels_file_name)
  48. header_data = get_file_header_data(labels_file, 8, ">2I")
  49. while True:
  50. images = []
  51. labels = []
  52. for i in range(100):
  53. try:
  54. image = read_a_image(images_file)
  55. label = read_a_label(labels_file)
  56. images.append(image)
  57. labels.append(label)
  58. except Exception as err:
  59. print(err)
  60. break
  61. yield images,labels
  62. def get_train_data_generator():
  63. images_file_name = file_list[0]
  64. labels_file_name = file_list[1]
  65. gennerator = generate_a_batch(images_file_name,labels_file_name)
  66. return gennerator
  67. def get_test_data_generator():
  68. images_file_name = file_list[2]
  69. labels_file_name = file_list[3]
  70. gennerator = generate_a_batch(images_file_name,labels_file_name)
  71. return gennerator
  72. def get_a_batch(data_generator):
  73. if sys.version >'3':
  74. batch_img, batch_labels = data_generator.__next__()
  75. else:
  76. batch_img, batch_labels = data_generator.next()
  77. return batch_img,batch_labels
  78. def generate_test_batch():
  79. data_generator = get_test_data_generator()
  80. count = 1
  81. while count:
  82. batch_img,batch_labels = get_a_batch(data_generator)
  83. if not batch_img and not batch_labels:
  84. break
  85. batch_img = np.array(batch_img)
  86. batch_labels = np.array(batch_labels)
  87. print("img shape:%s label shape:%s count:%s" %(batch_img.shape,batch_labels.shape,count))
  88. count +=1
  89. def generate_train_batch():
  90. epoch = 0
  91. while epoch<10:
  92. epoch += 1
  93. data_generator = get_train_data_generator()
  94. count = 1
  95. while count:
  96. batch_img,batch_labels = get_a_batch(data_generator)
  97. if not batch_img and not batch_labels:
  98. break
  99. batch_img = np.array(batch_img)
  100. batch_labels = np.array(batch_labels)
  101. print("epoch:%s img shape:%s label shape:%s count:%s" %(epoch,batch_img.shape,batch_labels.shape,count))
  102. count +=1
  103. def run():
  104. generate_train_batch()
  105. generate_test_batch()
  106. run()