Skip to content

vllm.model_executor.models.transformers.causal

Transformers modeling backend mixin for causal language models.

Classes:

CausalMixin

Bases: VllmModelForTextGeneration

Methods:

  • load_weights

    A thin wrapper around Base.load_weights to handle the lm_head bias.

Source code in vllm/model_executor/models/transformers/causal.py
class CausalMixin(VllmModelForTextGeneration):
    def __init__(self, *, vllm_config: "VllmConfig", prefix: str = ""):
        # Skip VllmModelForTextGeneration.__init__ and call the next class in MRO
        super(VllmModelForTextGeneration, self).__init__(
            vllm_config=vllm_config, prefix=prefix
        )

        # Tell `Base.load_weights` to skip
        # `lm_head` if the model has tied word embeddings
        tie_word_embeddings = self._get_tie_word_embeddings()
        if tie_word_embeddings:
            self.skip_prefixes.append("lm_head.")

        if self.pp_group.is_last_rank:
            self.lm_head = ParallelLMHead(
                self.text_config.vocab_size,
                self.text_config.hidden_size,
                quant_config=self.quant_config,
                prefix=maybe_prefix(prefix, "lm_head"),
            )
            if tie_word_embeddings:
                self.lm_head = self.lm_head.tie_weights(
                    self.model.get_input_embeddings()
                )

            logit_scale = getattr(self.text_config, "logit_scale", 1.0)
            self.logits_processor = LogitsProcessor(
                self.text_config.vocab_size, scale=logit_scale
            )
        else:
            self.lm_head = PPMissingLayer()

    def load_weights(self, weights: Iterable[tuple[str, "torch.Tensor"]]) -> set[str]:
        """A thin wrapper around `Base.load_weights` to handle the lm_head bias."""

        lm_head_bias = set()

        def auto_load_lm_head_bias(weights):
            for name, weight in weights:
                if name.endswith("lm_head.bias") and self.pp_group.is_last_rank:
                    self.lm_head._register_bias()
                    self.lm_head.bias.weight_loader(self.lm_head.bias, weight)
                    lm_head_bias.add(name)
                else:
                    yield name, weight

        return super().load_weights(auto_load_lm_head_bias(weights)) | lm_head_bias

    def compute_logits(self, hidden_states: "torch.Tensor") -> "torch.Tensor | None":
        logits = self.logits_processor(self.lm_head, hidden_states, self.lm_head.bias)
        return logits

load_weights(weights)

A thin wrapper around Base.load_weights to handle the lm_head bias.

Source code in vllm/model_executor/models/transformers/causal.py
def load_weights(self, weights: Iterable[tuple[str, "torch.Tensor"]]) -> set[str]:
    """A thin wrapper around `Base.load_weights` to handle the lm_head bias."""

    lm_head_bias = set()

    def auto_load_lm_head_bias(weights):
        for name, weight in weights:
            if name.endswith("lm_head.bias") and self.pp_group.is_last_rank:
                self.lm_head._register_bias()
                self.lm_head.bias.weight_loader(self.lm_head.bias, weight)
                lm_head_bias.add(name)
            else:
                yield name, weight

    return super().load_weights(auto_load_lm_head_bias(weights)) | lm_head_bias