当前位置: 首页 > news >正文

FPGA多通道卷积加速器:从零构建手写识别的硬件引擎

我最近在从事一项很有意思的项目,我想在PFGA上部署CNN并实现手写图片的识别。而本篇文章,是我迈出的第二步。具体代码已发布在github上

模块介绍

卷积神经网络(CNN)可以分为卷积层、池化层、激活层、全链接层结构。本篇实现的,就是CNN的卷积层中的卷积运算模块。

卷积运算的过程如下图所示:

img

在权重参数已经确定的情况下,我们可以将这过程看成数据滑窗卷积运算的这两个步骤的重复运算。在前文中,我们已经实现了window模块,而此处我们实现卷积运算模块。

运算过程如下:
[ 1 2 3 4 5 6 7 8 9 ] ∗ [ 1 0 1 0 1 0 1 1 2 ] = 1 ⋅ 1 + 2 ⋅ 0 + 3 ⋅ 1 + 4 ⋅ 0 + 5 ⋅ 1 + 6 ⋅ 0 + 7 ⋅ 1 + 8 ⋅ 1 + 9 ⋅ 2 = 42 \begin{bmatrix}1&2&3\\4&5&6\\7&8&9\end{bmatrix} \ast \begin{bmatrix}1&0&1\\ 0&1&0\\1&1&2 \end{bmatrix}=1\cdot1 +2 \cdot0 +3\cdot1+4\cdot0+5\cdot 1+6\cdot0+7\cdot1+8\cdot 1+9\cdot 2 \\ =42 147258369 101011102 =11+20+31+40+51+60+71+81+92=42

代码

  1. 模块可配置参数、输入和输出定义

为了支持多通道并行处理,输入为所有输入通道展平后的数据,如一维的窗口数据和权重参数

DATA_WIDTH和WEIGHT_WIDTH分开定义,因为后续工作中会对权重定点数量化

module mult_acc_comb #(parameter DATA_WIDTH = 8,parameter KERNEL_SIZE = 3,parameter IN_CHANNEL = 3,parameter WEIGHT_WIDTH = 8,parameter OUTPUT_WIDTH = 20,  // 可配置的输出位宽parameter ACC_WIDTH = 2*DATA_WIDTH + 4 + $clog2(KERNEL_SIZE*KERNEL_SIZE*IN_CHANNEL) // Ensure ACC_WIDTH is sufficient
)(// 输入数据接口input window_valid,input [IN_CHANNEL*KERNEL_SIZE*KERNEL_SIZE*DATA_WIDTH-1:0] multi_channel_window_in,input weight_valid,input [IN_CHANNEL*KERNEL_SIZE*KERNEL_SIZE*WEIGHT_WIDTH-1:0] multi_channel_weight_in,// 输出数据接口output [OUTPUT_WIDTH-1:0] conv_out, // 使用可配置的输出位宽output conv_valid
);
  1. 定义内部相关信号
// 计算权重相关参数
localparam WEIGHTS_PER_FILTER = IN_CHANNEL * KERNEL_SIZE * KERNEL_SIZE;// 解包后的多通道窗口数据和权重数据,无符号
wire [DATA_WIDTH-1:0] channel_window_data [0:IN_CHANNEL-1][0:KERNEL_SIZE*KERNEL_SIZE-1]; 
wire [WEIGHT_WIDTH-1:0] channel_weight_data [0:IN_CHANNEL-1][0:KERNEL_SIZE*KERNEL_SIZE-1];  // 每个通道每个位置的乘法结果,无符号
wire [DATA_WIDTH+WEIGHT_WIDTH-1:0] mult_results [0:IN_CHANNEL-1][0:KERNEL_SIZE*KERNEL_SIZE-1]; // 每个通道的累加结果
wire [ACC_WIDTH-1:0] channel_sums [0:IN_CHANNEL-1];// 最终跨通道累加结果
wire [ACC_WIDTH-1:0] total_sum; // 循环变量
genvar ch, i_idx, k_idx, c_idx; 
  1. 输入数据解包
generatefor (ch = 0; ch < IN_CHANNEL; ch = ch + 1) begin : unpack_genfor (i_idx = 0; i_idx < KERNEL_SIZE*KERNEL_SIZE; i_idx = i_idx + 1) begin : element_gen// 解包窗口数据assign channel_window_data[ch][i_idx] = multi_channel_window_in[(ch*KERNEL_SIZE*KERNEL_SIZE + i_idx)*DATA_WIDTH +: DATA_WIDTH ];// 解包权重数据assign channel_weight_data[ch][i_idx] = multi_channel_weight_in[(WEIGHTS_PER_FILTER - 1 - (ch*KERNEL_SIZE*KERNEL_SIZE + i_idx))*WEIGHT_WIDTH +: WEIGHT_WIDTH];endend
endgenerate

a[ b +: c ]的含义是,从a的b位,向上提取c位,也就是a[b+c:b+1];

输入的window和weight的数据结构变化如下

在这里插入图片描述

  1. 并行卷积运算

所有通道同时进行卷积运算

// 并行乘法 - 所有通道所有位置同时计算
generatefor (ch = 0; ch < IN_CHANNEL; ch = ch + 1) begin : mult_ch_genfor (i_idx = 0; i_idx < KERNEL_SIZE*KERNEL_SIZE; i_idx = i_idx + 1) begin : mult_elem_genassign mult_results[ch][i_idx] = channel_window_data[ch][i_idx] * channel_weight_data[ch][i_idx];endend
endgenerate// 每个通道内累加 - 使用组合逻辑加法树
generatefor (ch = 0; ch < IN_CHANNEL; ch = ch + 1) begin : sum_ch_genif (KERNEL_SIZE == 3) begin : kernel3_sumassign channel_sums[ch] = mult_results[ch][0] + mult_results[ch][1] + mult_results[ch][2] +mult_results[ch][3] + mult_results[ch][4] + mult_results[ch][5] +mult_results[ch][6] + mult_results[ch][7] + mult_results[ch][8];end else begin : general_sumwire [ACC_WIDTH-1:0] partial_sums [0:KERNEL_SIZE*KERNEL_SIZE-1];assign partial_sums[0] = mult_results[ch][0];for (k_idx = 1; k_idx < KERNEL_SIZE*KERNEL_SIZE; k_idx = k_idx + 1) begin : acc_genassign partial_sums[k_idx] = partial_sums[k_idx-1] + mult_results[ch][k_idx];endassign channel_sums[ch] = partial_sums[KERNEL_SIZE*KERNEL_SIZE-1];endend
endgenerate
  1. 跨通道累加并输出

对所有通道结果进行相加,进行饱和处理,然后输出

// 跨通道累加 - 组合逻辑
generateif (IN_CHANNEL == 3) begin : channel3_sumassign total_sum = channel_sums[0] + channel_sums[1] + channel_sums[2];end else begin : general_channel_sumwire [ACC_WIDTH-1:0] channel_partial_sums [0:IN_CHANNEL-1];assign channel_partial_sums[0] = channel_sums[0];for (c_idx = 1; c_idx < IN_CHANNEL; c_idx = c_idx + 1) begin : ch_acc_genassign channel_partial_sums[c_idx] = channel_partial_sums[c_idx-1] + channel_sums[c_idx];endassign total_sum = channel_partial_sums[IN_CHANNEL-1];end
endgenerate// 输出逻辑 - 组合逻辑
assign conv_valid = window_valid && weight_valid;
assign conv_out = conv_valid ? saturate(total_sum) : {OUTPUT_WIDTH{1'b0}};// 饱和处理函数(组合逻辑)- UNSIGNED
function [OUTPUT_WIDTH-1:0] saturate;input [ACC_WIDTH-1:0] value; // UNSIGNEDlocalparam [ACC_WIDTH-1:0] MAX_UNSIGNED_VAL_SAT = (1 << OUTPUT_WIDTH) - 1;// MIN_UNSIGNED_VAL is 0beginif (value > MAX_UNSIGNED_VAL_SAT)saturate = MAX_UNSIGNED_VAL_SAT[OUTPUT_WIDTH-1:0]; // 使用OUTPUT_WIDTH进行截取elsesaturate = value[OUTPUT_WIDTH-1:0]; // 使用OUTPUT_WIDTH进行截取end
endfunction

测试

mult_acc_comb_tb.v

为验证其功能性,使用多个case经行测试,并对比结果

`timescale 1ns / 1psmodule mult_acc_comb_tb;parameter DATA_WIDTH = 8;
parameter KERNEL_SIZE = 3;
parameter IN_CHANNEL = 3;
parameter WEIGHT_WIDTH = 8;
parameter OUTPUT_WIDTH = 20;  
parameter ACC_WIDTH = 2*DATA_WIDTH + 4 + $clog2(KERNEL_SIZE*KERNEL_SIZE*IN_CHANNEL);reg window_valid;
reg [IN_CHANNEL*KERNEL_SIZE*KERNEL_SIZE*DATA_WIDTH-1:0] multi_channel_window_in;
reg weight_valid;
reg [IN_CHANNEL*KERNEL_SIZE*KERNEL_SIZE*WEIGHT_WIDTH-1:0] multi_channel_weight_in;wire [OUTPUT_WIDTH-1:0] conv_out;
wire conv_valid;localparam MAX_UNSIGNED_OUT_VAL = (1 << OUTPUT_WIDTH) - 1;// Example: Test 2 raw sum for unsigned context
localparam EXPECTED_SUM_TEST2_UNSIGNED_RAW = 3 * 9 * 2 * 3; // 162
localparam EXPECTED_CONV_OUT_TEST2_UNSIGNED_SAT = (EXPECTED_SUM_TEST2_UNSIGNED_RAW > MAX_UNSIGNED_OUT_VAL) ? MAX_UNSIGNED_OUT_VAL : EXPECTED_SUM_TEST2_UNSIGNED_RAW;
localparam MAX_ELEMENT_VAL_TB = (1 << DATA_WIDTH) -1;
localparam MAX_WEIGHT_ELEMENT_VAL_TB = (1 << WEIGHT_WIDTH) -1;mult_acc_comb #(.DATA_WIDTH(DATA_WIDTH),.KERNEL_SIZE(KERNEL_SIZE),.IN_CHANNEL(IN_CHANNEL),.WEIGHT_WIDTH(WEIGHT_WIDTH),.OUTPUT_WIDTH(OUTPUT_WIDTH),.ACC_WIDTH(ACC_WIDTH)
) dut (.window_valid(window_valid),.multi_channel_window_in(multi_channel_window_in),.weight_valid(weight_valid),.multi_channel_weight_in(multi_channel_weight_in),.conv_out(conv_out),.conv_valid(conv_valid)
);reg all_tests_passed_flag; 
integer test_id_counter;
integer num_errors;// Task to check results and display Expected/Actual for all
task check_and_report;input [OUTPUT_WIDTH-1:0] expected_out_val;input expected_valid_val;// Test description is displayed before calling this taskbegintest_id_counter = test_id_counter + 1;// Always display Expected and Actual$display("    Expected: conv_valid=%b, conv_out=%d", expected_valid_val, expected_out_val);$display("    Actual:   conv_valid=%b, conv_out=%d", conv_valid, conv_out);if (conv_valid === expected_valid_val &&( (expected_valid_val === 1'b0) ? (conv_out === {OUTPUT_WIDTH{1'b0}}) : (conv_out === expected_out_val) ) ) begin$display("    Test ID %0d: Status: PASSED", test_id_counter);end else begin$display("    Test ID %0d: Status: FAILED", test_id_counter);all_tests_passed_flag = 1'b0;num_errors = num_errors + 1;end$display("--------------------------------------------------");end
endtaskinitial begin$display("=== Comprehensive UNSIGNED Combinational MultAcc Test (OUTPUT_WIDTH=%0d) ===", OUTPUT_WIDTH);all_tests_passed_flag = 1'b1; test_id_counter = 0;num_errors = 0;// Initializewindow_valid = 0;weight_valid = 0;multi_channel_window_in = 0;multi_channel_weight_in = 0;#10;// Test 1$display("Test Description: Simple Positive Values (1*1, sum 27)");multi_channel_window_in = {27{8'd1}}; multi_channel_weight_in = {27{8'd1}}; window_valid = 1;weight_valid = 1;#1; check_and_report(27, 1'b1);#10;// Test 2$display("Test Description: Positive Values with Saturation (2*3, raw %0d, sat %0d)", EXPECTED_SUM_TEST2_UNSIGNED_RAW, EXPECTED_CONV_OUT_TEST2_UNSIGNED_SAT);multi_channel_window_in = {27{8'd2}};multi_channel_weight_in = {27{8'd3}};#1; check_and_report(EXPECTED_CONV_OUT_TEST2_UNSIGNED_SAT, 1'b1);#10;// Test 3$display("Test Description: Invalid Inputs (both valid_n low)");window_valid = 0;weight_valid = 0;#1;check_and_report(0, 1'b0); #10;// Test 4$display("Test Description: Zero Window Data, Non-zero Weights");window_valid = 1;weight_valid = 1;multi_channel_window_in = {27{8'd0}}; multi_channel_weight_in = {27{8'd5}}; #1;check_and_report(0, 1'b1);#10;// Test 5$display("Test Description: Non-zero Window, Zero Weight Data");multi_channel_window_in = {27{8'd5}}; multi_channel_weight_in = {27{8'd0}}; #1;check_and_report(0, 1'b1);#10;// Test 6$display("Test Description: All Zero Inputs");multi_channel_window_in = {27{8'd0}}; multi_channel_weight_in = {27{8'd0}}; #1;check_and_report(0, 1'b1);#10;// Test 7$display("Test Description: Large values (no saturation with 20-bit output)");multi_channel_window_in = {27{8'd5}}; multi_channel_weight_in = {27{8'd5}}; #1;check_and_report(27*5*5, 1'b1);  // 27*25 = 675, well within 20-bit range#10;// Test 8$display("Test Description: Max Val Inputs (Win=%d, Wgt=%d), should saturate to %d", MAX_ELEMENT_VAL_TB, MAX_WEIGHT_ELEMENT_VAL_TB, MAX_UNSIGNED_OUT_VAL);multi_channel_window_in = {27{{DATA_WIDTH{1'b1}}}};multi_channel_weight_in = {27{{WEIGHT_WIDTH{1'b1}}}};#1;// 27 * 255 * 255 = 1,759,725, which exceeds 20-bit max (1,048,575), so should saturatecheck_and_report(MAX_UNSIGNED_OUT_VAL, 1'b1);#10;// Test 8.5: Test 20-bit range capability$display("Test Description: Medium values to test 20-bit range (100*100, sum 270000)");multi_channel_window_in = {27{8'd100}}; multi_channel_weight_in = {27{8'd100}}; #1;check_and_report(27*100*100, 1'b1);  // 27*10000 = 270000, well within 20-bit range#10;// Test 9: Window valid toggles$display("--- Test Sequence 9: Window Valid Toggles (base inputs 1*1, sum 27) ---");multi_channel_window_in = {27{8'd1}};multi_channel_weight_in = {27{8'd1}};weight_valid = 1; $display("  Sub-Test Description: WinValid=1 (Start)");window_valid = 1; #1; check_and_report(27, 1'b1);$display("  Sub-Test Description: WinValid=0");window_valid = 0; #1; check_and_report(0,  1'b0);$display("  Sub-Test Description: WinValid=1 (End)");window_valid = 1; #1; check_and_report(27, 1'b1);#10;// Test 10: Weight valid toggles$display("--- Test Sequence 10: Weight Valid Toggles (base inputs 1*1, sum 27) ---");window_valid = 1; // inputs are still 1s$display("  Sub-Test Description: WeightValid=1 (Start)");weight_valid = 1; #1; check_and_report(27, 1'b1);$display("  Sub-Test Description: WeightValid=0");weight_valid = 0; #1; check_and_report(0,  1'b0);$display("  Sub-Test Description: WeightValid=1 (End)");weight_valid = 1; #1; check_and_report(27, 1'b1);#10;// Final Summary$display("==================================================");if (all_tests_passed_flag) begin$display("FINAL STATUS: SUCCESS! All %0d UNSIGNED Combinational MultAcc tests passed!", test_id_counter);end else begin$display("FINAL STATUS: FAILED. %0d out of %0d UNSIGNED Combinational MultAcc tests did not pass.", num_errors, test_id_counter);end$display("==================================================");$finish;
endendmodule 

结果

window模块每个周期传递数据,因而采用组合逻辑实现卷积运算。当输入数据同时有效,也就是window_valid和weight_valid同时为高时,mult_acc_com进行运算,conv_valid拉高,如下图所示

在这里插入图片描述

输出打印结果:

=Comprehensive UNSIGNED Combinational MultAcc Test (OUTPUT_WIDTH=20) =
Test Description: Simple Positive Values (1*1, sum 27)
Expected: conv_valid=1, conv_out= 27
Actual: conv_valid=1, conv_out= 27

Test ID 1: Status: PASSED

Test Description: Positive Values with Saturation (2*3, raw 162, sat 162)
Expected: conv_valid=1, conv_out= 162
Actual: conv_valid=1, conv_out= 162

Test ID 2: Status: PASSED

Test Description: Invalid Inputs (both valid_n low)
Expected: conv_valid=0, conv_out= 0
Actual: conv_valid=0, conv_out= 0

Test ID 3: Status: PASSED

Test Description: Zero Window Data, Non-zero Weights
Expected: conv_valid=1, conv_out= 0
Actual: conv_valid=1, conv_out= 0

Test ID 4: Status: PASSED

Test Description: Non-zero Window, Zero Weight Data
Expected: conv_valid=1, conv_out= 0
Actual: conv_valid=1, conv_out= 0

Test ID 5: Status: PASSED

Test Description: All Zero Inputs
Expected: conv_valid=1, conv_out= 0
Actual: conv_valid=1, conv_out= 0

Test ID 6: Status: PASSED

Test Description: Large values (no saturation with 20-bit output)
Expected: conv_valid=1, conv_out= 675
Actual: conv_valid=1, conv_out= 675

Test ID 7: Status: PASSED

Test Description: Max Val Inputs (Win= 255, Wgt= 255), should saturate to 1048575
Expected: conv_valid=1, conv_out=1048575
Actual: conv_valid=1, conv_out=1048575

Test ID 8: Status: PASSED

Test Description: Medium values to test 20-bit range (100*100, sum 270000)
Expected: conv_valid=1, conv_out= 270000
Actual: conv_valid=1, conv_out= 270000

Test ID 9: Status: PASSED

— Test Sequence 9: Window Valid Toggles (base inputs 1*1, sum 27) —
Sub-Test Description: WinValid=1 (Start)
Expected: conv_valid=1, conv_out= 27
Actual: conv_valid=1, conv_out= 27

Test ID 10: Status: PASSED

Sub-Test Description: WinValid=0
Expected: conv_valid=0, conv_out= 0
Actual: conv_valid=0, conv_out= 0

Test ID 11: Status: PASSED

Sub-Test Description: WinValid=1 (End)
Expected: conv_valid=1, conv_out= 27
Actual: conv_valid=1, conv_out= 27

Test ID 12: Status: PASSED

— Test Sequence 10: Weight Valid Toggles (base inputs 1*1, sum 27) —
Sub-Test Description: WeightValid=1 (Start)
Expected: conv_valid=1, conv_out= 27
Actual: conv_valid=1, conv_out= 27

Test ID 13: Status: PASSED

Sub-Test Description: WeightValid=0
Expected: conv_valid=0, conv_out= 0
Actual: conv_valid=0, conv_out= 0

Test ID 14: Status: PASSED

Sub-Test Description: WeightValid=1 (End)
Expected: conv_valid=1, conv_out= 27
Actual: conv_valid=1, conv_out= 27

Test ID 15: Status: PASSED

http://www.xdnf.cn/news/1013293.html

相关文章:

  • 电脑虚拟网卡安装(添加以太网2)
  • 自己的电脑搭建外网访问网站服务器的步骤
  • 局域网内电脑与安卓设备低延迟同屏技术【100ms - 200ms】
  • Python-PLAXIS自动化建模技术与典型岩土工程
  • PyTorch深度学习框架60天进阶学习计划 - 第58天端到端对话系统(一):打造你的专属AI语音助手
  • 全时智能客服+精准触达转化:云徙科技打造汽车营销新体验
  • 【论文解读】OpenR:让大模型“深思熟虑”的开源框架
  • 51c自动驾驶~合集59
  • PCB 层压板的 Dk 和 Df 表征方法 – 第二部分
  • 高频面试之11Flink
  • 【Docker】docker 常用命令
  • redis穿透、击穿、雪崩
  • 30-Oracle 23ai-回顾从前的Flashback设置
  • SQL进阶之旅 Day 30:SQL性能调优实战案例
  • [网络实验] Cisco Packet Tracer | 通信子网的拓扑设计
  • 网络传输中的大小端问题
  • 一阶低通滤波器完整推导笔记
  • 【Chipyard】修改Gemmini 中PE的数量
  • JDK版本如何丝滑切换
  • 42 C 语言随机数生成详解:rand/srand 使用技巧、随机数范围控制、真实场景应用
  • Unity Assembly的灵活用法总结
  • 一块开发板多少钱?如何花最少的钱入门?
  • 【大模型02---Megatron-LM】
  • Node.js特训专栏-基础篇:2. JavaScript核心知识在Node.js中的应用
  • Flink 系列之二十八- Flink SQL - 水位线和窗口
  • 【计算机组成原理 第5版】白、戴编著 第七章 总线系统 课后题总结
  • 为什么电流、电压相同,功率却不同
  • ETLCloud中数据脱敏规则的使用技巧
  • 【有源医疗器械检测的常见问题、整改方法、送检了解】
  • HALCON第六讲->测量和检测