Back to blog
Research

Tiny Recursive Model (TRM)

10 min read

Original Paper by Alexia Jolicoeur-Martineu

I think this is a brilliant piece of work. Especially when juxtaposed with HRM. In the Zen of Python, there's a line that says, "Simple is better than Complex. Complex is better than complicated". Well, this is a significant simplification to the ideas put forth in HRM. Hats off to the author Alexia Jolicoeur-Martineu.

TLDR

If you remember 3 things
  1. TRM replaces "fixed-point theory + 1-step grad" with "unroll recursion (final pass) + no-grad refinement passes."
  2. TRM replaces ACT/Q-learning halting with a simple BCE halting head (no extra forward pass).
  3. The whole model reduces to two states: current solution (y) + latent reasoning (z).
What to implement first

The core engine of TRM to grok is latent_recursion and deep_recursion (See Figure 1). Make sure to pay close attention to where torch.no_grad() and .detach() are called.

Where folks might mess up

Mixing up detach points (accidentally backprop through deep supervision), mis-implementing the no-grad passes, or treating the halting logit incorrectly (thresholding the wrong thing).

Hierarchical Reasoning Model (HRM): Overview and issues

Refer to my blog post on HRM for more details on the concept.

Summary

  • Recursive refinement + deep supervision is the real juice
  • HRM's gradient story is DEQ-ish / 1-step-ish and arguably shaky
  • ACT adds complexity + extra forward pass

Recursive Refinement + Deep supervision

This is where all the magic in HRM happens.

Recursive refinement refines two latents, zL and zH, recursively using two, 4-layer transformer networks, and .

Deep supervision is iterating over a single batch sample within the training loop number of times, repeatedly supervising the prediction, updating the latent , and passing the updated latent to hrm on the next go around. See Figure 1 in the HRM post.

Deep Equilibrium (DEQ) and 1-step-gradient approximation

All you need to know here is the following. DEQ is loosely the idea of iterating repeatedly until a result reaches the optimal point, called an equilibrium. 1-step-gradient approximation says that only the learnings from the equilibrium point are needed to optimize the network, we can discard the rest. HRM is motivated by DEQ/1-step-gradient approximation to justify their architecture.

Concern

The author of TRM says the way HRM is designed doesn't guarantee a fixed point is ever reached: there aren't enough recursions steps and there is no convergence verification-TRM removes this by backpropagating through the full unrolled recursion (for the final pass) while using no-grad passes for refinement.

ACT adds complexity

ACT is what HRM uses to early-stop deep supervision. Otherwise, additional supervision steps are run during training, increasing training time.

Concern

The author of TRM notes that ACT is implemented by running an additional forward pass (call to hrm) on each deep supervision step. That's two calls to hrm every iteration.

Also, personally, the implementation of ACT is hard to grok. Namely, ACT_continue, which TRM removes all together.

Tiny Recursive Model (TRM)

Figure 1 (From the paper)

def latent_recursion(x, y, z, n=6):
  """
  y = zH
  z = zL
  """
  for i in range(n): # refine latent reasoning
    z = net(x, y, z)
  y = net(y, z) # refine answer
  return y, z

def deep_recursion(x, y, z, n=6, T=3):
  """
  The key idea: do multiple refinement passes without grads
  to improve state cheaply, then one final unrolled pass with grads to train.

  y = zH; z = zL
  """
  with torch.no_grad(): # No gradients
    for j in range(T-1):
      y, z = latent_recursion(x, y, z, n)
  y, z = latent_recursion(x, y, z, n) # gradients
  return (y.detach(), z.detach()), output_head(y), Q_head(y)

# Deep supervision
for x_input, y_true in train_dl:
  y, z = y_init, z_init
  for step in range(N_supervision):
    x = input_embedding(x_input)
    (y, z), y_hat, q_hat = deep_recursion(x, y, z)
    loss = softmax_cross_entropy(y_hat, y_true)
    loss += bce_with_logit(q_hat, (y_hat == y_true))

    loss.backward()
    opt.step()
    opt.zero_grad()

    # q_hat is a logit; q_hat > 0 ⇔ sigmoid(q_hat) > 0.5**
    if q_hat > 0:
      break

**My Notes: Because is a logit, is exactly the decision boundary; other thresholds would just trade compute for confidence via .

Say we define a threshold such that . And since , we get

So we can change our confidence by varying .

Data

(Note: I'm inferring shapes in this section. I didn't confirm against the repo)

Just a quick note on this. The datasets are Sudoku-Extreme, Maze-Hard, and ARC-AGI I & II. Consider them all square puzzle grids of initial shape MxM. A batch looks like (B, L=MxM)

  • Sudoku-Extreme: a batch looks like
  • Maze-Hard: a batch looks like
  • ARC-AGI: This one is trickier. The paper says a single training sample has multiple examples. Then it says they are serialized into a single input sequence. So I'm inferring is the size of all the flattened examples concatenated together.

Then if is our input, x = input_embedding(x) (Figure 1) is called, turning each token into an embedding. Then the data becomes shape (B, L, D).

Improvements

  • No fixed point theorem
  • ACT: No additional forward pass
  • No hierarchical features + single network

No fixed point theorem

HRM only backpropogates through the last two function evaluations of and , which the author of TRM points out is highly unlikely for zL and zH to reach a fixed point. Thus the backpropogating at an equilibrium is no longer justified.

TRM solves this by doing the following. 1) defining backpropagating over number of recursion steps as defined in latent_recursion, not just the last two steps. 2) Calling latent_recursion T-1 times, updating zL and zH each time with no gradients. Then calling latent_recursion one final time with gradients.

latent_recursion over steps removes the need for a 1-step gradient approximation, and calling latent_recursion times inside of deep_recursion allows zH and zL to be improved recusively without backpropagation.

So latent_recursion is a module you can deploy with/without gradients attached to iteratively refine an answer. Call it as many times as you want without gradients to refine your answer. Then as soon as you want to update weights, attach gradients.

ACT: no additional forward pass

TRM fixes the additional forward pass simply by only learning a halt probability, thus removing the complicated next pass prediction and continue path.

No hierarchical features + single network

TRM reinterprets as and as , removing the need to create to latent embeddings from separate networks ( and ), and reducing two networks down to a single network.

Attention free architecture

For tasks where , like Sudoku-Extreme, TRM uses an MLP-based architecture over a transformer and improves performance from 74.7% to 87.4%.

Also, just a quick implementation detail. An MLP-mixer (token mixer), which could (I'm inferring not quoting the paper) capture global context across tokens would look something like this:

x_t = x.transpose(1, 2)          # (B, D, L) - tokens are now the last dim
x_t = nn.Linear(L, L)(x_t)       # mixes across tokens, independently per channel
x   = x_t.transpose(1, 2)        # back to (B, L, D)

TRM mentions that MLP works with Sudoku because , but produces suboptimal results on Maze-Hard because . This is for a couple of reasons. First, 900 is very large compared to the small dataset size. For an MLP-mixer, the number of parameters scales with the number of tokens . This can hurt the model's ability to generalize. TRM uses self-attention to mix tokens for large because the number of parameters in self-attention scales with instead of .

Remaining improvements

  • Less layers: More MLP layers degrades performance on Sudoku-Extreme.
  • EMA of weights: To mitigate the tendency of HRM to overfit on small datasets, TRM uses Exponential Moving Average (EMA) of the weights.

OOM Issue

The author notes time complexity and memory issues of TRM

  1. Increasing or leads to "massive" slowdowns because of the complexity of recursion nested in deep supervision
  2. Increasing can lead to OOM errors because TRM is backpropagating through through the full recursion graph (See Figure 1).

Results

Improved

  • Sota test accuracy on Sudoku-Extreme and Maze-Hard
  • ARC-AGI-1 from
  • ARC-AGI-2 from

My notes on the results:

  • Small datasets → overfitting pressure → why EMA & data augmentation matters
  • Maze/ARC need global context → why attention may dominate there
  • Sudoku has structure that MLP can exploit → why attention might be unnecessary for Sudoku

My Approach

If I were implementing TRM this week

  • Start with Figure 1 exactly: verify detach/no_grad boundaries before touching architecture.
  • Treat halting head output as a logit; threshold at 0 (0.5 prob).
  • Expect OOM as you increase n; consider gradient checkpointing or smaller batch.
  • Expect diminishing returns for bigger MLP depth on Sudoku-Extreme; keep it shallow.
  • Use EMA from day 1; small data will overfit fast without it.
  • For Maze/ARC, don't assume attention-free will hold; test attention early.

Ablations I'd run first (predictions)

These are the ablations I would run to get a feel for the model

  • Remove deep supervision (N_sup=1) → big drop
  • Vary n vs T → which saturates first?
  • Remove EMA → instability / worse generalization
  • Swap MLP ↔ attention across Sudoku vs Maze/ARC → attention helps on larger-context tasks