from timm.data import ToTensor
from torchvision.datasets import CIFAR10
from torchvision.transforms import Compose# 数据的基本信息
def test1():# train=True 训练数据# transform=Compose([ToTensor]) 将数据转换成tensor张量# 加载数据集train = CIFAR10(root='root', train=True, download=True, transform=Compose([ToTensor()]))test = CIFAR10(root='root', train=False, transform=Compose([ToTensor()]))# 数据集的数量print('训练集数量:', len(train))print('测试集数量:', len(test))print('-----------')print('数据集形状:', train[0][0].shape)print('-----------')print('数据集类别:', train.class_to_idx)if __name__ == '__main__':test1()