vllm_gaudi.ops.pytorch_implementation
¶
new_chunk_cumsum
¶
new_chunk_cumsum(
dt,
A,
chunk_size,
dt_bias=None,
dt_softplus=False,
dt_limit=(0.0, float("inf")),
padding_mask=None,
)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dt
|
Tensor - (seqlen, nheads) |
required | |
A
|
Tensor - (nheads) |
required | |
chunk_size
|
int |
required | |
dt_bias
|
Optional Tensor - (nheads) |
None
|
|
dt_softplus
|
bool |
False
|
|
dt_limit
|
tuple - (min: float, max: float) |
(0.0, float('inf'))
|
|
padding_mask
|
Optional Tensor - (seqlen, 1) or (seqlen,) |
None
|
Return
dA_cumsum: Tensor - (nheads, nchunks, chunk_size) dt_out: Tensor - (nheads, nchunks, chunk_size)
Source code in vllm_gaudi/ops/pytorch_implementation.py
new_chunk_scan
¶
new_chunk_scan(
cb,
x_chunked,
dt_t,
dA_cumsum_t,
C,
states,
output,
D=None,
z=None,
initial_states=None,
)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
cb
|
Tensor - (nchunks, ngroups, chunk_size, chunk_size) - already causally masked |
required | |
x_chunked
|
Tensor - pre-chunked x (nchunks, chunk_size, nheads, hdim) |
required | |
dt_t
|
Tensor - pre-transposed dt (nchunks, nheads, chunk_size), float32 |
required | |
dA_cumsum_t
|
Tensor - pre-transposed dA_cumsum (nchunks, nheads, chunk_size), float32 |
required | |
C
|
Tensor - (seqlen, ngroups, dstate) |
required | |
states
|
Tensor - (nchunks, nheads, hdim, dstate) |
required | |
output
|
Tensor - (seqlen, nheads, hdim) |
required | |
D
|
Optional Tensor - (nheads, hdim) or (nheads) |
None
|
|
z
|
Optional Tensor - (seqlen, nheads, hdim) |
None
|
|
initial_states
|
Optional Tensor - (1, nheads, hdim, dstate) |
None
|
Return
output: Tensor - (seqlen, nheads, hdim)
Source code in vllm_gaudi/ops/pytorch_implementation.py
new_chunk_state
¶
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
B_expanded
|
Tensor - pre-expanded B (nchunks, chunk_size, nheads, dstate) |
required | |
x_chunked
|
Tensor - pre-chunked x (nchunks, chunk_size, nheads, hdim) |
required | |
dt_t
|
Tensor - pre-transposed dt (nchunks, nheads, chunk_size), float32 |
required | |
dA_cumsum_t
|
Tensor - pre-transposed dA_cumsum (nchunks, nheads, chunk_size), float32 |
required | |
states_in_fp32
|
bool |
True
|
Return
states: Tensor - (nchunks, nheads, hdim, dstate)
Source code in vllm_gaudi/ops/pytorch_implementation.py
new_ssd_bmm
¶
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
a
|
Tensor - (seqlen, ngroups, k) |
required | |
b
|
Tensor - (seqlen, ngroups, k) |
required | |
chunk_size
|
int |
required | |
causal
|
bool |
False
|
|
out_dtype
|
Optional dtype |
required |
Return: output: Tensor - (nchunks, ngroups, chunk_size, chunk_size)
Source code in vllm_gaudi/ops/pytorch_implementation.py
new_ssd_state_passing
¶
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
states
|
Tensor - (nchunks, nheads, hdim) |
required | |
dA_cumsum
|
Tensor - (nheads, nchunks, chunk_size) |
required | |
initial_states
|
Optional Tensor - (1, nheads, hdim) |
None
|
|
out_dtype
|
Optional dtype |
None
|
Return: output: Tensor - (nchunks, nheads, hdim)
Note
This implementation uses a parallel prefix-sum approach via a full (nheads, nchunks+1, nchunks+1) decay matrix and batched matmul, trading O(nchunks^2) memory for O(1) sequential depth (fully parallel). This is intentional for performance. For extremely large nchunks, memory usage may become significant; in such cases, consider chunking the sequence into smaller segments and processing sequentially.
Source code in vllm_gaudi/ops/pytorch_implementation.py
selective_state_update_ref
¶
selective_state_update_ref(
state,
x,
dt,
A,
B,
C,
D=None,
z=None,
dt_bias=None,
dt_softplus=False,
softplus_thres=20.0,
)
Argument
state: (batch, dim, dstate) or (batch, nheads, dim, dstate) x: (batch, dim) or (batch, nheads, dim) dt: (batch, dim) or (batch, nheads, dim) A: (dim, dstate) or (nheads, dim, dstate) B: (batch, dstate) or (batch, ngroups, dstate) C: (batch, dstate) or (batch, ngroups, dstate) D: (dim,) or (nheads, dim) z: (batch, dim) or (batch, nheads, dim) dt_bias: (dim,) or (nheads, dim)
Return: out: (batch, dim) or (batch, nheads, dim)