DEV Community

Cover image for Entropix: Sampling Techniques for Maximizing Inference Performance
M Sea Bass
M Sea Bass

Posted on

Entropix: Sampling Techniques for Maximizing Inference Performance

Entropix: Sampling Techniques for Maximizing Inference Performance

According to the Entropix README, Entropix uses an entropy-based sampling method. This article explains the specific sampling techniques based on entropy and varentropy.

Entropy and Varentropy

Let's start by explaining entropy and varentropy, as these are key factors in determining the sampling strategy.

Entropy

In information theory, entropy is a measure of the uncertainty of a random variable. The entropy of a random variable X is defined by the following equation:

Image description

  • X: A discrete random variable.
  • x_i: The i-th possible state of X.
  • p(x_i): The probability of state x_i.

Entropy is maximized when the probability distribution is uniform. Conversely, when a specific state is much more likely than others, entropy decreases.

Varentropy

Varentropy, closely related to entropy, represents the variability in the information content. Considering the information content I(X), entropy H(X), and variance for a random variable X, varentropy V E(X) is defined as follows:

Image description

Varentropy becomes large when the probabilities p(x_i) vary greatly. It becomes small when the probabilities are uniform—either when the distribution has maximum entropy or when one value has a probability of 1 and all others have a probability of 0.

Sampling Methods

Next, let's explore how sampling strategies change based on entropy and varentropy values.

Quadrants

1. Low Entropy, Low Varentropy → Argmax

In this scenario, a particular token has a much higher prediction probability than the others. Since the next token is almost certain, Argmax is used.

if ent < 0.1 and vent < 0.1:
    return torch.argmax(logits[:, -1], dim=-1, keepdim=True).to(torch.int32)
Enter fullscreen mode Exit fullscreen mode

Code link

2. Low Entropy, High Varentropy → Branch

This occurs when there is some confidence, but multiple viable options exist. In this case, the Branch strategy is used to sample from multiple choices and select the best outcome.

elif ent < 5.0 and vent > 5.0:
    temp_adj = 1.2 + 0.3 * interaction_strength
    top_k_adj = max(5, int(top_k * (1 + 0.5 * (1 - agreement))))
    return _sample(logits, temperature=min(1.5, temperature * temp_adj), top_p=top_p, top_k=top_k_adj, min_p=min_p, generator=generator)
Enter fullscreen mode Exit fullscreen mode

Code link

Although this strategy is called "Branch," the current code appears to adjust the sampling range and select a single path. (If anyone has more insight, further clarification would be appreciated.)

3. High Entropy, Low Varentropy → CoT or Insert Pause Token

When the prediction probabilities of the next token are fairly uniform, indicating that the next context is not certain, a clarification token is inserted to resolve the ambiguity.

elif ent > 3.0 and vent < 0.1:
    if not torch.isin(gen_tokens[:,-1], torch.tensor([2564], device=device)).any():
        return torch.tensor([[2564]], dtype=torch.int32, device=device)
    else:
        temp_adj = 1.3 + 0.2 * attn_ent
        return _sample(logits, temperature=min(1.5, temperature * temp_adj), top_p=top_p, top_k=top_k, min_p=min_p, generator=generator)
Enter fullscreen mode Exit fullscreen mode

Code link

4. High Entropy, High Varentropy → Resample

In this case, there are multiple contexts, and the prediction probabilities of the next token are low. A resampling strategy is used with a higher temperature setting and a lower top-p.

elif ent > 5.0 and vent > 5.0:
    temp_adj = 2.0 + 0.5 * attn_vent
    top_p_adj = max(0.5, top_p - 0.2 * attn_ent)
    return _sample(logits, temperature=max(2.0, temperature * temp_adj), top_p=top_p_adj, top_k=top_k, min_p=min_p, generator=generator)
Enter fullscreen mode Exit fullscreen mode

Code link

Intermediate Cases

If none of the above conditions are met, adaptive sampling is performed. Multiple samples are taken, and the best sampling score is calculated based on entropy, varentropy, and attention information.

else:
    return adaptive_sample(
        logits,
        metrics,
        gen_tokens,
        n_samples=5,
        base_temp=temperature,
        base_top_p=top_p,
        base_top_k=top_k,
        generator=generator
    )
Enter fullscreen mode Exit fullscreen mode

Code link


References

Top comments (0)