这篇博客讲的是MNIST数据集的读取,和Fashion_MNIST数据集的读取异曲同工,原文链接。
四个文件
train-images-idx3-ubyte.gz: training set imagestrain-images-idx3-ubyte.gz train-labels-idx1-ubyte.gz: training set labelstrain-labels-idx1-ubyte.gz t10k-images-idx3-ubyte.gz: test set imagest10k-images-idx3-ubyte.gz t10k-labels-idx1-ubyte.gz: test set labelst10k-labels-idx1-ubyte.gz
将上述的四个文件下载下来,然后解压,得到
train-images-idx3-ubyte train-labels-idx1-ubyte t10k-images-idx3-ubyte t10k-labels-idx1-ubyte
原博客有一堆巴拉巴拉的讲了一堆的数据意思,我看不下去,就直接上代码了
读取
from __future__ import divisionfrom __future__ import print_functionimport osimport sysimport structfile_list = ["train-images-idx3-ubyte","train-labels-idx1-ubyte","t10k-images-idx3-ubyte","t10k-labels-idx1-ubyte",]'''如果不是已有的路径,则创建路径'''def create_path(path):if not os.path.isdir(path):os.makedirs(path)'''judge weather the last character of the path is /?if yes, return "path"+"name"if no ,return "path"+"/"+"name"'''def get_file_full_name(path, name):create_path(path)if path[-1] == "/":full_name = path + nameelse:full_name = path + "/" + namereturn full_namedef read_mnist(file_name):file_path = "D:\\谷歌下载\\fashion-mnist-master\\data\\fashion"full_path = get_file_full_name(file_path, file_name)file_object = open(full_path, 'rb')return file_object'''the key part: I can't understand the goal of unpack,what does mean?'''def get_file_header_data(file_name, header_len, unpack_str):f = read_mnist(file_name)raw_header = f.read(header_len)header_data = struct.unpack(unpack_str, raw_header)return header_datadef show_images_file_header(file_name):show_file_header(file_name, 16, ">4I")def show_labels_file_header(file_name):show_file_header(file_name, 8, ">2I")def show_file_header(file_name, header_len, unpack_str):header_data = get_file_header_data(file_name, header_len, unpack_str)print("%s header data:%s" % (file_name, header_data))def show_mnist_file_header():train_images_file_name = file_list[0]show_images_file_header(train_images_file_name)test_images_file_name = file_list[2]show_images_file_header(test_images_file_name)train_labels_file_name = file_list[1]show_labels_file_header(train_labels_file_name)test_labels_file_name = file_list[3]show_labels_file_header(test_labels_file_name)def run():show_mnist_file_header()run()
输出: train-images-idx3-ubyte header data:(2051, 60000, 28, 28) t10k-images-idx3-ubyte header data:(2051, 10000, 28, 28) train-labels-idx1-ubyte header data:(2049, 60000) t10k-labels-idx1-ubyte header data:(2049, 10000)
下面我们读取一张图片,并且展示图片和它的标记
from __future__ import divisionfrom __future__ import print_function#gunzip *.gz#http://yann.lecun.com/exdb/mnist/import osimport sysimport structimport numpy as npimport matplotlib.pyplot as pltfrom PIL import Imagefile_list = ["train-images-idx3-ubyte","train-labels-idx1-ubyte","t10k-images-idx3-ubyte","t10k-labels-idx1-ubyte",]def create_path(path):if not os.path.isdir(path):os.makedirs(path)def get_file_full_name(path, name):create_path(path)if path[-1] == "/":full_name = path + nameelse:full_name = path + "/" + namereturn full_namedef read_mnist(file_name):file_path = "/home/your/data/path"full_path = get_file_full_name(file_path, file_name)file_object = open(full_path, 'rb') #python3 need rb python2 r is okreturn file_objectdef get_file_header_data(file_obj, header_len, unpack_str):raw_header = file_obj.read(header_len)header_data = struct.unpack(unpack_str, raw_header)return header_datadef show_images_file_header(file_name):show_file_header(file_name, 16, ">4I")def show_labels_file_header(file_name):show_file_header(file_name, 8, ">2I")def show_file_header(file_name, header_len, unpack_str):file_obj = read_mnist(file_name)header_data = get_file_header_data(file_obj, header_len, unpack_str)show_file_header_data(file_name, header_data)file_obj.close()def show_mnist_file_header():train_images_file_name = file_list[0]show_images_file_header(train_images_file_name)test_images_file_name = file_list[2]show_images_file_header(test_images_file_name)train_labels_file_name = file_list[1]show_labels_file_header(train_labels_file_name)test_labels_file_name = file_list[3]show_labels_file_header(test_labels_file_name)def read_a_image(file_object):img = file_object.read(28*28)tp = struct.unpack(">784B",img)image = np.asarray(tp)image = image.reshape((28,28))#image = image.astype(np.float64)plt.imshow(image,cmap = plt.cm.gray)plt.show()def read_a_label(file_object):img = file_object.read(1)tp = struct.unpack(">B",img)print("the label is :%s" % tp[0])def show_file_header_data(file_name,header_data):print("%s header data:%s" % (file_name, header_data))def show_a_image():images_file_name = file_list[0]labels_file_name = file_list[1]images_file = read_mnist(images_file_name)header_data = get_file_header_data(images_file, 16, ">4I")show_file_header_data(images_file_name, header_data)labels_file = read_mnist(labels_file_name)header_data = get_file_header_data(labels_file, 8, ">2I")show_file_header_data(labels_file_name, header_data)read_a_image(images_file)read_a_label(labels_file)def run():#show_mnist_file_header()show_a_image()run()
输出:
train-images-idx3-ubyte header data:(2051, 60000, 28, 28)
train-labels-idx1-ubyte header data:(2049, 60000)
the label is :9

由标记对应图片的结果,看到
| Label | Description |
|---|---|
| 0 | T-shirt/top |
| 1 | Trouser |
| 2 | Pullover |
| 3 | Dress |
| 4 | Coat |
| 5 | Sandal |
| 6 | Shirt |
| 7 | Sneaker |
| 8 | Bag |
| 9 | Ankle boot |
我们输出与我们的图片相对应。
然后我们修改成能自动生成批数据
from __future__ import divisionfrom __future__ import print_function#gunzip *.gz#http://yann.lecun.com/exdb/mnist/import osimport sysimport structimport numpy as npimport matplotlib.pyplot as pltfrom PIL import Imagefile_list = ["train-images-idx3-ubyte","train-labels-idx1-ubyte","t10k-images-idx3-ubyte","t10k-labels-idx1-ubyte",]def show_images_file_header(file_name):show_file_header(file_name, 16, ">4I")def show_labels_file_header(file_name):show_file_header(file_name, 8, ">2I")def show_file_header(file_name, header_len, unpack_str):file_obj = read_mnist(file_name)header_data = get_file_header_data(file_obj, header_len, unpack_str)show_file_header_data(file_name, header_data)file_obj.close()def show_mnist_file_header():train_images_file_name = file_list[0]show_images_file_header(train_images_file_name)test_images_file_name = file_list[2]show_images_file_header(test_images_file_name)train_labels_file_name = file_list[1]show_labels_file_header(train_labels_file_name)test_labels_file_name = file_list[3]show_labels_file_header(test_labels_file_name)def show_a_image(file_object):image = read_a_image(images_file)image = np.asarray(tp)image = image.reshape((28,28))plt.imshow(image,cmap = plt.cm.gray)plt.show()def show_a_lebel(file_object):tp = read_a_label(file_object)print("the label is :%s" % tp)def show_file_header_data(file_name,header_data):print("%s header data:%s" % (file_name, header_data))def show_a_image():images_file_name = file_list[0]labels_file_name = file_list[1]images_file = read_mnist(images_file_name)header_data = get_file_header_data(images_file, 16, ">4I")show_file_header_data(images_file_name, header_data)labels_file = read_mnist(labels_file_name)header_data = get_file_header_data(labels_file, 8, ">2I")show_file_header_data(labels_file_name, header_data)show_a_image(images_file)read_a_label(labels_file)def create_path(path):if not os.path.isdir(path):os.makedirs(path)def get_file_full_name(path, name):create_path(path)if path[-1] == "/":full_name = path + nameelse:full_name = path + "/" + namereturn full_namedef read_mnist(file_name):file_path = "/home/your/data/path"full_path = get_file_full_name(file_path, file_name)file_object = open(full_path, 'rb') #python3 need rb python2 r is okreturn file_objectdef get_file_header_data(file_obj, header_len, unpack_str):raw_header = file_obj.read(header_len)header_data = struct.unpack(unpack_str, raw_header)return header_datadef read_a_image(file_object):raw_img = file_object.read(28*28)img = struct.unpack(">784B",raw_img)return imgdef read_a_label(file_object):raw_label = file_object.read(1)label = struct.unpack(">B",raw_label)return labeldef generate_a_batch(images_file_name,labels_file_name,batch_size=8):images_file = read_mnist(images_file_name)header_data = get_file_header_data(images_file, 16, ">4I")#show_file_header_data(images_file_name, header_data)labels_file = read_mnist(labels_file_name)header_data = get_file_header_data(labels_file, 8, ">2I")#show_file_header_data(labels_file_name, header_data)while True:images = []labels = []for i in range(100):try:image = read_a_image(images_file)label = read_a_label(labels_file)images.append(image)labels.append(label)except Exception as err:print(err)breakyield images,labelsdef get_train_data_generator():images_file_name = file_list[0]labels_file_name = file_list[1]gennerator = generate_a_batch(images_file_name,labels_file_name)return gennerator-def get_test_data_generator():images_file_name = file_list[2]labels_file_name = file_list[3]gennerator = generate_a_batch(images_file_name,labels_file_name)return genneratordef get_test_data_generator():images_file_name = file_list[2]labels_file_name = file_list[3]gennerator = generate_a_batch(images_file_name,labels_file_name)return gennerator-def get_a_batch(data_generator):if sys.version >'3':batch_img, batch_labels = data_generator.__next__()else:batch_img, batch_labels = data_generator.next()return batch_img,batch_labelsdef generate_test_batch():data_generator = get_test_data_generator()count = 1while count:batch_img,batch_labels = get_a_batch(data_generator)if not batch_img and not batch_labels:breakbatch_img = np.array(batch_img)batch_labels = np.array(batch_labels)print("img shape:%s label shape:%s count:%s" %(batch_img.shape,batch_labels.shape,count))count +=1def generate_train_batch():epoch = 0while epoch<10:epoch += 1data_generator = get_train_data_generator()count = 1while count:batch_img,batch_labels = get_a_batch(data_generator)if not batch_img and not batch_labels:breakbatch_img = np.array(batch_img)batch_labels = np.array(batch_labels)print("epoch:%s img shape:%s label shape:%s count:%s" %(epoch,batch_img.shape,batch_labels.shape,count))count +=1def run():generate_train_batch()generate_test_batch()run()
上面的代码好多没有用的代码,把没有用的代码删掉,我们得到
from __future__ import divisionfrom __future__ import print_function#gunzip *.gz#http://yann.lecun.com/exdb/mnist/import osimport sysimport structimport numpy as npimport matplotlib.pyplot as pltfrom PIL import Imagefile_list = ["train-images-idx3-ubyte","train-labels-idx1-ubyte","t10k-images-idx3-ubyte","t10k-labels-idx1-ubyte",]def create_path(path):if not os.path.isdir(path):os.makedirs(path)def get_file_full_name(path, name):create_path(path)if path[-1] == "/":full_name = path + nameelse:full_name = path + "/" + namereturn full_namedef read_mnist(file_name):file_path = "/home/your/data/path"full_path = get_file_full_name(file_path, file_name)file_object = open(full_path, 'rb') #python3 need rb python2 r is okreturn file_objectdef get_file_header_data(file_obj, header_len, unpack_str):raw_header = file_obj.read(header_len)header_data = struct.unpack(unpack_str, raw_header)return header_datadef read_a_image(file_object):raw_img = file_object.read(28*28)img = struct.unpack(">784B",raw_img)return imgdef read_a_label(file_object):raw_label = file_object.read(1)label = struct.unpack(">B",raw_label)return labeldef generate_a_batch(images_file_name,labels_file_name,batch_size=8):images_file = read_mnist(images_file_name)header_data = get_file_header_data(images_file, 16, ">4I")labels_file = read_mnist(labels_file_name)header_data = get_file_header_data(labels_file, 8, ">2I")while True:images = []labels = []for i in range(100):try:image = read_a_image(images_file)label = read_a_label(labels_file)images.append(image)labels.append(label)except Exception as err:print(err)breakyield images,labelsdef get_train_data_generator():images_file_name = file_list[0]labels_file_name = file_list[1]gennerator = generate_a_batch(images_file_name,labels_file_name)return genneratordef get_test_data_generator():images_file_name = file_list[2]labels_file_name = file_list[3]gennerator = generate_a_batch(images_file_name,labels_file_name)return genneratordef get_a_batch(data_generator):if sys.version >'3':batch_img, batch_labels = data_generator.__next__()else:batch_img, batch_labels = data_generator.next()return batch_img,batch_labelsdef generate_test_batch():data_generator = get_test_data_generator()count = 1while count:batch_img,batch_labels = get_a_batch(data_generator)if not batch_img and not batch_labels:breakbatch_img = np.array(batch_img)batch_labels = np.array(batch_labels)print("img shape:%s label shape:%s count:%s" %(batch_img.shape,batch_labels.shape,count))count +=1def generate_train_batch():epoch = 0while epoch<10:epoch += 1data_generator = get_train_data_generator()count = 1while count:batch_img,batch_labels = get_a_batch(data_generator)if not batch_img and not batch_labels:breakbatch_img = np.array(batch_img)batch_labels = np.array(batch_labels)print("epoch:%s img shape:%s label shape:%s count:%s" %(epoch,batch_img.shape,batch_labels.shape,count))count +=1def run():generate_train_batch()generate_test_batch()run()
