当前位置: 首页 > news >正文

简单神经网络(ANN)实现:从零开始构建第一个模型

本文将手把手带你用 Python + Numpy 实现一个最基础的人工神经网络(Artificial Neural Network, ANN)。不依赖任何深度学习框架,适合入门理解神经网络的本质。


一、项目目标

构建一个三层神经网络(输入层、隐藏层、输出层),用于解决一个简单的二分类任务,例如根据两个输入特征判断输出是 0 还是 1。


二、基本结构说明

我们将构建如下结构的神经网络:

 

复制编辑

输入层(2个神经元) → 隐藏层(4个神经元) → 输出层(1个神经元)

  • 激活函数:使用 Sigmoid

  • 损失函数:均方误差

  • 学习方式:批量梯度下降 + 手动反向传播


三、准备数据

我们使用一个简单的数据集(可类比于 AND/OR 操作):

import numpy as np# 输入数据:4组样本,每组2个特征
X = np.array([[0, 0],[0, 1],[1, 0],[1, 1]
])# 标签:这里我们尝试模拟逻辑或(OR)操作
y = np.array([[0], [1], [1], [1]])

四、初始化网络参数

np.random.seed(0)# 网络结构:2 → 4 → 1
input_size = 2
hidden_size = 4
output_size = 1# 权重初始化(正态分布)
W1 = np.random.randn(input_size, hidden_size)
b1 = np.zeros((1, hidden_size))W2 = np.random.randn(hidden_size, output_size)
b2 = np.zeros((1, output_size))

五、激活函数

def sigmoid(x):return 1 / (1 + np.exp(-x))def sigmoid_derivative(x):# 输入为 sigmoid 的输出值return x * (1 - x)

六、训练循环

我们进行 10000 次迭代,手动实现前向传播、损失计算和反向传播。

learning_rate = 0.1
for epoch in range(10000):# --- 正向传播 ---z1 = np.dot(X, W1) + b1a1 = sigmoid(z1)z2 = np.dot(a1, W2) + b2a2 = sigmoid(z2)  # 预测值# --- 损失计算(均方误差)---loss = np.mean((y - a2) ** 2)# --- 反向传播 ---error_output = y - a2d_output = error_output * sigmoid_derivative(a2)error_hidden = d_output.dot(W2.T)d_hidden = error_hidden * sigmoid_derivative(a1)# --- 参数更新 ---W2 += a1.T.dot(d_output) * learning_rateb2 += np.sum(d_output, axis=0, keepdims=True) * learning_rateW1 += X.T.dot(d_hidden) * learning_rateb1 += np.sum(d_hidden, axis=0, keepdims=True) * learning_rateif epoch % 1000 == 0:print(f"Epoch {epoch}, Loss: {loss:.4f}")

七、模型测试

print("预测结果:")
print(a2.round())

输出如下,接近 OR 操作的结果 [0, 1, 1, 1]

预测结果:
[[0.][1.][1.][1.]]

八、总结与拓展

通过这篇文章,我们实现了一个从零开始的神经网络:

  • 完整构建了网络结构(无需框架)

  • 实现了正向传播与反向传播

  • 成功对二分类任务进行了拟合

拓展建议:

  • 改用 ReLU 激活函数;

  • 增加网络层数,提升模型表达能力;

  • 用 Softmax 处理多分类问题;

  • 尝试用真实数据集,如鸢尾花(Iris)或 MNIST。


这类“纯手写”的 ANN 实战项目非常适合用来理解深度学习的本质机制。如果你打算继续深入,可以尝试逐步迁移到 PyTorch 或 TensorFlow 框架实现更复杂的模型。

http://www.xdnf.cn/news/496657.html

相关文章:

  • 【第二篇】 初步解析Spring Boot
  • 第9讲、深入理解Scaled Dot-Product Attention
  • 【漫话机器学习系列】264.内距(又称四分位差)Interquartile Range
  • 抽奖系统-抽奖
  • uni-app小程序登录后…
  • 数据分析_Python
  • arduino平台读取鼠标光电传感器
  • MATLAB学习笔记(七):MATLAB建模城市的雨季防洪排污的问题
  • Elasticsearch 性能优化面试宝典
  • LabVIEW声音与振动测量分析
  • STM32实战指南:SG90舵机控制原理与代码详解
  • Qt与Hid设备通信
  • 392. Is Subsequence
  • 天拓四方锂电池卷绕机 PLC 物联网解决方案
  • 从零开始认识 Node.js:异步非阻塞的魅力
  • Go语言 GORM框架 使用指南
  • c/c++的opencv模糊
  • exit耗时高
  • PYTHON训练营DAY28
  • AMD Vivado™ 设计套件生成加密比特流和加密密钥
  • 【React中虚拟DOM与Diff算法详解】
  • 免费商用字体下载
  • STM32IIC协议基础及Cube配置
  • 创建react工程并集成tailwindcss
  • C++(20): 文件输入输出库 —— <fstream>
  • Pytorch实现常用代码笔记
  • 从代码学习深度学习 - 词嵌入(word2vec)PyTorch版
  • 05、基础入门-SpringBoot-HelloWorld
  • 页面上如何显示特殊字符、Unicode字符?
  • 【001】RenPy打包安卓apk 流程源码级别分析