当前位置: 首页 > ai >正文

Kaggle-Bag of Words Meets Bags of Popcorn-(二分类+NLP+Bert模型)

Bag of Words Meets Bags of Popcorn

题意:

有很多条电影评论记录,问你每一条记录是积极性的评论还是消极性的评论。

数据处理:

1.首先这是文件是zip形式,要先解压,注意sep = ‘\t’。
2.加载预训练的 BERT 分词器 bert-base-uncased,用于将文本转换为模型可接受的输入格式。
3.将 pandas DataFrame 转换为 Hugging Face 的 Dataset 对象,便于与 Transformers 库集成。
4.定义 data_tokenize 函数,使用 BERT 分词器对文本进行分词、填充和截断,最大长度为 512。
5.对训练集、验证集和测试集应用分词函数。
6.将数据集格式设置为 PyTorch 张量,并指定需要使用的列。
7.加载预训练的 BERT 模型 bert-base-uncased,并修改最后一层以适配二分类任务(num_labels=2)。

建立模型:

1.训练参数配置:定义训练参数,包括评估策略、保存策略、学习率、批次大小、训练轮数、权重衰减、日志记录步长、最佳模型加载和保存限制。
2.训练器初始化:定义 Trainer 对象,指定模型、训练参数、训练集、验证集、评估指标和分词器。
3.调用 trainer.train() 开始训练模型。评估模型在验证集上的性能,并打印验证准确率。

代码:
!pip install transformers datasets scikit-learn pandas torch
import os
os.environ["WANDB_DISABLED"] = "true"	
//注意这里要关掉实验跟踪功能,因为本题用了Hugging Face Transformers框架,该框架默认与WANDB集成。import sys
import pandas as pd
import numpy as np
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
from datasets import Dataset
from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
import zipfiledef data_tokenize(data):return tokenizer(data['review'],padding='max_length',truncation=True,max_length=512)def compute_metrics(pred):labels = pred.label_idspreds = pred.predictions.argmax(-1)return {"accuracy": accuracy_score(labels, preds)}if __name__ == '__main__':input_dir = "/kaggle/input/word2vec-nlp-tutorial/"with zipfile.ZipFile(f"{input_dir}labeledTrainData.tsv.zip", "r") as zip_ref:zip_ref.extractall(".")with zipfile.ZipFile(f"{input_dir}testData.tsv.zip", "r") as zip_ref:zip_ref.extractall(".")data_train_pre = pd.read_csv('labeledTrainData.tsv',sep='\t')data_test_pre = pd.read_csv('testData.tsv',sep='\t')X = data_train_pre['review']Y = data_train_pre['sentiment']X = X[:30000]Y = Y[:30000]X_train,X_val,Y_train,Y_val = train_test_split(X,Y,test_size=0.2,random_state=42)tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')data_train = Dataset.from_dict({'review':X_train.tolist(),'label':Y_train.tolist()})data_val = Dataset.from_dict({'review':X_val.tolist(),'label':Y_val.tolist()})data_test = Dataset.from_dict({'review': data_test_pre['review'].tolist()})data_train = data_train.map(data_tokenize,batched=True)data_val = data_val.map(data_tokenize,batched=True)data_test = data_test.map(data_tokenize,batched=True)data_train.set_format(type='torch',columns=['input_ids','attention_mask','label'])data_val.set_format(type='torch',columns=['input_ids','attention_mask','label'])data_test.set_format(type='torch',columns=['input_ids','attention_mask'])bert_model = BertForSequenceClassification.from_pretrained('bert-base-uncased',num_labels=2)training_args = TrainingArguments(eval_strategy="epoch",save_strategy="epoch",learning_rate=3e-5,per_device_train_batch_size=32,per_device_eval_batch_size=32,num_train_epochs=5,weight_decay=0.01,logging_steps=10,load_best_model_at_end=True,metric_for_best_model="accuracy",save_total_limit=2,fp16=True,)trainer = Trainer(model=bert_model,args=training_args,train_dataset=data_train,eval_dataset=data_val,compute_metrics=compute_metrics,tokenizer=tokenizer,)trainer.train()print("val accuracy: ",trainer.evaluate()['eval_accuracy'])Submission = pd.DataFrame({'id': data_test_pre['id'],'sentiment': trainer.predict(data_test).predictions.argmax(-1)})Submission.to_csv("/kaggle/working/submission.csv", index=False)
http://www.xdnf.cn/news/180.html

相关文章:

  • Mac 选择下载安装工具 x86 还是 arm64 ?
  • gl-matrix 库简介
  • 【java 13天进阶Day06】Map集合,HashMapTreeMap,斗地主、图书管理系统,排序算法
  • 实验2:turtle 库绘制进阶图形
  • Linux服务器配置Anaconda环境、Pytorch库(图文并茂的教程)
  • java基础从入门到上手(九):Java - List、Set、Map
  • 每天学一个 Linux 命令(20):find
  • 23种设计模式-创建型模式之抽象工厂模式(Java版本)
  • 【含文档+PPT+源码】基于Python的股票数据可视化及推荐系统的设计与实现
  • Oracle 11g通过dg4odbc配置dblink连接PostgreSQL
  • 从头学 | 目标函数、梯度下降相关知识笔记(一)
  • 边缘计算网关组态功能的定义
  • 阀门轴承电动车工件一键精修软件
  • vue2.6.12 安装babel 以使用 可选链 ?. 和空值合并 ??
  • 【Vue3代理机制详解:从原理到实践】
  • 医疗行业如何构建合成数据平台?——技术、合规与实践全景
  • Jenkins的使用及Pipeline语法讲解
  • 简易 Python 爬虫实现,10min可完成带效果源码
  • LIB-ZC, 一个跨平台(Linux)平台通用C/C++扩展库, 网络socket
  • Linux和Ubuntu的驱动适配情况
  • 数据结构-Map和Set
  • Oracle日志系统之附加日志
  • 学习海康VisionMaster之中线查找
  • 新手蓝桥杯冲击国一练习题单(四)
  • C++ 二叉搜索树
  • LINUX418 加载YUM源 wireshark ping程序 解析
  • 亚远景-ASPICE评估标准与车企供应商准入要求的关联性
  • 串口通信实战:从寄存器操作到数据处理的完全指南
  • 人像面部关键点检测
  • 力扣刷题Day 20:柱状图中最大的矩形(84)