脑机新手指南(九):高性能脑文本通信:手写方式实现(上)
一、引言
在当今科技飞速发展的时代,脑机接口技术一直是备受关注的前沿领域。最近的一项研究成果为我们带来了新的惊喜 —— 通过手写实现高性能的脑文本通信。今天,我们就来深入了解一下这个令人兴奋的研究。
二、项目概述
该项目相关代码库可在对应的 GitHub 仓库中找到,它与一篇发表在《Nature》杂志上的论文相关,论文链接为High-performance brain-to-text communication via handwriting | Nature。项目的整体思路是将脑电信号转化为文本信息,并且每一步的结果都会保存到磁盘中,以便后续步骤使用。同时,中间结果和模型可以从Dryad | Data -- High-performance brain-to-text communication via handwriting下载,这样我们就可以探索某些步骤而无需运行所有先前的步骤(不过步骤 3 需要自己运行,因为它会产生约 100GB 的文件)。
三、主要结果
原作者已经进行了一次完整的代码运行,主要结果可以在Dryad | Data -- High-performance brain-to-text communication via handwriting查看。这些结果来自训练和测试分区('HeldOutTrials' 和 'HeldOutBlocks'),并且是使用SummarizeRNNPerformance.ipynb笔记本生成的。每个结果都报告了 95% 的置信区间。
四、依赖环境
要运行这个项目,我们需要安装一些必要的依赖库,以下是详细的依赖信息:
Python:版本要求python>=3.6。Python 是一种广泛使用的高级编程语言,具有丰富的库和工具,非常适合进行数据分析和机器学习任务。
TensorFlow:版本要求tensorflow=1.15。TensorFlow 是一个开源的机器学习框架,提供了丰富的工具和算法,用于构建和训练深度学习模型。
NumPy:测试版本为1.17。NumPy 是 Python 的一个重要科学计算库,提供了高效的多维数组对象和各种数学函数。
SciPy:测试版本为1.1.0。SciPy 是基于 NumPy 的科学计算库,提供了许多用于优化、积分、插值等的算法。
Scikit-learn:测试版本为0.20。Scikit-learn 是一个简单而高效的机器学习库,提供了各种分类、回归、聚类等算法。
五、核心代码解读
在代码库中,rnnEval.py文件包含了一些核心的函数,用于评估 RNN 的输出结果。下面我们来详细解读这些函数。
evaluateRNNOutput函数
def evaluateRNNOutput(rnnOutput, numBinsPerSentence, trueText, charDef, charStartThresh=0.3, charStartDelay=15): """ Converts the rnn output (character probabilities & a character start signal) into a discrete sentence and computes char/word error rates. Returns error counts and the decoded sentences. """ lgit = rnnOutput[:,:,0:-1] charStart = rnnOutput[:,:,-1] #convert output to character strings decStr = decodeCharStr(lgit, charStart, charStartThresh, charStartDelay, numBinsPerSentence, charDef['charListAbbr']) allErrCounts = {} allErrCounts['charCounts'] = np.zeros([len(trueText)]) allErrCounts['charErrors'] = np.zeros([len(trueText)]) allErrCounts['wordCounts'] = np.zeros([len(trueText)]) allErrCounts['wordErrors'] = np.zeros([len(trueText)]) allDecSentences = [] #compute error rates for t in range(len(trueText)): thisTrueText = trueText[t,0][0] thisTrueText = thisTrueText.replace(' ','') thisTrueText = thisTrueText.replace('>',' ') thisTrueText = thisTrueText.replace('~','.') thisTrueText = thisTrueText.replace('#','') thisDec = decStr[t] thisDec = thisDec.replace('>',' ') thisDec = thisDec.replace('~','.') nCharErrors = wer(list(thisTrueText), list(thisDec)) nWordErrors = wer(thisTrueText.strip().split(), thisDec.strip().split()) allErrCounts['charCounts'][t] = len(thisTrueText) allErrCounts['charErrors'][t] = nCharErrors allErrCounts['wordCounts'][t] = len(thisTrueText.strip().split()) allErrCounts['wordErrors'][t] = nWordErrors allDecSentences.append(thisDec) return allErrCounts, allDecSentences
这个函数的主要功能是将 RNN 的输出(字符概率和字符开始信号)转换为离散的句子,并计算字符和单词的错误率。它首先调用decodeCharStr函数将输出转换为字符字符串,然后遍历每个真实文本和对应的解码文本,计算字符和单词的错误数量。最后返回错误计数和解码后的句子。