当前位置: 首页 > ds >正文

循环神经网络RNN(示例代码LSTM预测股价示例)

1 应用场景:

AI翻译,AI对话,股票预测,行为判断

2 背景介绍

3 代码:

import pandas as pd
import numpy as np
import baostock as bs#下载数据
def get_stock_data_and_save_to_csv(stock_code, start_date, end_date, file_path):# 登陆系统lg = bs.login()if lg.error_code != '0':print(f"登录失败,错误代码: {lg.error_code},错误信息: {lg.error_msg}")return# 获取股票数据rs = bs.query_history_k_data_plus(stock_code,"date,code,open,high,low,close,preclose,volume,amount,adjustflag,turn,tradestatus,pctChg,peTTM,pbMRQ,psTTM,pcfNcfTTM,isST",start_date=start_date, end_date=end_date,frequency="d", adjustflag="3")if rs.error_code != '0':print(f"获取数据失败,错误代码: {rs.error_code},错误信息: {rs.error_msg}")bs.logout()return# 打印结果集data_list = []while (rs.error_code == '0') & rs.next():# 获取一条记录,将记录合并在一起data_list.append(rs.get_row_data())result = pd.DataFrame(data_list, columns=rs.fields)try:# 保存到CSV文件result.to_csv(file_path, index=False)print(f"数据已成功保存到 {file_path}")except Exception as e:print(f"保存数据到文件时出错: {e}")# 登出系统bs.logout()if __name__ == "__main__":stock_code = "sh.603236"start_date = "2024-08-22"end_date = "2025-08-22"file_path = "stock_data_603236_short.csv"get_stock_data_and_save_to_csv(stock_code, start_date, end_date, file_path)
%matplotlib inline
from matplotlib import pyplot as plt
fig1 = plt.figure(figsize=(8,5))
plt.plot(price)
plt.title('close price')
plt.xlabel('time')
plt.ylabel('price')
plt.show()

# define X and y 
# define method to extract X and y
def extract_data(data, time_step):X=[]y=[]for i in range(len(data)-time_step):X.append([a for a in data[i:i+time_step]])y.append(data[i+time_step])X = np.array(X)X = X.reshape(X.shape[0], X.shape[1], 1)return X, y#define X and y
time_step = 8
X,y = extract_data(price_norm, time_step)
print(X[0,:,:])
print(y)
print(X.shape, len(y))
#set up the model
from keras.models import Sequential
from keras.layers import Dense, SimpleRNN
model = Sequential()
#add RNN layer
model.add(SimpleRNN(units=5,input_shape=(time_step,1), activation='relu'))
#add output layer
model.add(Dense(units=1,activation='linear'))
#configure the model
model.compile(optimizer='adam', loss = 'mean_squared_error')
model.summary()
#train the model
y = np.array(y)
model.fit(X,y,batch_size=30, epochs=200)
#make prediction base on the traing data
y_train_predict = model.predict(X)*max(price)
y_train = y*max(price)
print(y_train_predict, y_train)
%matplotlib inline
from matplotlib import pyplot as plt
fig2 = plt.figure(figsize=(8,5))
plt.plot(y_train_predict, label = 'predict price',marker='*', markersize=5,markerfacecolor='none')
plt.plot(y_train, label = 'real price',marker='*', markersize=5,markerfacecolor='none')
plt.title('close price')
plt.xlabel('time')
plt.ylabel('price')
plt.legend()
plt.show()

http://www.xdnf.cn/news/1876.html

相关文章:

  • 【硬核干货】SonarQube安全功能
  • 上篇:深入剖析 BLE 底层物理层与链路层(约5000字)
  • FreeRTOS【2】任务、优先级知识重点
  • 【C语言】C语言结构体:从基础到高级特性
  • 深入解析 doas:有望替代 sudo 的极简权限管理工具
  • Dify快速入门之发布应用
  • Trae 编程工具 Cline 插件安装与 Claude 3.7 API Key 自定义配置详解
  • 修改RK3568 UBUNTU开机画面
  • C++ Lambda 表达式
  • 黑马点评商户查询缓存--缓存更新策略
  • shell练习(2)
  • github 简单访问方法(无魔法)
  • 数据库-数据类型、约束 和 DQL语言
  • QComboBox自适应下拉展开区域宽度但控件本身限制宽度
  • leetcode刷题日记——有效的括号
  • IOMUXC_SetPinMux的0,1参数解释
  • Java集合框架解析
  • 【TS入门笔记1---初识TS】
  • A*迷宫寻路
  • 【频谱分析仪与信号分析仪】异同比较
  • 【力扣刷题|第五天作业】二分查找-寻找旋转排序数组中的最小值 II
  • Redis Bitmaps
  • 网络编程!
  • Android 16强制横竖屏设置
  • SQL进阶知识:七、数据库设计
  • 每日英语:每周背10句
  • PyQt6实例_pyqtgraph散点图显示工具_代码分享
  • AI大模型从0到1记录学习 数据结构和算法 day20
  • 分片算法详解:原理、类型与实现方案
  • 链表-两两交换链表中的结点