Skip to content

vllm.v1.sample.sampler

A layer that samples the next tokens from the model's outputs.

_SAMPLING_EPS module-attribute

_SAMPLING_EPS = 1e-05

Sampler

Bases: Module

A layer that samples the next tokens from the model's outputs with the following steps in order:

  1. If logprobs are requested: a) If logprobs_mode is raw_logprobs, compute logprobs as the final logprobs to return. b) If logprobs_mode is raw_logits, clone the logits as the final logprobs to return.
  2. Convert logits to float32.
  3. Apply allowed token ids whitelist.
  4. Apply bad words exclusion.
  5. Apply logit processors which are not argmax-invariant, i.e. that can impact greedy sampling. a) Min tokens processor b) Logit bias processor
  6. Apply penalties a) Repetition penalty b) Frequency penalty c) Presence penalty
  7. Sample the next tokens. sample method performs the following steps: a) If not all_random, perform greedy sampling. If all_greedy, return the greedily sampled tokens and final logprobs if requested. b) Apply temperature. c) Apply logit processors which are argmax-invariant, by default the min_p processor. d) Apply top_k and/or top_p. e) Sample the next tokens with the probability distribution. f) If all_random or temperature >= epsilon (1e-5), return the randomly sampled tokens and final logprobs if requested. Else, return the greedily sampled tokens and logprobs if requested.
  8. Gather the logprobs of the top max_num_logprobs and sampled token (if requested). Note that if the sampled token is within the top max_num_logprobs, the logprob will be eventually merged in LogprobsProcessor during output processing. Therefore, the final output may contain either max_num_logprobs + 1 or max_num_logprobs logprobs.
  9. Return the final SamplerOutput.
Source code in vllm/v1/sample/sampler.py
class Sampler(nn.Module):
    """
    A layer that samples the next tokens from the model's outputs
    with the following steps in order:

    1. If logprobs are requested:
        a) If `logprobs_mode` is `raw_logprobs`, compute logprobs
           as the final logprobs to return.
        b) If `logprobs_mode` is `raw_logits`, clone the logits
           as the final logprobs to return.
    2. Convert logits to float32.
    3. Apply allowed token ids whitelist.
    4. Apply bad words exclusion.
    5. Apply logit processors which are not argmax-invariant,
       i.e. that can impact greedy sampling.
        a) Min tokens processor
        b) Logit bias processor
    6. Apply penalties
        a) Repetition penalty
        b) Frequency penalty
        c) Presence penalty
    7. Sample the next tokens. `sample` method performs the following steps:
        a) If not `all_random`, perform greedy sampling. If `all_greedy`,
           return the greedily sampled tokens and final logprobs if requested.
        b) Apply temperature.
        c) Apply logit processors which are argmax-invariant, by default
           the min_p processor.
        d) Apply top_k and/or top_p.
        e) Sample the next tokens with the probability distribution.
        f) If `all_random` or temperature >= epsilon (1e-5), return the
           randomly sampled tokens and final logprobs if requested. Else,
           return the greedily sampled tokens and logprobs if requested.
    8. Gather the logprobs of the top `max_num_logprobs` and sampled token
       (if requested). Note that if the sampled token is within the top
       `max_num_logprobs`, the logprob will be eventually merged in
       `LogprobsProcessor` during output processing. Therefore, the
       final output may contain either `max_num_logprobs + 1` or
       `max_num_logprobs` logprobs.
    9. Return the final `SamplerOutput`.
    """

    def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs"):
        super().__init__()
        self.topk_topp_sampler = TopKTopPSampler(logprobs_mode)
        self.pin_memory = is_pin_memory_available()
        self.logprobs_mode = logprobs_mode

    def forward(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
        predict_bonus_token: bool = False,
    ) -> SamplerOutput:
        # NOTE(woosuk): Use the original logits (before any penalties or
        # temperature scaling) for the top-k logprobs.
        # This is different from the V0 sampler, which uses the logits that
        # is used for sampling (after penalties and temperature scaling).
        num_logprobs = sampling_metadata.max_num_logprobs
        if num_logprobs is not None:
            if self.logprobs_mode == "raw_logprobs":
                raw_logprobs = self.compute_logprobs(logits)
            elif self.logprobs_mode == "raw_logits":
                raw_logprobs = logits.clone()

        # Use float32 for the logits.
        logits = logits.to(torch.float32)

        logits = self.apply_logits_processors(
            logits, sampling_metadata, predict_bonus_token
        )
        # Sample the next token.
        sampled, processed_logprobs = self.sample(logits, sampling_metadata)
        if processed_logprobs is not None:
            raw_logprobs = processed_logprobs
        # Convert sampled token ids to int64 (long) type to ensure compatibility
        # with subsequent operations that may use these values as indices.
        # This conversion is necessary because FlashInfer sampling operations
        # return int32 (while PyTorch argmax and topk return int64).
        sampled = sampled.long()

        # Gather the logprobs of the topk and sampled token (if requested).
        # Get logprobs and rank tensors (if requested)
        logprobs_tensors = (
            None
            if num_logprobs is None
            else self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=sampled)
        )

        # Use int32 to reduce the tensor size.
        sampled = sampled.to(torch.int32)

        # These are GPU tensors.
        sampler_output = SamplerOutput(
            # The sampled tokens are expanded to 2D tensor with shape
            # [num_requests, 1], where each row represents one generated
            # token per request.
            sampled_token_ids=sampled.unsqueeze(-1),
            logprobs_tensors=logprobs_tensors,
        )
        return sampler_output

    def apply_temperature(
        self,
        logits: torch.Tensor,
        temp: torch.Tensor,
        all_random: bool,
    ) -> torch.Tensor:
        # Use in-place division to avoid creating a new tensor.
        # Avoid division by zero if there are greedy requests.
        if not all_random:
            temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp)
        return logits.div_(temp.unsqueeze(dim=1))

    def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor:
        return logits.argmax(dim=-1).view(-1)

    def sample(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
    ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
        """Sample logits based on sampling metadata.

        The various logits processing functions called in this method
        may update the logits tensor in-place.
        """

        assert not (sampling_metadata.all_greedy and sampling_metadata.all_random)
        if sampling_metadata.all_random:
            greedy_sampled = None
        else:
            greedy_sampled = self.greedy_sample(logits)
            if sampling_metadata.all_greedy:
                processed_logprobs = None
                if sampling_metadata.max_num_logprobs is not None:
                    if self.logprobs_mode == "processed_logits":
                        processed_logprobs = logits
                    elif self.logprobs_mode == "processed_logprobs":
                        processed_logprobs = self.compute_logprobs(logits)
                return greedy_sampled, processed_logprobs

        assert sampling_metadata.temperature is not None

        # Apply temperature.
        logits = self.apply_temperature(
            logits, sampling_metadata.temperature, sampling_metadata.all_random
        )

        # Apply logits processors that only apply to random sampling
        # (argmax invariant)
        for processor in sampling_metadata.logitsprocs.argmax_invariant:
            logits = processor.apply(logits)

        # Apply top_k and/or top_p.
        random_sampled, processed_logprobs = self.topk_topp_sampler(
            logits,
            sampling_metadata.generators,
            sampling_metadata.top_k,
            sampling_metadata.top_p,
        )

        if greedy_sampled is None:
            return random_sampled, processed_logprobs

        sampled = torch.where(
            sampling_metadata.temperature < _SAMPLING_EPS,
            greedy_sampled,
            random_sampled,
            out=greedy_sampled,  # Reuse tensor
        )
        return sampled, processed_logprobs

    def compute_logprobs(self, logits: torch.Tensor) -> torch.Tensor:
        return logits.log_softmax(dim=-1, dtype=torch.float32)

    def gather_logprobs(
        self,
        logprobs: torch.Tensor,
        num_logprobs: int,
        token_ids: torch.Tensor,
    ) -> LogprobsTensors:
        """
        Gather logprobs for topk and sampled/prompt token.

        Args:
          logprobs: (num tokens) x (vocab) tensor
          num_logprobs: minimum number of logprobs to
                        retain per token
          token_ids: prompt tokens (if prompt logprobs)
                     or sampled tokens (if sampled
                     logprobs); 1D token ID tensor
                     with (num tokens) elements
                     Must be int64.

        Returns:
          Top-k int indices tensor, (num tokens) x (num_logprobs + 1)
          Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1)
          Sampled token rank tensor, (num tokens)
        """
        assert token_ids.dtype == torch.int64
        # Find the topK values.
        topk_logprobs, topk_indices = torch.topk(logprobs, num_logprobs, dim=-1)

        # Get with the logprob of the prompt or sampled token.
        token_ids = token_ids.unsqueeze(-1)
        token_logprobs = logprobs.gather(-1, token_ids)

        # Compute the ranks of the actual token.
        token_ranks = batched_count_greater_than(logprobs, token_logprobs)

        # Concatenate together with the topk.
        indices = torch.cat((token_ids, topk_indices), dim=1)
        logprobs = torch.cat((token_logprobs, topk_logprobs), dim=1)

        # Use int32 to reduce the tensor size.
        indices = indices.to(torch.int32)

        return LogprobsTensors(indices, logprobs, token_ranks)

    def _combine_outputs_with_spec_tokens(
        self,
        output_token_ids: list[list[int]],
        spec_token_ids: Optional[list[list[int]]] = None,
    ) -> list[list[int]]:
        if spec_token_ids is None:
            return output_token_ids

        return [
            [*out, *spec] if spec else out
            for out, spec in zip(output_token_ids, spec_token_ids)
        ]

    def apply_logits_processors(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
        predict_bonus_token: bool,
    ) -> torch.Tensor:
        any_penalties_or_bad_words = (
            sampling_metadata.bad_words_token_ids or not sampling_metadata.no_penalties
        )

        output_token_ids = sampling_metadata.output_token_ids
        if predict_bonus_token and any_penalties_or_bad_words:
            # Combine base outputs with spec tokens when speculative decoding
            # is enabled.
            output_token_ids = self._combine_outputs_with_spec_tokens(
                sampling_metadata.output_token_ids,
                sampling_metadata.spec_token_ids,
            )

        # Apply allowed token ids.
        if sampling_metadata.allowed_token_ids_mask is not None:
            logits.masked_fill_(sampling_metadata.allowed_token_ids_mask, float("-inf"))

        # Apply bad words exclusion.
        if sampling_metadata.bad_words_token_ids:
            apply_bad_words(
                logits,
                sampling_metadata.bad_words_token_ids,
                output_token_ids
                if output_token_ids is not None
                else sampling_metadata.output_token_ids,
            )

        # Apply logits processors which can impact greedy sampling.
        for processor in sampling_metadata.logitsprocs.non_argmax_invariant:
            logits = processor.apply(logits)

        # Apply penalties (e.g., freq_penalties).
        logits = self.apply_penalties(logits, sampling_metadata, output_token_ids)
        return logits

    def apply_penalties(
        self,
        logits: torch.Tensor,
        sampling_metadata: SamplingMetadata,
        output_token_ids: Optional[list[list[int]]] = None,
    ) -> torch.Tensor:
        if not sampling_metadata.no_penalties:
            assert sampling_metadata.prompt_token_ids is not None
            logits = apply_all_penalties(
                logits,
                sampling_metadata.prompt_token_ids,
                sampling_metadata.presence_penalties,
                sampling_metadata.frequency_penalties,
                sampling_metadata.repetition_penalties,
                output_token_ids
                if output_token_ids is not None
                else sampling_metadata.output_token_ids,
            )
        return logits

logprobs_mode instance-attribute

logprobs_mode = logprobs_mode

pin_memory instance-attribute

pin_memory = is_pin_memory_available()

topk_topp_sampler instance-attribute

topk_topp_sampler = TopKTopPSampler(logprobs_mode)

__init__

__init__(logprobs_mode: LogprobsMode = 'raw_logprobs')
Source code in vllm/v1/sample/sampler.py
def __init__(self, logprobs_mode: LogprobsMode = "raw_logprobs"):
    super().__init__()
    self.topk_topp_sampler = TopKTopPSampler(logprobs_mode)
    self.pin_memory = is_pin_memory_available()
    self.logprobs_mode = logprobs_mode

_combine_outputs_with_spec_tokens

_combine_outputs_with_spec_tokens(
    output_token_ids: list[list[int]],
    spec_token_ids: Optional[list[list[int]]] = None,
) -> list[list[int]]
Source code in vllm/v1/sample/sampler.py
def _combine_outputs_with_spec_tokens(
    self,
    output_token_ids: list[list[int]],
    spec_token_ids: Optional[list[list[int]]] = None,
) -> list[list[int]]:
    if spec_token_ids is None:
        return output_token_ids

    return [
        [*out, *spec] if spec else out
        for out, spec in zip(output_token_ids, spec_token_ids)
    ]

apply_logits_processors

apply_logits_processors(
    logits: Tensor,
    sampling_metadata: SamplingMetadata,
    predict_bonus_token: bool,
) -> Tensor
Source code in vllm/v1/sample/sampler.py
def apply_logits_processors(
    self,
    logits: torch.Tensor,
    sampling_metadata: SamplingMetadata,
    predict_bonus_token: bool,
) -> torch.Tensor:
    any_penalties_or_bad_words = (
        sampling_metadata.bad_words_token_ids or not sampling_metadata.no_penalties
    )

    output_token_ids = sampling_metadata.output_token_ids
    if predict_bonus_token and any_penalties_or_bad_words:
        # Combine base outputs with spec tokens when speculative decoding
        # is enabled.
        output_token_ids = self._combine_outputs_with_spec_tokens(
            sampling_metadata.output_token_ids,
            sampling_metadata.spec_token_ids,
        )

    # Apply allowed token ids.
    if sampling_metadata.allowed_token_ids_mask is not None:
        logits.masked_fill_(sampling_metadata.allowed_token_ids_mask, float("-inf"))

    # Apply bad words exclusion.
    if sampling_metadata.bad_words_token_ids:
        apply_bad_words(
            logits,
            sampling_metadata.bad_words_token_ids,
            output_token_ids
            if output_token_ids is not None
            else sampling_metadata.output_token_ids,
        )

    # Apply logits processors which can impact greedy sampling.
    for processor in sampling_metadata.logitsprocs.non_argmax_invariant:
        logits = processor.apply(logits)

    # Apply penalties (e.g., freq_penalties).
    logits = self.apply_penalties(logits, sampling_metadata, output_token_ids)
    return logits

apply_penalties

apply_penalties(
    logits: Tensor,
    sampling_metadata: SamplingMetadata,
    output_token_ids: Optional[list[list[int]]] = None,
) -> Tensor
Source code in vllm/v1/sample/sampler.py
def apply_penalties(
    self,
    logits: torch.Tensor,
    sampling_metadata: SamplingMetadata,
    output_token_ids: Optional[list[list[int]]] = None,
) -> torch.Tensor:
    if not sampling_metadata.no_penalties:
        assert sampling_metadata.prompt_token_ids is not None
        logits = apply_all_penalties(
            logits,
            sampling_metadata.prompt_token_ids,
            sampling_metadata.presence_penalties,
            sampling_metadata.frequency_penalties,
            sampling_metadata.repetition_penalties,
            output_token_ids
            if output_token_ids is not None
            else sampling_metadata.output_token_ids,
        )
    return logits

apply_temperature

apply_temperature(
    logits: Tensor, temp: Tensor, all_random: bool
) -> Tensor
Source code in vllm/v1/sample/sampler.py
def apply_temperature(
    self,
    logits: torch.Tensor,
    temp: torch.Tensor,
    all_random: bool,
) -> torch.Tensor:
    # Use in-place division to avoid creating a new tensor.
    # Avoid division by zero if there are greedy requests.
    if not all_random:
        temp = torch.where(temp < _SAMPLING_EPS, 1.0, temp)
    return logits.div_(temp.unsqueeze(dim=1))

compute_logprobs

compute_logprobs(logits: Tensor) -> Tensor
Source code in vllm/v1/sample/sampler.py
def compute_logprobs(self, logits: torch.Tensor) -> torch.Tensor:
    return logits.log_softmax(dim=-1, dtype=torch.float32)

forward

forward(
    logits: Tensor,
    sampling_metadata: SamplingMetadata,
    predict_bonus_token: bool = False,
) -> SamplerOutput
Source code in vllm/v1/sample/sampler.py
def forward(
    self,
    logits: torch.Tensor,
    sampling_metadata: SamplingMetadata,
    predict_bonus_token: bool = False,
) -> SamplerOutput:
    # NOTE(woosuk): Use the original logits (before any penalties or
    # temperature scaling) for the top-k logprobs.
    # This is different from the V0 sampler, which uses the logits that
    # is used for sampling (after penalties and temperature scaling).
    num_logprobs = sampling_metadata.max_num_logprobs
    if num_logprobs is not None:
        if self.logprobs_mode == "raw_logprobs":
            raw_logprobs = self.compute_logprobs(logits)
        elif self.logprobs_mode == "raw_logits":
            raw_logprobs = logits.clone()

    # Use float32 for the logits.
    logits = logits.to(torch.float32)

    logits = self.apply_logits_processors(
        logits, sampling_metadata, predict_bonus_token
    )
    # Sample the next token.
    sampled, processed_logprobs = self.sample(logits, sampling_metadata)
    if processed_logprobs is not None:
        raw_logprobs = processed_logprobs
    # Convert sampled token ids to int64 (long) type to ensure compatibility
    # with subsequent operations that may use these values as indices.
    # This conversion is necessary because FlashInfer sampling operations
    # return int32 (while PyTorch argmax and topk return int64).
    sampled = sampled.long()

    # Gather the logprobs of the topk and sampled token (if requested).
    # Get logprobs and rank tensors (if requested)
    logprobs_tensors = (
        None
        if num_logprobs is None
        else self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=sampled)
    )

    # Use int32 to reduce the tensor size.
    sampled = sampled.to(torch.int32)

    # These are GPU tensors.
    sampler_output = SamplerOutput(
        # The sampled tokens are expanded to 2D tensor with shape
        # [num_requests, 1], where each row represents one generated
        # token per request.
        sampled_token_ids=sampled.unsqueeze(-1),
        logprobs_tensors=logprobs_tensors,
    )
    return sampler_output

gather_logprobs

gather_logprobs(
    logprobs: Tensor, num_logprobs: int, token_ids: Tensor
) -> LogprobsTensors

Gather logprobs for topk and sampled/prompt token.

Parameters:

Name Type Description Default
logprobs Tensor

(num tokens) x (vocab) tensor

required
num_logprobs int

minimum number of logprobs to retain per token

required
token_ids Tensor

prompt tokens (if prompt logprobs) or sampled tokens (if sampled logprobs); 1D token ID tensor with (num tokens) elements Must be int64.

required

Returns:

Type Description
LogprobsTensors

Top-k int indices tensor, (num tokens) x (num_logprobs + 1)

LogprobsTensors

Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1)

LogprobsTensors

Sampled token rank tensor, (num tokens)

Source code in vllm/v1/sample/sampler.py
def gather_logprobs(
    self,
    logprobs: torch.Tensor,
    num_logprobs: int,
    token_ids: torch.Tensor,
) -> LogprobsTensors:
    """
    Gather logprobs for topk and sampled/prompt token.

    Args:
      logprobs: (num tokens) x (vocab) tensor
      num_logprobs: minimum number of logprobs to
                    retain per token
      token_ids: prompt tokens (if prompt logprobs)
                 or sampled tokens (if sampled
                 logprobs); 1D token ID tensor
                 with (num tokens) elements
                 Must be int64.

    Returns:
      Top-k int indices tensor, (num tokens) x (num_logprobs + 1)
      Top-k float logprobs tensor, (num tokens) x (num_logprobs + 1)
      Sampled token rank tensor, (num tokens)
    """
    assert token_ids.dtype == torch.int64
    # Find the topK values.
    topk_logprobs, topk_indices = torch.topk(logprobs, num_logprobs, dim=-1)

    # Get with the logprob of the prompt or sampled token.
    token_ids = token_ids.unsqueeze(-1)
    token_logprobs = logprobs.gather(-1, token_ids)

    # Compute the ranks of the actual token.
    token_ranks = batched_count_greater_than(logprobs, token_logprobs)

    # Concatenate together with the topk.
    indices = torch.cat((token_ids, topk_indices), dim=1)
    logprobs = torch.cat((token_logprobs, topk_logprobs), dim=1)

    # Use int32 to reduce the tensor size.
    indices = indices.to(torch.int32)

    return LogprobsTensors(indices, logprobs, token_ranks)

greedy_sample

greedy_sample(logits: Tensor) -> Tensor
Source code in vllm/v1/sample/sampler.py
def greedy_sample(self, logits: torch.Tensor) -> torch.Tensor:
    return logits.argmax(dim=-1).view(-1)

sample

sample(
    logits: Tensor, sampling_metadata: SamplingMetadata
) -> tuple[Tensor, Optional[Tensor]]

Sample logits based on sampling metadata.

The various logits processing functions called in this method may update the logits tensor in-place.

Source code in vllm/v1/sample/sampler.py
def sample(
    self,
    logits: torch.Tensor,
    sampling_metadata: SamplingMetadata,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
    """Sample logits based on sampling metadata.

    The various logits processing functions called in this method
    may update the logits tensor in-place.
    """

    assert not (sampling_metadata.all_greedy and sampling_metadata.all_random)
    if sampling_metadata.all_random:
        greedy_sampled = None
    else:
        greedy_sampled = self.greedy_sample(logits)
        if sampling_metadata.all_greedy:
            processed_logprobs = None
            if sampling_metadata.max_num_logprobs is not None:
                if self.logprobs_mode == "processed_logits":
                    processed_logprobs = logits
                elif self.logprobs_mode == "processed_logprobs":
                    processed_logprobs = self.compute_logprobs(logits)
            return greedy_sampled, processed_logprobs

    assert sampling_metadata.temperature is not None

    # Apply temperature.
    logits = self.apply_temperature(
        logits, sampling_metadata.temperature, sampling_metadata.all_random
    )

    # Apply logits processors that only apply to random sampling
    # (argmax invariant)
    for processor in sampling_metadata.logitsprocs.argmax_invariant:
        logits = processor.apply(logits)

    # Apply top_k and/or top_p.
    random_sampled, processed_logprobs = self.topk_topp_sampler(
        logits,
        sampling_metadata.generators,
        sampling_metadata.top_k,
        sampling_metadata.top_p,
    )

    if greedy_sampled is None:
        return random_sampled, processed_logprobs

    sampled = torch.where(
        sampling_metadata.temperature < _SAMPLING_EPS,
        greedy_sampled,
        random_sampled,
        out=greedy_sampled,  # Reuse tensor
    )
    return sampled, processed_logprobs