这篇博客讲的是MNIST数据集的读取,和Fashion_MNIST数据集的读取异曲同工,原文链接。
四个文件
train-images-idx3-ubyte.gz: training set imagestrain-images-idx3-ubyte.gz train-labels-idx1-ubyte.gz: training set labelstrain-labels-idx1-ubyte.gz t10k-images-idx3-ubyte.gz: test set imagest10k-images-idx3-ubyte.gz t10k-labels-idx1-ubyte.gz: test set labelst10k-labels-idx1-ubyte.gz
将上述的四个文件下载下来,然后解压,得到
train-images-idx3-ubyte train-labels-idx1-ubyte t10k-images-idx3-ubyte t10k-labels-idx1-ubyte
原博客有一堆巴拉巴拉的讲了一堆的数据意思,我看不下去,就直接上代码了
读取
from __future__ import division
from __future__ import print_function
import os
import sys
import struct
file_list = [
"train-images-idx3-ubyte",
"train-labels-idx1-ubyte",
"t10k-images-idx3-ubyte",
"t10k-labels-idx1-ubyte",
]
'''
如果不是已有的路径,则创建路径
'''
def create_path(path):
if not os.path.isdir(path):
os.makedirs(path)
'''
judge weather the last character of the path is /?
if yes, return "path"+"name"
if no ,return "path"+"/"+"name"
'''
def get_file_full_name(path, name):
create_path(path)
if path[-1] == "/":
full_name = path + name
else:
full_name = path + "/" + name
return full_name
def read_mnist(file_name):
file_path = "D:\\谷歌下载\\fashion-mnist-master\\data\\fashion"
full_path = get_file_full_name(file_path, file_name)
file_object = open(full_path, 'rb')
return file_object
'''
the key part: I can't understand the goal of unpack,what does mean?
'''
def get_file_header_data(file_name, header_len, unpack_str):
f = read_mnist(file_name)
raw_header = f.read(header_len)
header_data = struct.unpack(unpack_str, raw_header)
return header_data
def show_images_file_header(file_name):
show_file_header(file_name, 16, ">4I")
def show_labels_file_header(file_name):
show_file_header(file_name, 8, ">2I")
def show_file_header(file_name, header_len, unpack_str):
header_data = get_file_header_data(file_name, header_len, unpack_str)
print("%s header data:%s" % (file_name, header_data))
def show_mnist_file_header():
train_images_file_name = file_list[0]
show_images_file_header(train_images_file_name)
test_images_file_name = file_list[2]
show_images_file_header(test_images_file_name)
train_labels_file_name = file_list[1]
show_labels_file_header(train_labels_file_name)
test_labels_file_name = file_list[3]
show_labels_file_header(test_labels_file_name)
def run():
show_mnist_file_header()
run()
输出: train-images-idx3-ubyte header data:(2051, 60000, 28, 28) t10k-images-idx3-ubyte header data:(2051, 10000, 28, 28) train-labels-idx1-ubyte header data:(2049, 60000) t10k-labels-idx1-ubyte header data:(2049, 10000)
下面我们读取一张图片,并且展示图片和它的标记
from __future__ import division
from __future__ import print_function
#gunzip *.gz
#http://yann.lecun.com/exdb/mnist/
import os
import sys
import struct
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
file_list = [
"train-images-idx3-ubyte",
"train-labels-idx1-ubyte",
"t10k-images-idx3-ubyte",
"t10k-labels-idx1-ubyte",
]
def create_path(path):
if not os.path.isdir(path):
os.makedirs(path)
def get_file_full_name(path, name):
create_path(path)
if path[-1] == "/":
full_name = path + name
else:
full_name = path + "/" + name
return full_name
def read_mnist(file_name):
file_path = "/home/your/data/path"
full_path = get_file_full_name(file_path, file_name)
file_object = open(full_path, 'rb') #python3 need rb python2 r is ok
return file_object
def get_file_header_data(file_obj, header_len, unpack_str):
raw_header = file_obj.read(header_len)
header_data = struct.unpack(unpack_str, raw_header)
return header_data
def show_images_file_header(file_name):
show_file_header(file_name, 16, ">4I")
def show_labels_file_header(file_name):
show_file_header(file_name, 8, ">2I")
def show_file_header(file_name, header_len, unpack_str):
file_obj = read_mnist(file_name)
header_data = get_file_header_data(file_obj, header_len, unpack_str)
show_file_header_data(file_name, header_data)
file_obj.close()
def show_mnist_file_header():
train_images_file_name = file_list[0]
show_images_file_header(train_images_file_name)
test_images_file_name = file_list[2]
show_images_file_header(test_images_file_name)
train_labels_file_name = file_list[1]
show_labels_file_header(train_labels_file_name)
test_labels_file_name = file_list[3]
show_labels_file_header(test_labels_file_name)
def read_a_image(file_object):
img = file_object.read(28*28)
tp = struct.unpack(">784B",img)
image = np.asarray(tp)
image = image.reshape((28,28))
#image = image.astype(np.float64)
plt.imshow(image,cmap = plt.cm.gray)
plt.show()
def read_a_label(file_object):
img = file_object.read(1)
tp = struct.unpack(">B",img)
print("the label is :%s" % tp[0])
def show_file_header_data(file_name,header_data):
print("%s header data:%s" % (file_name, header_data))
def show_a_image():
images_file_name = file_list[0]
labels_file_name = file_list[1]
images_file = read_mnist(images_file_name)
header_data = get_file_header_data(images_file, 16, ">4I")
show_file_header_data(images_file_name, header_data)
labels_file = read_mnist(labels_file_name)
header_data = get_file_header_data(labels_file, 8, ">2I")
show_file_header_data(labels_file_name, header_data)
read_a_image(images_file)
read_a_label(labels_file)
def run():
#show_mnist_file_header()
show_a_image()
run()
输出:
train-images-idx3-ubyte header data:(2051, 60000, 28, 28)
train-labels-idx1-ubyte header data:(2049, 60000)
the label is :9
由标记对应图片的结果,看到
Label | Description |
---|---|
0 | T-shirt/top |
1 | Trouser |
2 | Pullover |
3 | Dress |
4 | Coat |
5 | Sandal |
6 | Shirt |
7 | Sneaker |
8 | Bag |
9 | Ankle boot |
我们输出与我们的图片相对应。
然后我们修改成能自动生成批数据
from __future__ import division
from __future__ import print_function
#gunzip *.gz
#http://yann.lecun.com/exdb/mnist/
import os
import sys
import struct
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
file_list = [
"train-images-idx3-ubyte",
"train-labels-idx1-ubyte",
"t10k-images-idx3-ubyte",
"t10k-labels-idx1-ubyte",
]
def show_images_file_header(file_name):
show_file_header(file_name, 16, ">4I")
def show_labels_file_header(file_name):
show_file_header(file_name, 8, ">2I")
def show_file_header(file_name, header_len, unpack_str):
file_obj = read_mnist(file_name)
header_data = get_file_header_data(file_obj, header_len, unpack_str)
show_file_header_data(file_name, header_data)
file_obj.close()
def show_mnist_file_header():
train_images_file_name = file_list[0]
show_images_file_header(train_images_file_name)
test_images_file_name = file_list[2]
show_images_file_header(test_images_file_name)
train_labels_file_name = file_list[1]
show_labels_file_header(train_labels_file_name)
test_labels_file_name = file_list[3]
show_labels_file_header(test_labels_file_name)
def show_a_image(file_object):
image = read_a_image(images_file)
image = np.asarray(tp)
image = image.reshape((28,28))
plt.imshow(image,cmap = plt.cm.gray)
plt.show()
def show_a_lebel(file_object):
tp = read_a_label(file_object)
print("the label is :%s" % tp)
def show_file_header_data(file_name,header_data):
print("%s header data:%s" % (file_name, header_data))
def show_a_image():
images_file_name = file_list[0]
labels_file_name = file_list[1]
images_file = read_mnist(images_file_name)
header_data = get_file_header_data(images_file, 16, ">4I")
show_file_header_data(images_file_name, header_data)
labels_file = read_mnist(labels_file_name)
header_data = get_file_header_data(labels_file, 8, ">2I")
show_file_header_data(labels_file_name, header_data)
show_a_image(images_file)
read_a_label(labels_file)
def create_path(path):
if not os.path.isdir(path):
os.makedirs(path)
def get_file_full_name(path, name):
create_path(path)
if path[-1] == "/":
full_name = path + name
else:
full_name = path + "/" + name
return full_name
def read_mnist(file_name):
file_path = "/home/your/data/path"
full_path = get_file_full_name(file_path, file_name)
file_object = open(full_path, 'rb') #python3 need rb python2 r is ok
return file_object
def get_file_header_data(file_obj, header_len, unpack_str):
raw_header = file_obj.read(header_len)
header_data = struct.unpack(unpack_str, raw_header)
return header_data
def read_a_image(file_object):
raw_img = file_object.read(28*28)
img = struct.unpack(">784B",raw_img)
return img
def read_a_label(file_object):
raw_label = file_object.read(1)
label = struct.unpack(">B",raw_label)
return label
def generate_a_batch(images_file_name,labels_file_name,batch_size=8):
images_file = read_mnist(images_file_name)
header_data = get_file_header_data(images_file, 16, ">4I")
#show_file_header_data(images_file_name, header_data)
labels_file = read_mnist(labels_file_name)
header_data = get_file_header_data(labels_file, 8, ">2I")
#show_file_header_data(labels_file_name, header_data)
while True:
images = []
labels = []
for i in range(100):
try:
image = read_a_image(images_file)
label = read_a_label(labels_file)
images.append(image)
labels.append(label)
except Exception as err:
print(err)
break
yield images,labels
def get_train_data_generator():
images_file_name = file_list[0]
labels_file_name = file_list[1]
gennerator = generate_a_batch(images_file_name,labels_file_name)
return gennerator-
def get_test_data_generator():
images_file_name = file_list[2]
labels_file_name = file_list[3]
gennerator = generate_a_batch(images_file_name,labels_file_name)
return gennerator
def get_test_data_generator():
images_file_name = file_list[2]
labels_file_name = file_list[3]
gennerator = generate_a_batch(images_file_name,labels_file_name)
return gennerator-
def get_a_batch(data_generator):
if sys.version >'3':
batch_img, batch_labels = data_generator.__next__()
else:
batch_img, batch_labels = data_generator.next()
return batch_img,batch_labels
def generate_test_batch():
data_generator = get_test_data_generator()
count = 1
while count:
batch_img,batch_labels = get_a_batch(data_generator)
if not batch_img and not batch_labels:
break
batch_img = np.array(batch_img)
batch_labels = np.array(batch_labels)
print("img shape:%s label shape:%s count:%s" %(batch_img.shape,batch_labels.shape,count))
count +=1
def generate_train_batch():
epoch = 0
while epoch<10:
epoch += 1
data_generator = get_train_data_generator()
count = 1
while count:
batch_img,batch_labels = get_a_batch(data_generator)
if not batch_img and not batch_labels:
break
batch_img = np.array(batch_img)
batch_labels = np.array(batch_labels)
print("epoch:%s img shape:%s label shape:%s count:%s" %(epoch,batch_img.shape,batch_labels.shape,count))
count +=1
def run():
generate_train_batch()
generate_test_batch()
run()
上面的代码好多没有用的代码,把没有用的代码删掉,我们得到
from __future__ import division
from __future__ import print_function
#gunzip *.gz
#http://yann.lecun.com/exdb/mnist/
import os
import sys
import struct
import numpy as np
import matplotlib.pyplot as plt
from PIL import Image
file_list = [
"train-images-idx3-ubyte",
"train-labels-idx1-ubyte",
"t10k-images-idx3-ubyte",
"t10k-labels-idx1-ubyte",
]
def create_path(path):
if not os.path.isdir(path):
os.makedirs(path)
def get_file_full_name(path, name):
create_path(path)
if path[-1] == "/":
full_name = path + name
else:
full_name = path + "/" + name
return full_name
def read_mnist(file_name):
file_path = "/home/your/data/path"
full_path = get_file_full_name(file_path, file_name)
file_object = open(full_path, 'rb') #python3 need rb python2 r is ok
return file_object
def get_file_header_data(file_obj, header_len, unpack_str):
raw_header = file_obj.read(header_len)
header_data = struct.unpack(unpack_str, raw_header)
return header_data
def read_a_image(file_object):
raw_img = file_object.read(28*28)
img = struct.unpack(">784B",raw_img)
return img
def read_a_label(file_object):
raw_label = file_object.read(1)
label = struct.unpack(">B",raw_label)
return label
def generate_a_batch(images_file_name,labels_file_name,batch_size=8):
images_file = read_mnist(images_file_name)
header_data = get_file_header_data(images_file, 16, ">4I")
labels_file = read_mnist(labels_file_name)
header_data = get_file_header_data(labels_file, 8, ">2I")
while True:
images = []
labels = []
for i in range(100):
try:
image = read_a_image(images_file)
label = read_a_label(labels_file)
images.append(image)
labels.append(label)
except Exception as err:
print(err)
break
yield images,labels
def get_train_data_generator():
images_file_name = file_list[0]
labels_file_name = file_list[1]
gennerator = generate_a_batch(images_file_name,labels_file_name)
return gennerator
def get_test_data_generator():
images_file_name = file_list[2]
labels_file_name = file_list[3]
gennerator = generate_a_batch(images_file_name,labels_file_name)
return gennerator
def get_a_batch(data_generator):
if sys.version >'3':
batch_img, batch_labels = data_generator.__next__()
else:
batch_img, batch_labels = data_generator.next()
return batch_img,batch_labels
def generate_test_batch():
data_generator = get_test_data_generator()
count = 1
while count:
batch_img,batch_labels = get_a_batch(data_generator)
if not batch_img and not batch_labels:
break
batch_img = np.array(batch_img)
batch_labels = np.array(batch_labels)
print("img shape:%s label shape:%s count:%s" %(batch_img.shape,batch_labels.shape,count))
count +=1
def generate_train_batch():
epoch = 0
while epoch<10:
epoch += 1
data_generator = get_train_data_generator()
count = 1
while count:
batch_img,batch_labels = get_a_batch(data_generator)
if not batch_img and not batch_labels:
break
batch_img = np.array(batch_img)
batch_labels = np.array(batch_labels)
print("epoch:%s img shape:%s label shape:%s count:%s" %(epoch,batch_img.shape,batch_labels.shape,count))
count +=1
def run():
generate_train_batch()
generate_test_batch()
run()