当前位置: 首页 > news >正文

egpo进行train_egpo训练时,keyvalueError:“replay_sequence_length“

def execution_plan(workers: WorkerSet,
config: TrainerConfigDict) -> LocalIterator[dict]:
if config.get(“prioritized_replay”):
prio_args = {
“prioritized_replay_alpha”: config[“prioritized_replay_alpha”],
“prioritized_replay_beta”: config[“prioritized_replay_beta”],
“prioritized_replay_eps”: config[“prioritized_replay_eps”],
}
else:
prio_args = {}

local_replay_buffer = LocalReplayBuffer(num_shards=1,learning_starts=config["learning_starts"],buffer_size=config["buffer_size"],replay_batch_size=config["train_batch_size"],replay_mode=config["multiagent"]["replay_mode"],#这一行需要注释掉,如果不注释掉,整个代码就跑不起来,可能是因为ray1.4.1版本没有这个参数# replay_sequence_length=config["replay_sequence_length"],**prio_args)rollouts = ParallelRollouts(workers, mode="bulk_sync")# Update penalty
rollouts = rollouts.for_each(UpdateSaverPenalty(workers))
# We execute the following steps concurrently:
# (1) Generate rollouts and store them in our local replay buffer. Calling
# next() on store_op drives this.
store_op = rollouts.for_each(StoreToReplayBuffer(local_buffer=local_replay_buffer))def update_prio(item):samples, info_dict = itemif config.get("prioritized_replay"):prio_dict = {}for policy_id, info in info_dict.items():# TODO(sven): This is currently structured differently for#  torch/tf. Clean up these results/info dicts across#  policies (note: fixing this in torch_policy.py will#  break e.g. DDPPO!).td_error = info.get("td_error",info[LEARNER_STATS_KEY].get("td_error"))prio_dict[policy_id] = (samples.policy_batches[policy_id].data.get("batch_indexes"), td_error)local_replay_buffer.update_priorities(prio_dict)return info_dict# (2) Read and train on experiences from the replay buffer. Every batch
# returned from the LocalReplay() iterator is passed to TrainOneStep to
# take a SGD step, and then we decide whether to update the target network.
post_fn = config.get("before_learn_on_batch") or (lambda b, *a: b)
replay_op = Replay(local_buffer=local_replay_buffer) \.for_each(lambda x: post_fn(x, workers, config)) \.for_each(TrainOneStep(workers)) \.for_each(update_prio) \.for_each(UpdateTargetNetwork(workers, config["target_network_update_freq"]))# Alternate deterministically between (1) and (2). Only return the output
# of (2) since training metrics are not available until (2) runs.
train_op = Concurrently([store_op, replay_op],mode="round_robin",output_indexes=[1],round_robin_weights=calculate_rr_weights(config))return StandardMetricsReporting(train_op, workers, config)
http://www.xdnf.cn/news/439255.html

相关文章:

  • GoogleTest:GMock2 EXPECT_CALL
  • 数据结构基础排序算法
  • 【MySQL 基础篇】深入解析MySQL逻辑架构与查询执行流程
  • 【Ansys 2023 R2 Icepak】热管模型
  • 武汉科技大学人工智能与演化计算实验室许志伟课题组参加2025中国膜计算论坛
  • 【PostgreSQL数据分析实战:从数据清洗到可视化全流程】附录-B. 错误代码与解决方案
  • 论文阅读笔记——双流网络
  • 从阿里SDK学习请求-响应模式
  • 【Python】抽象基类ABC
  • [论文阅读]Formalizing and Benchmarking Prompt Injection Attacks and Defenses
  • 构建现代化WPF应用:数据驱动开发与高级特性解析
  • LeetCode 热题 100 230. 二叉搜索树中第 K 小的元素
  • 多模态论文笔记——NaViT
  • 2005-2022年各省绿色信贷水平测算数据(含原始数据+计算过程+计算结果)
  • 《AI大模型应知应会100篇》第61篇:FastAPI搭建大模型API服务
  • Vue3 区分开发环境与生产环境
  • PostgreSQL常用DML操作的锁类型归纳
  • 搜索二维矩阵 II
  • 【达梦数据库】超出全局hash join空间问题处理
  • 生活实用小工具-手机号归属地查询
  • PaddleNLP框架训练模型:使用SwanLab教程
  • 养生:拥抱健康生活的实用之道
  • URP相机如何将场景渲染定帧模糊绘制
  • PyTorch中mean(dim=1)的深度解析
  • P2168 NOI2015 荷马史诗
  • Kubernetes排错(十七) :kubelet日志报device or resource busy
  • 【机器人】复现 SG-Nav 具身导航 | 零样本对象导航的 在线3D场景图提示
  • ​​开放传神创始人论道AI未来|“广发证券—国信中数人工智能赛道专家交流论坛“落幕
  • MySQL——九、锁
  • 【Linux】Ext系列文件系统