PyTorch_张量元素类型转换
- tensor.type([张量类型])
- torch.double()
代码
import torch
import numpy as np # 使用 type() 函数进行转换
def test01():data = torch.full([2,3], 10)print(data.dtype)# 注意:返回一个新的类型转换过的张量data = data.type(torch.DoubleTensor)#data = data.type(torch.IntTensor)print(data.dtype)# 使用具体类型函数进行转换
def test02():data = torch.full([2,3], 10)print(data.dtype)# 转换程 float64 类型data = data.double()print(data.dtype)"""# 转换成其他类型data.short() # 将张量元素转换成 int16 类型data. int() # 将张量转换成 int32 类型data.long() # 将张量转换成 int64 类型data.float() # 将张量转换成 float32 类型"""if __name__ == '__main__':test01()