当前位置: 首页 > ds >正文

使用 np.zeros_like(label) 保存预测概率时发现数据类型不匹配导致的隐式类型转换

🍉 CSDN 叶庭云https://yetingyun.blog.csdn.net/


下面这段代码中 predictions 数组在赋值后全为 0 的根本原因是数据类型不匹配导致的隐式类型转换,保存概率值时会被截断为 0 或 1。具体分析如下:

import numpy as npsamples = 100
y = np.random.randint(0, 2, size=samples)
y_pred = np.random.randint(0, 2, size=samples)
y_pred_proba = np.random.rand(samples).astype(np.float32)
print(y, len(y))
print(y_pred, len(y_pred))
print(y_pred_proba, len(y_pred_proba))print("=" * 100)
predictions = np.zeros_like(y_pred)
print(predictions, len(predictions))
predictions[:] = y_pred_proba
print(predictions, len(predictions))

在这里插入图片描述


1. 错误原因分析

  1. 初始化数据类型错误

    • y_prednp.random.randint(0, 2) 生成的,默认数据类型为 int

    • np.zeros_like(y_pred) 会继承 y_predint 类型,因此 predictions 是整型数组。

  2. 浮点数到整型的强制转换

    • y_pred_probanp.random.rand 生成的浮点数组(float32)。

    • 当执行 predictions[:] = y_pred_proba 时,右侧的浮点数会被强制转换为左侧的整型,导致小数部分被截断。例如,0.95 -> 00.19 -> 0

    • 最终,所有 y_pred_proba 中的浮点值都会变成 0,导致 predictions 全为 0。


2. 修正方法

核心思路:确保 predictions 的数据类型与 y_pred_proba 兼容(即浮点类型)。

2.1 高效且精确的修正代码
# 修正后的关键行:显式指定浮点类型
predictions = np.zeros_like(y_pred, dtype=np.float32)  # 强制为浮点类型
predictions[:] = y_pred_proba
2.2 修正原理
  1. 显式指定数据类型

    • np.zeros_like(y_pred, dtype=np.float32) 会创建一个与 y_pred 形状相同但数据类型为 float32 的数组。

    • 此时 predictions 可以正确存储浮点数值,避免类型转换。

  2. 赋值操作保留精度

    • 右侧的 y_pred_probafloat32)可以直接赋值给左侧的浮点数组,无精度损失。

3. 完整修正代码

import numpy as npsamples = 100
y = np.random.randint(0, 2, size=samples)
y_pred = np.random.randint(0, 2, size=samples)
y_pred_proba = np.random.rand(samples).astype(np.float32)
print(y, len(y))
print(y_pred, len(y_pred))
print(y_pred_proba, len(y_pred_proba))print("=" * 100)
predictions = np.zeros_like(y_pred)
print(predictions, len(predictions))
predictions[:] = y_pred_proba
print(predictions, len(predictions))
predictions = np.zeros_like(y_pred, dtype=np.float32)
predictions[:] = y_pred_proba
print(predictions, len(predictions))

4. 其他可行方案

  1. 直接使用浮点初始化
predictions = np.zeros(samples, dtype=np.float32)
  1. 复用 y_pred_proba 的数据类型
predictions = np.zeros_like(y_pred_proba)

5. 总结

  • 根本原因:整型数组无法存储浮点数值,这会引起隐式类型转换。

  • 修正关键:确保目标数组的数据类型与源数据相匹配(浮点数类型:np.float32、np.float64)。


http://www.xdnf.cn/news/3092.html

相关文章:

  • 新版权案件申请TRO,涵盖复古风吉他与头盔
  • 【LeetCode】螺旋矩阵
  • Maven根据Google proto文件自动生成java对象
  • 香港科技大学广州|智能制造学域硕、博研究生招生可持续能源与环境学域博士招生宣讲会—四川大学专场!
  • Unity-Shader详解-其三
  • 电子电器架构 --- 人工智能、固态电池和先进自动驾驶功能等新兴技术的影响
  • Centos Ubuntu RedOS系统类型下查看系统信息
  • 黑马Redis(四)
  • A2A与MCP:理解它们的区别以及何时使用
  • 除法未能拿下 一直运行超时
  • MySQL 实战 45 讲 笔记 ----来源《极客时间》
  • Markdown语法大全
  • 洛谷题解 | CF111C Petya and Spiders
  • Spark GraphX 机器学习:图计算
  • CertiK创始人顾荣辉出席Unchained Summit,探讨Web3.0安全与合规路径
  • 记录 Flink jdbc、mysql-cdc 连接 mysql8 碰到的适配问题
  • 4.28-4.29 Vue
  • phpstudy修改Apache端口号
  • Azure Synapse Dedicated SQL pool企业权限管理
  • 论文阅读:2024 arxiv FlipAttack: Jailbreak LLMs via Flipping
  • 怎样学习Electron
  • 驱动开发硬核特训 · Day 25 (附加篇):从设备树到驱动——深入理解Linux时钟子系统的实战链路
  • PSO详解变体上新!新型混合蛾焰粒子群优化(MFPSO)算法
  • GA-Transformer遗传算法优化编码器多特征分类预测/故障诊断,作者:机器学习之心
  • 【Redis——数据类型和内部编码和Redis使用单线程模型的分析】
  • EtherCAT 分布式时钟(DC)补偿技术解析
  • React Native 动态切换主题
  • 使用js写一个发布订阅者
  • 给 BBRv2/3 火上浇油的 drain-to-target
  • 26考研 | 王道 | 计算机网络 | 第一章 计算机网络的体系结构