Source code for slangmod.ml.generators.top_p

from typing import Any
import torch.distributions as ptd
import torch.nn.functional as ptnf
from swak.pt.types import Module, Tensor
from ..tokenizers import Algo
from .abc import NextToken


[docs] class TopP(NextToken): """Randomly draw the next token from the top fraction of probability. If the model is very sure about what the next token should be, then most of the probability mass will be concentrated on that token and the random draw will be from among very few tokens. If, in contrast, the model is not so sure, and the probability mass is more widely distributed, then the random draw will be from among many more candidate tokens. Parameters ---------- tokenizer: Algo Fully configured ``Algo`` wrapper around a trained tokenizer. model: Module The trained PyTorch model to use for text generation. max_tokens: int, optional The maximum number of tokens to generate in case the end-of-sequence token is not predicted by the model first. Defaults to 256. p: float, optional Candidate tokens to draw from are chosen by ranking all in order of descending probability and taking as many as possible before the sum of their individual probabilities exceeds `p`. Default to 1.0, which results in a draw from a categorical distribution over *all* eligible tokens. temperature: float, optional Higher temperatures concentrate more probability mass onto the most likely tokens, while lower temperatures spread the probability mass out among all eligible tokens. Defaults to 1.0, which results in unmodified logits. """ def __init__( self, tokenizer: Algo, model: Module, max_tokens: int = 256, p: float = 1.0, temperature: float = 1.0, **_: Any ) -> None: super().__init__(tokenizer, model, max_tokens) self.p = p self.temperature = temperature def __repr__(self) -> str: extras = f', p={self.p}, temperature={self.temperature})' return super().__repr__()[:-1] + extras
[docs] def next_token_from_logits(self, logits: Tensor) -> Tensor: """Randomly draw the next token from among the most probable ones. Parameters ---------- logits: Tensor 1-D PyTorch tensor with un-normalized probabilities over all permissible tokens in the vocabulary. Returns ------- Tensor Int64 scalar with the ID of the next token randomly chosen from the top candidates that together have a probability of `p`. """ scaled = logits / self.temperature probas = ptnf.softmax(scaled, dim=-1).sort(dim=-1, descending=True) # Boolean mask starting with True until cumsum reaches p top_p = probas.values.cumsum(dim=-1) <= self.p # We need at least one element to draw top_p[0] = True # Sample only from the top-p candidates sample = ptd.Categorical(probas.values[top_p]).sample() # Return their actual index among the logits return probas.indices[sample]