【发布时间】:2023-03-18 09:55:01
【问题描述】:
我已经尝试使用this 代码从 google collab 运行我的代码,但它在下载代码时卡住了。我也在硬件加速器之间切换,但仍然没有。这个问题有解决办法吗?
【问题讨论】:
标签: python-3.x scikit-learn google-colaboratory
我已经尝试使用this 代码从 google collab 运行我的代码,但它在下载代码时卡住了。我也在硬件加速器之间切换,但仍然没有。这个问题有解决办法吗?
【问题讨论】:
标签: python-3.x scikit-learn google-colaboratory
您可以从github repository 下载它。
将下载的文件(来自自述文件链接)放在您当前路径中名为data/fashion/ 的目录中,然后您可以使用它们的加载器。
def load_mnist(path, kind='train'):
import os
import gzip
import numpy as np
"""Load MNIST data from `path`"""
labels_path = os.path.join(path,
'%s-labels-idx1-ubyte.gz'
% kind)
images_path = os.path.join(path,
'%s-images-idx3-ubyte.gz'
% kind)
with gzip.open(labels_path, 'rb') as lbpath:
labels = np.frombuffer(lbpath.read(), dtype=np.uint8,
offset=8)
with gzip.open(images_path, 'rb') as imgpath:
images = np.frombuffer(imgpath.read(), dtype=np.uint8,
offset=16).reshape(len(labels), 784)
return images, labels
X_train, y_train = load_mnist('data/fashion', kind='train')
X_test, y_test = load_mnist('data/fashion', kind='t10k')
另一种选择是使用torchvision FMNIST 数据集。
你也可以使用:
from tensorflow.examples.tutorials.mnist import input_data
data = input_data.read_data_sets('data/fashion', source_url='http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/')
这是下载文件的代码(可以通过一些try-catch来改进):
import os
import requests
path = 'data/fashion'
def download_fmnist(path):
DEFAULT_SOURCE_URL = 'http://fashion-mnist.s3-website.eu-central-1.amazonaws.com/'
files = dict(
TRAIN_IMAGES='train-images-idx3-ubyte.gz',
TRAIN_LABELS='train-labels-idx1-ubyte.gz',
TEST_IMAGES='t10k-images-idx3-ubyte.gz',
TEST_LABELS='t10k-labels-idx1-ubyte.gz')
if not os.path.exists(path):
os.mkdir(path)
for f in files:
filepath = os.path.join(path, files[f])
if not os.path.exists(filepath):
url = DEFAULT_SOURCE_URL + files[f]
r = requests.get(url, allow_redirects=True)
open(filepath, 'wb').write(r.content)
print('Successfully downloaded', f)
download_fmnist(path)
【讨论】:
load_mnist)加载数据。否则,您可以直接使用您在 tf 返回的变量data 中找到的内容。
x_train, y_train = (data.train._images, data.train._labels)。
keras.datasets.fashion_mnist.load_data() 命令返回一个 numpy 数组元组:(xtrain, ytrain) 和 (xtest, ytest)。
数据集不会以这种方式下载到您的本地存储中。这就是命令cd fashion-mnist/ 引发错误的原因。没有创建目录。 fashion-mnist 数据集已正确加载到您的代码中的(xtrain, ytrain) 和(xtest, ytest)。
【讨论】:
keras.datasets.fashion_mnist.load_data() 函数加载数据集,您的代码应该可以工作。我已经在 google colab 中测试了你的代码。执行需要很长时间,但它可以正常工作而不会出现任何错误。
对于 Google Colab
在顶部写!pip install mnist。
使用import mnist。
然后简单地存储图像和标签:
train_images = mnist.train_images()
train_labels = mnist.train_labels()
test_images = mnist.test_images()
test_labels = mnist.test_labels()
就是这样!!!
【讨论】: