class HPUBucketingManager():
_instance = None
prompt_buckets: List[Tuple[int, int, int]] = []
decode_buckets: List[Tuple[int, int, int]] = []
# Seed buckets are the buckets originally generated from bucketing configuration
# Spec decode may automatically add new buckets based on the seed buckets
seed_decode_buckets: List[Tuple[int, int, int]] = None
initialized = False
def __new__(cls, *args, **kwargs):
if not cls._instance:
cls._instance = super(HPUBucketingManager, cls).__new__(cls)
return cls._instance
def initialize(self,
max_num_seqs,
max_num_prefill_seqs,
block_size,
max_num_batched_tokens,
max_model_len,
num_speculative_tokens=0,
mamba_chunk_size=0,
mamba_chunk_size_is_explicit=False):
self.max_num_seqs = max_num_seqs
self.max_num_prefill_seqs = max_num_prefill_seqs
self.block_size = block_size
self.max_num_batched_tokens = max_num_batched_tokens
self.num_hpu_blocks = None
self._fallback_max_ctx = 0
self.max_model_len = max_model_len
self.num_speculative_tokens = num_speculative_tokens
self.mamba_chunk_size = mamba_chunk_size
self.mamba_chunk_size_is_explicit = mamba_chunk_size_is_explicit
self.initialized = True
self.fallback_bs_base_step = 2
self.fallback_seq_base_step = max(32, mamba_chunk_size)
self.fallback_blocks_base_step = 32
if mamba_chunk_size > 0 and self.max_num_batched_tokens % mamba_chunk_size != 0:
raise ValueError(f"max_num_batched_tokens ({self.max_num_batched_tokens}) must be "
f"divisible by mamba_chunk_size ({mamba_chunk_size})")
self.use_sliding_window = get_config().PT_HPU_SDPA_QKV_SLICE_MODE_FWD
if self.use_sliding_window:
self.slice_size = get_config().PT_HPU_SDPA_BC_FACTOR if \
get_config().PT_HPU_SDPA_BC_FACTOR is not None else 1024
self.slice_thld = get_config().VLLM_FUSEDSDPA_SLIDE_THLD if \
get_config().VLLM_FUSEDSDPA_SLIDE_THLD is not None else 8192
msg = (
f"use_sliding_window {self.use_sliding_window}, slice_size {self.slice_size}, threshold {self.slice_thld}"
)
logger().info(msg)
### GENERATE BUCKETS FUNCTIONS ###
def read_from_file(self, is_prompt):
file_name = get_config().VLLM_BUCKETING_FROM_FILE
strategy = FileBucketingStrategy()
return strategy.get_buckets(file_name, is_prompt)
def get_bucketing_strategy(self):
# TODO - we can use different strategies for decode and prompt
bucketing_strategy = get_config().bucketing_strategy
if bucketing_strategy == 'exp':
strategy = ExponentialBucketingStrategy()
elif bucketing_strategy == 'lin':
strategy = LinearBucketingStrategy()
elif bucketing_strategy == 'pad':
strategy = PaddingAwareBucketingStrategy()
else:
raise ValueError(
f"Invalid bucketing strategy: {bucketing_strategy}, please choose from ['exp', 'lin', 'pad']")
# for backward compatibility - if VLLM_EXPONENTIAL_BUCKETING is set, it will override the bucketing strategy
exp_bucketing_env = os.getenv('VLLM_EXPONENTIAL_BUCKETING', None)
if exp_bucketing_env is not None:
logger().warning(
"VLLM_EXPONENTIAL_BUCKETING is deprecated and will be removed in a future release. Use VLLM_BUCKETING_STRATEGY='exp'|'lin'|'pad' instead."
)
use_exp_bucketing = boolean(exp_bucketing_env)
if use_exp_bucketing:
override_strategy = ExponentialBucketingStrategy()
else:
override_strategy = LinearBucketingStrategy()
if override_strategy.__class__ != strategy.__class__:
logger().warning(
f"Overriding bucketing strategy {strategy.__class__.__name__} with {override_strategy.__class__.__name__} due to VLLM_EXPONENTIAL_BUCKETING={exp_bucketing_env}"
)
strategy = override_strategy
return strategy
def generate_prompt_buckets(self):
if self.initialized:
buckets_from_file = None
bs_range = []
query_range = []
ctx_range = []
if get_config().VLLM_BUCKETING_FROM_FILE:
buckets_from_file = self.read_from_file(is_prompt=True)
else:
strategy = self.get_bucketing_strategy()
bs_cfg, query_cfg, ctx_cfg = strategy.get_prompt_cfgs(
max_num_prefill_seqs=self.max_num_prefill_seqs,
block_size=self.block_size,
max_num_batched_tokens=self.max_num_batched_tokens,
max_model_len=self.max_model_len)
bs_range = strategy.get_range(bs_cfg)
query_range = strategy.get_range(query_cfg)
ctx_range = strategy.get_range(ctx_cfg)
logger().debug(f"Prompt BS range: {bs_range}")
logger().debug(f"Prompt query range: {query_range}")
logger().debug(f"Prompt context range: {ctx_range}")
self.prompt_buckets = generate_buckets(bs_range, query_range, ctx_range, True, self.max_model_len,
self.max_num_seqs, self.max_num_prefill_seqs,
self.max_num_batched_tokens, self.block_size, self.num_hpu_blocks,
buckets_from_file, self.mamba_chunk_size,
self.mamba_chunk_size_is_explicit)
self.log_generate_info(True)
if self.use_sliding_window:
self.prompt_buckets = [
t for t in self.prompt_buckets
if t[2] != 0 or (t[2] == 0 and (t[1] < self.slice_thld or
(t[1] >= self.slice_thld and t[1] % self.slice_size == 0)))
]
self.log_generate_info(True)
else:
logger().info("Bucketing is off - skipping prompt buckets generation")
self.prompt_buckets = []
return
def generate_decode_buckets(self):
if self.initialized:
buckets_from_file = None
bs_range = []
query_range = []
ctx_range = []
if get_config().VLLM_BUCKETING_FROM_FILE:
buckets_from_file = self.read_from_file(is_prompt=False)
else:
strategy = self.get_bucketing_strategy()
bs_cfg, query_cfg, ctx_cfg = strategy.get_decode_cfgs(
max_num_seqs=self.max_num_seqs,
block_size=self.block_size,
max_num_batched_tokens=self.max_num_batched_tokens,
max_model_len=self.max_model_len,
max_blocks=self.num_hpu_blocks)
bs_range = strategy.get_range(bs_cfg)
query_range = strategy.get_range(query_cfg)
ctx_range = strategy.get_range(ctx_cfg)
if get_config().use_contiguous_pa and ctx_range[-1] < self.num_hpu_blocks:
ctx_range.append(self.num_hpu_blocks)
logger().debug(f"Decode BS range: {bs_range}")
logger().debug(f"Decode query range: {query_range}")
logger().debug(f"Decode context range: {ctx_range}")
self.decode_buckets = generate_buckets(bs_range, query_range, ctx_range, False, self.max_model_len,
self.max_num_seqs, self.max_num_prefill_seqs,
self.max_num_batched_tokens, self.block_size, self.num_hpu_blocks,
buckets_from_file, self.mamba_chunk_size,
self.mamba_chunk_size_is_explicit)
if self.num_speculative_tokens:
# The existing buckets are used as seed decode buckets
self.seed_decode_buckets = self.decode_buckets
# More buckets are added automatically for spec decode
self.decode_buckets = self.generate_spec_decode_buckets(self.decode_buckets)
# Safety cap for fallback: max ctx from ALL prepared decode buckets
# (including spec decode expansions). Prevents catastrophic
# allocations from corrupt batch data while allowing
# calc_fallback_value to handle moderate overflow.
self._fallback_max_ctx = max((ctx for _, _, ctx in self.decode_buckets), default=0)
self.log_generate_info(False)
else:
logger().info("Bucketing is off - skipping decode buckets generation")
self.decode_buckets = []
return
def log_generate_info(self, is_prompt=False):
phase = 'prompt' if is_prompt else 'decode'
buckets = self.prompt_buckets if is_prompt else self.decode_buckets
msg = (f"Generated {len(buckets)} "
f"{phase} buckets [bs, query, num_blocks]: "
f"{list(buckets)}")
logger().info(msg)
### RETRIEVE BUCKETS FUNCTIONS ###
def generate_fallback_bucket(self, batch_size, seq_len, ctx, is_prompt=False):
assert self.max_num_batched_tokens is not None
new_batch_size = calc_fallback_value(batch_size, self.fallback_bs_base_step)
if new_batch_size > self.max_num_seqs:
new_batch_size = self.max_num_seqs
if self.use_sliding_window and seq_len >= self.slice_thld:
new_seq_len = math.ceil(seq_len / self.slice_size) * self.slice_size
else:
new_seq_len = min(calc_fallback_value(seq_len, self.fallback_seq_base_step), self.max_num_batched_tokens)
if self.num_hpu_blocks is None:
new_ctx = 0
else:
new_ctx = calc_fallback_value(ctx, self.fallback_blocks_base_step)
if is_prompt:
# For prompt buckets, cap to the theoretical max number of
# context blocks derived from max_model_len.
max_prompt_ctx = math.ceil((self.max_model_len - 1) / self.block_size)
if new_ctx > max_prompt_ctx:
logger().warning(f"Fallback prompt ctx {new_ctx} exceeds theoretical "
f"max {max_prompt_ctx} (max_model_len="
f"{self.max_model_len}, block_size={self.block_size}),"
f" capping.")
new_ctx = max_prompt_ctx
else:
# Safety cap: limit to max prepared decode bucket ctx to prevent
# catastrophic graph compilation from corrupted batch data.
if self._fallback_max_ctx > 0 and new_ctx > self._fallback_max_ctx:
logger().warning(f"Fallback ctx {new_ctx} exceeds max prepared "
f"decode bucket ctx {self._fallback_max_ctx}, capping.")
new_ctx = self._fallback_max_ctx
return (new_batch_size, new_seq_len, new_ctx)
def find_prompt_bucket(self, batch_size, seq_len, ctx=0):
if self.initialized:
found_bucket = find_equal_or_closest_greater_config(self.prompt_buckets, (batch_size, seq_len, ctx))
if found_bucket is None:
new_bucket = self.generate_fallback_bucket(batch_size, seq_len, ctx, is_prompt=True)
logger().warning(f"Prompt bucket for {batch_size, seq_len, ctx}"
f" was not prepared. Adding new bucket: {new_bucket}")
self.prompt_buckets.append(new_bucket)
self.prompt_buckets.sort()
return new_bucket
return found_bucket
return (batch_size, seq_len, ctx)
def find_decode_bucket(self, batch_size, num_blocks, seed_buckets: bool = False):
if self.initialized:
# Cap num_blocks to the max prepared ctx so that previously-added
# capped fallback buckets are found on subsequent lookups, avoiding
# an infinite loop of "was not prepared" warnings.
if self._fallback_max_ctx > 0 and num_blocks > self._fallback_max_ctx:
num_blocks = self._fallback_max_ctx
if seed_buckets and self.seed_decode_buckets is not None:
found_bucket = find_equal_or_closest_greater_config(self.seed_decode_buckets,
(batch_size, 1, num_blocks))
if found_bucket is not None:
return found_bucket
found_bucket = find_equal_or_closest_greater_config(self.decode_buckets, (batch_size, 1, num_blocks))
if found_bucket is None:
new_bucket = self.generate_fallback_bucket(batch_size, 1, num_blocks)
if new_bucket not in self.decode_buckets:
logger().warning(f"Decode bucket for {batch_size, 1, num_blocks}"
f" was not prepared. Adding new bucket: {new_bucket}")
self.decode_buckets.append(new_bucket)
self.decode_buckets.sort()
return new_bucket
return found_bucket
return (batch_size, 1, num_blocks)
def get_max_prompt_shape(self):
return max(b[1] for b in self.prompt_buckets) \
if len(self.prompt_buckets) > 0 else self.max_model_len
def generate_spec_decode_buckets(self, seed_decode_buckets):
max_model_len = self.max_model_len
block_size = self.block_size
def no_corrections(bs, query, ctx):
return (bs, query, ctx)
def correct_for_max_model_len(bs, query, ctx):
return (bs, query, min(ctx, bs * math.ceil(max_model_len / block_size)))
def get_corrector(use_contiguous_pa):
if use_contiguous_pa:
return no_corrections
else:
return correct_for_max_model_len
use_contiguous_pa = get_config().use_contiguous_pa
corrector = get_corrector(use_contiguous_pa)
# If spec decode enabled, generate buckets for batch_size * (1 + num_speculative_tokens)
num_tokens = 1 + self.num_speculative_tokens
buckets = set()
for bucket in seed_decode_buckets:
buckets.add(bucket)
bs, query, ctx = bucket
spec_decode_bs = bs * num_tokens
if spec_decode_bs <= ctx:
# Add a bucket with (batch_size * num_tokens, query, ctx)
buckets.add(corrector(spec_decode_bs, query, ctx))
# Add a bucket with (batch_size * num_tokens, query, ctx * num_tokens)
buckets.add(corrector(spec_decode_bs, query, ctx * num_tokens))
# Log the new generated spec decode buckets
new_buckets = sorted(buckets - set(seed_decode_buckets))
msg = (f"Generated {len(new_buckets)} "
f"spec decode buckets [bs, query, num_blocks]: {list(new_buckets)}")
logger().info(msg)
return sorted(buckets)
@classmethod
def get_instance(cls):
"""
Retrieve the singleton instance of the class.
"""
return cls._instance