PyTorch MNIST Dataset

打印 上一主题 下一主题

主题 884|帖子 884|积分 2652

1. MNIST Dataset

https://pytorch.org/vision/main/generated/torchvision.datasets.MNIST.html
  1. torchvision.datasets.MNIST(root: Union[str, Path], train: bool = True, transform: Optional[Callable] = None, target_transform: Optional[Callable] = None, download: bool = False)
复制代码
Parameters:


  • root (str or pathlib.Path) - Root directory of dataset where MNIST/raw/train-images-idx3-ubyte and MNIST/raw/t10k-images-idx3-ubyte exist.
  • train (bool, optional) - If True, creates dataset from train-images-idx3-ubyte, otherwise from t10k-images-idx3-ubyte.
  • download (bool, optional) - If True, downloads the dataset from the internet and puts it in root directory. If dataset is already downloaded, it is not downloaded again.
  • transform (callable, optional) - A function/transform that takes in a PIL image and returns a transformed version. E.g, transforms.RandomCrop
  • target_transform (callable, optional) - A function/transform that takes in the target and transforms it.
Special-members:
  1. __getitem__(index: int) -> Tuple[Any, Any]
复制代码


  • Parameters: index (int)
  • Returns: (image, target) where target is index of the target class.
  • Return type: tuple
2. Source code for torchvision.datasets.mnist

https://pytorch.org/vision/main/_modules/torchvision/datasets/mnist.html
  1.     mirrors = [
  2.         "http://yann.lecun.com/exdb/mnist/",
  3.         "https://ossci-datasets.s3.amazonaws.com/mnist/",
  4.     ]
  5.     resources = [
  6.         ("train-images-idx3-ubyte.gz", "f68b3c2dcbeaaa9fbdd348bbdeb94873"),
  7.         ("train-labels-idx1-ubyte.gz", "d53e105ee54ea40749a09fcbcd1e9432"),
  8.         ("t10k-images-idx3-ubyte.gz", "9fb629c4189551a2d022fa330f9573f3"),
  9.         ("t10k-labels-idx1-ubyte.gz", "ec29112dd5afa0611ce80d1b7f02629c"),
  10.     ]
  11.     training_file = "training.pt"
  12.     test_file = "test.pt"
  13.     classes = [
  14.         "0 - zero",
  15.         "1 - one",
  16.         "2 - two",
  17.         "3 - three",
  18.         "4 - four",
  19.         "5 - five",
  20.         "6 - six",
  21.         "7 - seven",
  22.         "8 - eight",
  23.         "9 - nine",
  24.     ]
复制代码
可以通过下面的链接在欣赏器中下载,复制到 data/MNIST/raw/ 目次下。
  1. https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
  2. https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
  3. https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
  4. https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
复制代码
2.1. mnist-dataset.py

/home/yongqiang/llm_work/ggml_25_02_15/ggml/examples/mnist/mnist-dataset.py
  1. import torch
  2. import torchvision
  3. import torchvision.datasets
  4. import torchvision.transforms
  5. import numpy as np
  6. import matplotlib.pyplot as plt
  7. print(torch.__version__)
  8. train_data = torchvision.datasets.MNIST(root='./data', train=True, transform=torchvision.transforms.ToTensor(), download=True)
  9. test_data  = torchvision.datasets.MNIST(root='./data', train=False, transform=torchvision.transforms.ToTensor(), download=True)
  10. assert len(train_data) == 60000
  11. assert len(test_data)  == 10000
  12. print("len(train_data) =", len(train_data))
  13. print("len(test_data) =", len(test_data))
  14. print("type(train_data[0]):", type(train_data[0]))
  15. print("train_data[0].shape:", train_data[0][0].shape)
  16. classes = train_data.classes
  17. print("train_data.classes:", classes)
  18. print("train_data.class_to_idx: ", train_data.class_to_idx)
  19. def ImShow(sample_element, shape = (28, 28)):
  20.     plt.imshow(sample_element[0].numpy().reshape(shape), cmap='gray')
  21.     plt.title('Label = ' + str(sample_element[1]))
  22.     plt.show()
  23. ImShow(train_data[0])
复制代码
  1. (base) yongqiang@yongqiang:~/llm_work/ggml_25_02_15/ggml/examples/mnist$ python mnist-dataset.py
  2. 2.5.1
  3. Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
  4. Failed to download (trying next):
  5. HTTP Error 404: Not Found
  6. Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
  7. Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to ./data/MNIST/raw/train-images-idx3-ubyte.gz
  8. 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 9.91M/9.91M [00:03<00:00, 3.20MB/s]
  9. Extracting ./data/MNIST/raw/train-images-idx3-ubyte.gz to ./data/MNIST/raw
  10. Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
  11. Failed to download (trying next):
  12. HTTP Error 404: Not Found
  13. Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
  14. Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to ./data/MNIST/raw/train-labels-idx1-ubyte.gz
  15. 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 28.9k/28.9k [00:00<00:00, 117kB/s]
  16. Extracting ./data/MNIST/raw/train-labels-idx1-ubyte.gz to ./data/MNIST/raw
  17. Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
  18. Failed to download (trying next):
  19. HTTP Error 404: Not Found
  20. Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
  21. Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw/t10k-images-idx3-ubyte.gz
  22. 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████| 1.65M/1.65M [00:02<00:00, 656kB/s]
  23. Extracting ./data/MNIST/raw/t10k-images-idx3-ubyte.gz to ./data/MNIST/raw
  24. Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
  25. Failed to download (trying next):
  26. HTTP Error 404: Not Found
  27. Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
  28. Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz
  29. 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████| 4.54k/4.54k [00:00<00:00, 3.05MB/s]
  30. Extracting ./data/MNIST/raw/t10k-labels-idx1-ubyte.gz to ./data/MNIST/raw
  31. len(train_data) = 60000
  32. len(test_data) = 10000
  33. type(train_data[0]): <class 'tuple'>
  34. train_data[0].shape: torch.Size([1, 28, 28])
  35. train_data.classes: ['0 - zero', '1 - one', '2 - two', '3 - three', '4 - four', '5 - five', '6 - six', '7 - seven', '8 - eight', '9 - nine']
  36. train_data.class_to_idx:  {'0 - zero': 0, '1 - one': 1, '2 - two': 2, '3 - three': 3, '4 - four': 4, '5 - five': 5, '6 - six': 6, '7 - seven': 7, '8 - eight': 8, '9 - nine': 9}
  37. (base) yongqiang@yongqiang:~/llm_work/ggml_25_02_15/ggml/examples/mnist$
复制代码

3. SSL: CERTIFICATE_VERIFY_FAILED

  1. <urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: unable to get local issuer certificate (_ssl.c:1007)>
  2. <urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: certificate has expired (_ssl.c:1129)>
  3. <urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed: unable to get local issuer certificate (_ssl.c:1123)>
  4. <urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed (_ssl.c:833)>
  5. <urlopen error [SSL: CERTIFICATE_VERIFY_FAILED] certificate verify failed (_ssl.c:581)>
复制代码
取消证书验证:
  1. import ssl
  2. ssl._create_default_https_context = ssl._create_unverified_context
复制代码
References

[1] Yongqiang Cheng, https://yongqiang.blog.csdn.net/

免责声明:如果侵犯了您的权益,请联系站长,我们会及时删除侵权内容,谢谢合作!更多信息从访问主页:qidao123.com:ToB企服之家,中国第一个企服评测及商务社交产业平台。

本帖子中包含更多资源

您需要 登录 才可以下载或查看,没有账号?立即注册

x
回复

使用道具 举报

0 个回复

倒序浏览

快速回复

您需要登录后才可以回帖 登录 or 立即注册

本版积分规则

玛卡巴卡的卡巴卡玛

金牌会员
这个人很懒什么都没写!

标签云

快速回复 返回顶部 返回列表