vllm.model_executor.models.hrm_text ¶
HRM-Text: Hierarchical Reasoning Model — Text variant.
Reference Hugging Face implementation
src/transformers/models/hrm_text/modeling_hrm_text.py
The model performs a hierarchical recurrent forward over two transformer stacks (H slow, L fast) inside nested loops. Each recurrence step gets its own KV cache slot via a unique vLLM-visible layer index. The PrefixLM attention pattern (prompt bidirectional, response causal) is realized by reusing EncoderOnlyAttention (which sets causal=False unconditionally on every metadata build) but with attn_type=DECODER so the KV cache is allocated; see HrmTextAttention for usage.
The on-disk attn.gqkv_proj.weight (rows concatenated as [gate | q | k | v]) is loaded by a single MergedColumnParallelLinear with four equal-sized output partitions; its weight loader auto-splits the fused tensor along the output dim by output_sizes (the same path used by Phi-3's fused gate_up_proj).
Classes:
-
HrmTextAttention–One self-attention block; weights shared across recurrence steps.
-
HrmTextForCausalLM–Hierarchical Reasoning Model — Text variant, causal LM.
-
HrmTextModel–Hierarchical recurrent transformer body.
-
HrmTextStack–A single transformer stack — used twice (H and L).
HrmTextAttention ¶
Bases: Module
One self-attention block; weights shared across recurrence steps.
HF transformers writes a single fused attn.gqkv_proj.weight on disk (per transformers/conversion_mapping.py "hrm_text" mapping; rows are concatenated as [gate | q | k | v] along dim=0). We mirror that on the model side with a single MergedColumnParallelLinear whose four equal output partitions are sharded along the head axis under TP; its weight loader auto-splits the fused tensor (same path used by Phi-3's fused gate_up_proj). HF's runtime config currently hardcodes MHA (num_key_value_groups=1); GQA would require QKVParallelLinear semantics for q/k/v shard replication and is left for a follow-up if/when HF adds it.
Holds
- parameters: gqkv_proj, o_proj, rotary_emb (shared across cycles).
attn_per_step: ann.ModuleDictkeyed by recurrence step (as a string), each value anEncoderOnlyAttention(withattn_type=DECODERso the KV cache is allocated; theEncoderOnlyAttentionwrapper setscausal=Falseon every metadata build). The L stack steps are[high_cycle_idx*(L_cycles+1)+low_cycle_idx]and the H stack steps are[high_cycle_idx*(L_cycles+1)+L_cycles]; the two ranges are disjoint so each instance registers a unique vLLMlayer_name(model.{H,L}_module.layers.{global_idx}.self_attn) and gets its own KV cache slot. The global layer index per recurrence step isstep * num_layers_per_stack + layer_idx_in_stack, matching the HF transformerscycle_offsetformula inmodeling_hrm_text.py.
Source code in vllm/model_executor/models/hrm_text.py
91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 | |
HrmTextForCausalLM ¶
Bases: Module
Hierarchical Reasoning Model — Text variant, causal LM.
Reference: src/transformers/models/hrm_text/modeling_hrm_text.py
Source code in vllm/model_executor/models/hrm_text.py
HrmTextModel ¶
Bases: Module
Hierarchical recurrent transformer body.
Forward (matches HF main exactly, src/transformers/models/hrm_text/modeling_hrm_text.py:495-547):
hidden_states_high_cycle = embed(input_ids) * embedding_scale
hidden_states_low_cycle = z_L_init.expand_as(hidden_states_high_cycle)
for high_cycle_idx in range(H_cycles):
for low_cycle_idx in range(L_cycles):
step = high_cycle_idx * (L_cycles + 1) + low_cycle_idx
hidden_states_low_cycle = L_module(
hidden_states_low_cycle + hidden_states_high_cycle,
current_step=step,
)
step = high_cycle_idx * (L_cycles + 1) + L_cycles
hidden_states_high_cycle = H_module(
hidden_states_high_cycle + hidden_states_low_cycle,
current_step=step,
)
return hidden_states_high_cycle
Source code in vllm/model_executor/models/hrm_text.py
353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 | |
HrmTextStack ¶
Bases: Module
A single transformer stack — used twice (H and L).