MCPcopy Index your code
hub / github.com/pytorch/rl

github.com/pytorch/rl @v0.13.2 sqlite

repository ↗ · DeepWiki ↗ · release v0.13.2 ↗
14,391 symbols 79,404 edges 678 files 4,625 documented · 32%
README

Unit-tests Nightly Documentation Benchmarks codecov Flaky Tests X / Twitter Follow Python version GitHub license pypi version pypi nightly version Downloads Downloads Discord Shield

TorchRL

TorchRL logo

TorchRL is a PyTorch-native toolkit for reinforcement learning, decision making, robotics, and simulation. It is not a single algorithm implementation or a narrow benchmark suite: it is a collection of composable pieces for building RL systems while keeping the code close to the PyTorch programming model. Recent work has made this especially strong for recurrent RL, MuJoCo-based control, multi-agent training, replay-buffer and collector infrastructure, and reusable loss/value-estimation components.

The library is built around three ideas:

  1. Data should have names, structure, batch dimensions, and devices all the way through the training loop.
  2. Environments, policies, replay buffers, objectives, and collectors should be independent modules that can be swapped without rewriting the rest of the stack.
  3. Research code should scale from a local prototype to vectorized, multiprocess, distributed, compiled, recurrent, multi-agent, model-based, or offline workflows without changing the data model.

That common data model is TensorDict, a dictionary-like tensor container with PyTorch operations, device transfers, shared-memory support, memmaps, lazy views, and nn.Module wrappers.

Getting started | API reference | Tutorials | Knowledge base | Examples | SOTA implementations

Recent highlights

TorchRL 0.13 and the preceding development cycle bring several user-visible improvements that are worth surfacing up front:

  • faster recurrent RL paths, including scan and Triton GRU/LSTM reset handling;
  • custom MuJoCo environments, satellite examples, and macro-control policies;
  • stronger multi-agent coverage through MAPPO, IPPO, MultiAgentGAE, value-normalization utilities, and mixer configs;
  • better collector and replay-buffer ergonomics, including async prioritized writes, ordered storage access, compact observations, HER, and optional CUDA wheels for CUDA-based prioritized replay-buffer kernels;
  • new transforms and value-estimator improvements such as ActionScaling, FlattenAction, NextObservationDelta, compact shifted estimators, and chunked forwards.

A quick mental model

TorchRL represents an RL interaction as a TensorDict that moves through a small number of reusable components:

TensorDict
  -> policy module writes actions and log-probs
  -> environment reads actions and writes next observations, rewards, done flags
  -> collector batches trajectories from one or many workers
  -> replay buffer stores, samples, prioritizes, and transforms data
  -> loss module reads named keys and writes differentiable losses
  -> optimizer updates ordinary PyTorch parameters

The same object can carry observations, pixels, actions, rewards, masks, recurrent states, agent groups, sampled indices, priorities, or custom task fields. The result is less glue code and fewer hidden assumptions about what each algorithm or environment returns.

Quick demo

A local rollout is just a TensorDict passed between a PyTorch module and an environment:

import torch
from tensordict.nn import TensorDictModule
from torch import nn

from torchrl.envs import PendulumEnv, StepCounter, TransformedEnv

# A PyTorch-native environment with an ordinary transform stack.
env = TransformedEnv(PendulumEnv(), StepCounter(max_steps=200))

# Policies are regular nn.Modules wrapped with explicit TensorDict keys.
policy = TensorDictModule(
    nn.Sequential(
        nn.LazyLinear(64),
        nn.Tanh(),
        nn.Linear(64, 1),
        nn.Tanh(),
    ),
    in_keys=["observation"],
    out_keys=["action"],
)

rollout = env.rollout(max_steps=32, policy=policy)
assert rollout.batch_size == torch.Size([32])
assert rollout["next", "reward"].shape[:1] == torch.Size([32])

Nothing in this pattern is specific to Pendulum. The same keys-and-TensorDict interface is used by batched environments, multi-agent tasks, collectors, replay buffers, recurrent modules, transforms, and losses.

What TorchRL is today

TensorDict-first pipelines

RL code tends to accumulate special cases: tuples from one environment, dicts from another, separate arrays for recurrent states, masks next to data rather than inside it, and losses that silently assume a particular batch layout. TorchRL uses TensorDict to make those assumptions explicit.

TensorDict supports common tensor operations while preserving named fields:

# These operations preserve the structure and operate on every compatible value.
batch = torch.stack(list_of_tensordicts, dim=0)
batch = batch.reshape(-1)
batch = batch.to("cuda")
mini_batch = batch[:128]

# Nested keys make multi-agent, recurrent, and next-state data explicit.
reward = batch["next", "reward"]
agent_obs = batch["agents", "observation"]
hidden = batch["recurrent_state", "h"]

This is the reason TorchRL components compose: a collector can emit a TensorDict, a replay buffer can store it without losing structure, a transform can add or remove keys, and a loss can read exactly the keys it needs.

Environments and transforms

TorchRL includes native environments, wrappers for popular environment libraries, and vectorized containers for running many environments at once. The environment API exposes specs for observations, actions, rewards, and done flags, so policies and transforms can check shapes, devices, dtypes, and bounds before a training job runs for hours.

Environment support includes:

  • PyTorch-native environments such as PendulumEnv and custom MuJoCo tasks.
  • Wrappers for Gymnasium, Gym, DM Control, Brax, Jumanji, PettingZoo, VMAS, OpenSpiel, Safety-Gymnasium, Isaac Lab, and other optional libraries.
  • SerialEnv, ParallelEnv, and batched wrappers for local vectorization and multiprocessing.
  • Environment transforms for observation normalization, image conversion, reward transforms, action masking, action scaling, auto-reset, frame stacking, state reconstruction, and more.

Transforms are first-class TorchRL modules. They can run on-device, participate in specs, and be inserted, removed, or composed without wrapping the whole environment in opaque adapter layers.

from torchrl.envs import Compose, DoubleToFloat, ObservationNorm, TransformedEnv
from torchrl.envs.libs.gym import GymEnv

base_env = GymEnv("HalfCheetah-v4", device="cuda:0")
env = TransformedEnv(
    base_env,
    Compose(
        ObservationNorm(in_keys=["observation"]),
        DoubleToFloat(),
    ),
)

Collectors and execution models

Collectors are the bridge between policies and environments. A collector owns the execution loop, batches trajectories, handles devices, and can update policy weights while environments keep running.

TorchRL includes single-process, async, multiprocess, and distributed collectors. This lets the same policy and loss code be used across small smoke tests, GPU-heavy simulation, CPU environment farms, or asynchronous evaluation setups.

from torchrl.collectors import Collector

collector = Collector(
    create_env_fn=env,
    policy=policy,
    frames_per_batch=1024,
    total_frames=1_000_000,
)

for data in collector:
    # data is a TensorDict with time, environment, and key structure preserved.
    train_step(data)

For larger jobs, the collector family adds async execution, multiple worker processes, weight updaters, evaluator loops, profiling hooks, and fake-data helpers for testing downstream code without stepping an expensive environment.

Replay buffers and offline data

TorchRL replay buffers are modular: storage, sampler, writer, collate function, transforms, prefetching, priority updates, and device movement are separate pieces. That makes it possible to use the same interface for simple in-memory replay, memmap-backed storage, prioritized replay, CUDA-aware sampling, offline datasets, HER, or custom storage layouts.

from torchrl.data import LazyMemmapStorage, TensorDictPrioritizedReplayBuffer

buffer = TensorDictPrioritizedReplayBuffer(
    storage=LazyMemmapStorage(1_000_000),
    alpha=0.7,
    beta=0.5,
    batch_size=256,
    prefetch=2,
)

buffer.extend(collector_batch)
sample = buffer.sample()

Replay buffers understand TensorDict structure, so they can store trajectories, nested agent data, recurrent states, HER relabeling metadata, or offline datasets without flattening everything into parallel Python containers.

Modules, distributions, and policies

TorchRL modules are ordinary PyTorch modules with explicit input and output keys. The library provides actors, critics, actor-critic operators, recurrent modules, distribution wrappers, exploration modules, world models, decision transformers, robot-learning models, and helper utilities for inferring specs from environments.

A stochastic actor can be assembled from familiar PyTorch layers:

from tensordict.nn import TensorDictModule
from tensordict.nn.distributions import NormalParamExtractor
from torch import nn
from torchrl.modules import ProbabilisticActor, TanhNormal

params = TensorDictModule(
    nn.Sequential(
        nn.LazyLinear(256),
        nn.Tanh(),
        nn.Linear(256, 2),
        NormalParamExtractor(),
    ),
    in_keys=["observation"],
    out_keys=["loc", "scale"],
)

actor = ProbabilisticActor(
    params,
    in_keys=["loc", "scale"],
    out_keys=["action"],
    distribution_class=TanhNormal,
    distribution_kwargs={"low": -1.0, "high": 1.0},
    return_log_prob=True,
)

The explicit key contract makes it clear what data a module consumes and produces, and it allows losses, collectors, and transforms to be reconfigured without editing the model itself.

Objectives, returns, and trainers

TorchRL objectives are loss modules that read TensorDict keys, compute losses, and expose configurable key mappings. They cover policy-gradient methods, actor-critic algorithms, Q-learning, offline RL, imitation learning, model-based RL, and multi-agent RL.

Examples include PPO, SAC, DQN, TD3, REDQ, IQL, CQL, Decision Transformer, Dreamer, CrossQ, GAIL, behavior cloning, ACT, MAPPO, IPPO, and QMIX/VDN. Value-estimator utilities provide GAE, TD(lambda), V-trace, lambda returns, multi-agent advantages, and vectorized return computation.

from torchrl.objectives import ClipPPOLoss
from torchrl.objectives.value import GAE

loss = ClipPPOLoss(actor_network=actor, critic_network=critic)
advantage = GAE(value_network=critic, gamma=0.99, lmbda=0.95)

data = advantage(data)
losses = loss(data)
loss_value = losses["loss_objective"] + losses["loss_critic"] + losses["loss_entropy"]

For higher-level workflows, TorchRL also provides trainer utilities and Hydra configuration dataclasses that assemble environments, networks, collectors, losses, optimizers, loggers, hooks, and schedules into reproducible recipes.

Multi-agent, model-based, and imitation learning

Multi-agent data is represented as TensorDict structure rather than a separate parallel convention. Agent observations, actions, rewards, masks, and shared state can live under nested keys such as ("agents", "observation"), while losses and modules declare which keys they use.

TorchRL supports multi-agent environments and algorithms through VMAS, PettingZoo, Melting Pot, SMACv2, OpenSpiel, multi-agent trainers, and dedicated objectives. The 0.13 line adds MAPPO, IPPO, MultiAgentGAE, ValueNorm, PopArtValueNorm, RunningValueNorm, and cross-agent critic utilities.

The same component style also covers model-based and imitation-learning work: Dreamer/DreamerV3 objectives and RSSM modules, Decision Transformer components, behavior cloning losses, and ACT-style action chunking all share the same TensorDict and key-dispatch conve

Core symbols most depended-on inside this repo

keys
called by 1377
torchrl/data/tensor_specs.py
zeros
called by 1057
torchrl/data/tensor_specs.py
clone
called by 940
torchrl/data/tensor_specs.py
get
called by 820
torchrl/services/base.py
append
called by 697
torchrl/data/llm/history.py
to
called by 679
torchrl/envs/common.py
unsqueeze
called by 667
torchrl/data/tensor_specs.py
set
called by 665
torchrl/data/tensor_specs.py

Shape

Method 10,531
Function 1,781
Class 1,760
Route 319

Languages

Python100%

Modules by API surface

torchrl/data/tensor_specs.py475 symbols
test/test_collectors.py317 symbols
test/test_specs.py274 symbols
torchrl/testing/mocking_classes.py199 symbols
test/transforms/test_action_transforms.py196 symbols
test/transforms/test_observation_transforms.py191 symbols
test/transforms/test_env_transforms.py183 symbols
torchrl/data/replay_buffers/storages.py161 symbols
torchrl/trainers/trainers.py156 symbols
torchrl/envs/llm/reward/ifeval/_instructions.py153 symbols
torchrl/envs/common.py151 symbols
torchrl/envs/transforms/_base.py148 symbols

Dependencies from manifests, versioned

Jinja23.1.4 · 1×
accelerate1.7.0 · 1×
ale-py0.9.0 · 1×
bitsandbytes0.46.0 · 1×
datasets3.6.0 · 1×
dm_control1.0.41 · 1×
hoptorch0.1.4 · 1×
hydra-core1.3.2 · 1×
immutabledict4.2.1 · 1×
langdetect1.0.9 · 1×
mujoco3.8.1 · 1×

For agents

$ claude mcp add rl \
  -- python -m otcore.mcp_server <graph>

⬇ download graph artifact