返回

GNN+PPO: 图神经网络与强化学习控制足式机器人

Ai

GNN + PPO:让足式机器人在 Gymnasium Ant-v4 环境中奔跑

我一直在尝试用图神经网络 (GNN) 和强化学习 (RL) 解决足式机器人的控制问题,主要参考了 NerveNet 这篇论文的方法。我用的环境是 Gymnasium 的 Ant-v4,打算把这个环境建模成一个图,图上的每个节点代表机器人结构中的一个关节。

我选了 Proximal Policy Optimization (PPO) 算法来优化强化学习过程。尽管搞了好几个星期,我还是没能把这些东西很好地融合在一起。主要问题似乎是如何有效地将基于图的模型与强化学习算法结合起来,以提升学习效果。

一、 问题分析:为什么GNN和PPO结合困难?

要把 GNN 和 PPO 结合起来,主要难点有这么几个:

  1. 状态表示的转换: Ant-v4 环境原本的状态是一个连续的向量,而 GNN 需要的是图结构的数据。怎么把原始状态转换成图,并且这个图还能准确反映机器人的状态和环境信息,是个问题。
  2. 网络结构的设计: 要设计一个合适的 GNN 架构,它既能处理图数据,又能输出 PPO 算法需要的动作策略和价值估计。
  3. 梯度回传问题: GNN 处理的 graph 数据, 需要定义如何在GNN上对每个 node 的特征做回传。
  4. 算法的兼容性: PPO 算法和 GNN 的训练方式可能不一样,怎么让它们在一个框架下协同工作,需要仔细考虑。

二、 解决方案:GNN和PPO结合的具体步骤

针对上面的问题,我整理了以下几个步骤和建议,希望能帮到同样遇到困难的朋友们:

1. 将 Ant-v4 环境建模成图

把 Ant-v4 这种物理环境建模成图,核心是找到节点和边的对应关系。我的思路是:

  • 节点: 机器人的每个关节(hinge)作为图中的一个节点。
  • 边: 连接关节的身体部位(比如大腿、小腿)作为边。

这样做的好处是,节点和边的关系直接反映了机器人的物理结构。

代码示例:

import torch as th
from torch_geometric.data import Data
import torch.nn.functional as F

# 预先定义好 Ant 机器人的边连接关系 (根据 Ant-v4 的结构)
edge_index = th.tensor([[0, 1], [0, 3], [0, 5], [0, 7], [1, 2], [3, 4], [5, 6], [7, 8]], dtype=th.long).t().contiguous()

def model_ant_as_graph(state: th.Tensor) -> Data:
    """
    将 Ant-v4 的状态向量转换为图数据。
    
    Args:
        state:  Ant-v4 环境的原始状态 (一维向量, 例如 shape: (27,))。
        
    Returns:
        一个 torch_geometric.data.Data 对象, 表示转换后的图。
    """
    # 如果是批量观测,取第一个
    if state.dim() == 2:
        state = state[0]

    # 根据 Ant-v4 环境文档,确定每个关节和速度数据在状态向量中的位置。
    node_indices = {
        'torso': slice(0, 5),       # 躯干
        'hip_1': slice(5, 6),    'ankle_1': slice(6, 7),  # 第一条腿
        'hip_2': slice(7, 8),    'ankle_2': slice(8, 9),  # 第二条腿
        'hip_3': slice(9, 10),   'ankle_3': slice(10, 11), # 第三条腿
        'hip_4': slice(11, 12),  'ankle_4': slice(12, 13), # 第四条腿
    }
    velocity_indices = {
        'torso': slice(13, 19),
        'hip_1': slice(19, 20), 'ankle_1': slice(20, 21),
        'hip_2': slice(21, 22), 'ankle_2': slice(22, 23),
        'hip_3': slice(23, 24), 'ankle_3': slice(24, 25),
        'hip_4': slice(25, 26), 'ankle_4': slice(26, 27),
    }

    node_features = []
    for key in node_indices.keys():
        # 把关节的位置、角度和速度信息组合起来, 作为节点的特征
        combined_features = th.cat((state[node_indices[key]], state[velocity_indices[key]]))
        # 将特征进行一个Padding。
        padded_features = F.pad(combined_features, (0, 11 - combined_features.size(0)), "constant", 0)
        
        node_features.append(padded_features)

    # 将节点特征合并成一个张量
    node_features_tensor = th.stack(node_features).float()
    # 创建图数据对象
    graph = Data(x=node_features_tensor, edge_index=edge_index, num_nodes=9)
    return graph

代码解释:

  • edge_index 定义了图的连接关系,这里是根据 Ant 机器人的身体结构预先定义好的。
  • model_ant_as_graph 函数接收 Ant-v4 环境的原始状态向量,将其转换为图数据。
  • 根据官方环境的信息, 节点特征包含了关节的位置,和对应的关节速度。
  • torch_geometric.data.Data 对象是 PyTorch Geometric 中表示图数据的标准方式。

2. 设计 GNN 网络结构

有了图数据后,接下来要设计一个 GNN 网络。我用的是 GATv2Conv,因为它在处理图数据时效果不错。

代码示例:

from torch_geometric.nn import GATv2Conv
from torch import nn

class GATv2ConvWrapper(nn.Module):
    """
    一个简单的包装器,封装 GATv2Conv 层, 并固定 edge_index。
    """
    def __init__(self, in_channels, out_channels, edge_index):
        super(GATv2ConvWrapper, self).__init__()
        self.conv = GATv2Conv(in_channels, out_channels)
        self.edge_index = edge_index  # 保存 edge_index

    def forward(self, x):
        return self.conv(x, self.edge_index) # 用固定的 edge_index

class CustomNetwork(nn.Module):
    """
    自定义的 GNN 网络,用于策略网络和价值网络。
    """
    def __init__(
        self,
        feature_dim: int = 11,  #每个节点有11个特征, 不固定长度则会产生问题
        last_layer_dim_pi: int = 64,
        last_layer_dim_vf: int = 64,
    ):
        super().__init__()

        self.latent_dim_pi = last_layer_dim_pi
        self.latent_dim_vf = last_layer_dim_vf

        # 策略网络 (policy network)
        self.policy_net = nn.Sequential(
            GATv2ConvWrapper(feature_dim, last_layer_dim_pi, edge_index),
            nn.BatchNorm1d(64),
            GATv2ConvWrapper(last_layer_dim_pi, last_layer_dim_pi, edge_index),
            nn.BatchNorm1d(64),
        )

        # 价值网络 (value network)
        self.value_net = nn.Sequential(
            GATv2ConvWrapper(feature_dim, last_layer_dim_pi, edge_index),
            nn.BatchNorm1d(64),
            GATv2ConvWrapper(last_layer_dim_pi, last_layer_dim_pi, edge_index),
            nn.BatchNorm1d(64),
        )
        # 动作的 Log Probs
        self.log_std = nn.Parameter(th.zeros(1, 8))

    def forward(self, features):

        actions = self.forward_actor(features).mean(dim=1)
        values =  self.forward_critic(features).mean(dim=1)
        return actions, values

    def forward_actor(self, features: Data):
        """
        策略网络的前向传播。
        """
        x, edge_index = features.x, features.edge_index # 获得data里面的节点和索引。
        return self.policy_net(x)

    def forward_critic(self, features: Data):
        """
        价值网络的前向传播。
        """
        x, edge_index = features.x, features.edge_index # 获得data里面的节点和索引。
        return self.value_net(x)

代码解释:

  • GATv2ConvWrapper 包装了 GATv2Conv 层,并固定了 edge_index,这样在网络中就不用每次都传入 edge_index 了。
  • CustomNetwork 定义了整个 GNN 网络结构,包括策略网络和价值网络。
  • 这里用了两个 GATv2Conv 层,每层后面加了 BatchNorm1d 来加速训练。

3. 与 PPO 算法集成

最关键的一步是把 GNN 网络和 PPO 算法集成起来。我用的是 stable-baselines3 这个库,它提供了 PPO 算法的实现。

代码示例:

from stable_baselines3 import PPO
from stable_baselines3.common.policies import ActorCriticPolicy
from gymnasium import spaces
from typing import Callable

class CustomActorCriticPolicy(ActorCriticPolicy):
    """
    自定义策略类,将 GNN 集成到 Actor-Critic 框架中。
    """
    def __init__(
        self,
        observation_space: spaces.Space,
        action_space: spaces.Space,
        lr_schedule: Callable[[float], float],
        *args,
        **kwargs,
    ):
        #  正交初始化会让算法更难学习到有效内容, 我们不需要
        kwargs["ortho_init"] = False
        super().__init__(
            observation_space,
            action_space,
            lr_schedule,
            *args,
            **kwargs,
        )

    def forward(self, obs):
        #  把观察转换为 graph 数据
        graph_data = model_ant_as_graph(obs)
        actions, values = self.mlp_extractor(graph_data)
        
        return actions, values

    def _build_mlp_extractor(self) -> None:
        #  构建 GNN 网络
        self.mlp_extractor = CustomNetwork(11)

# 创建 Ant-v4 环境
env = gym.make("Ant-v4")

# 创建 PPO 模型,使用自定义的策略网络
model = PPO(CustomActorCriticPolicy, env, verbose=1)

# 打印策略网络的结构,确认一下
print(model.policy)

代码解释:

  • CustomActorCriticPolicy 继承了 stable-baselines3ActorCriticPolicy 类,这是自定义策略网络的关键。
  • _build_mlp_extractor 方法中,创建了我们设计的 CustomNetwork
  • forward 方法中,先把原始观察 obs 转换成图数据 graph_data,然后传给 GNN 网络,得到动作 actions 和价值估计 values

4. 安全建议(可选项)

在使用强化学习和 GNN 时, 有一些问题是通用的:

  • 超参数调整: PPO 算法和 GNN 都有很多超参数,要仔细调整才能达到好的效果。
  • 探索和利用: 强化学习需要在探索(尝试新的动作)和利用(选择已知的最佳动作)之间找到平衡。
  • 环境的复杂性: Ant-v4 是一个比较复杂的环境,可能需要更长的训练时间。

5. 进阶使用方法(可选项)

除了以上介绍的步骤外,对于高级使用者, 还有以下提升的技巧:

  • Edge Features: 为 edge 也加入 features。 例如定义前后,左右腿等 features,以加快学习。
  • Temporal-GNNs : 将循环网络加入 GNN 中, 学习更高层面的控制关系。
  • Meta Learning + RL : 使用类似 MAML 的方法,将 GNN 作为其中的一个组件进行元学习。

三、 总结

将GNN 与 PPO 结合是一个较新的方法, 这个方法最大的优势就是直接把机器人的结构信息编码到神经网络中。 这对于需要精细控制的机器人运动来说是有用的。当然还有很多可以探索的地方。如果大家有什么好的想法或者实践经验,欢迎一起交流。