|
|
|
import gc |
|
import time |
|
from collections import namedtuple |
|
from dataclasses import dataclass, field |
|
from functools import partial |
|
from typing import Callable, Optional, Sequence, Union |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
from einops import rearrange, repeat |
|
from torch import Tensor |
|
from torch.profiler import ProfilerActivity, profile, record_function |
|
from transformers.generation import GreedySearchDecoderOnlyOutput, SampleDecoderOnlyOutput, TextStreamer |
|
|
|
|
|
@dataclass |
|
class InferenceParams: |
|
"""Inference parameters that are passed to the main model in order |
|
to efficienly calculate and store the context during inference.""" |
|
|
|
max_seqlen: int |
|
max_batch_size: int |
|
seqlen_offset: int = 0 |
|
batch_size_offset: int = 0 |
|
key_value_memory_dict: dict = field(default_factory=dict) |
|
lengths_per_sample: Optional[Tensor] = None |
|
|
|
def reset(self, max_seqlen, max_batch_size): |
|
self.max_seqlen = max_seqlen |
|
self.max_batch_size = max_batch_size |
|
self.seqlen_offset = 0 |
|
if self.lengths_per_sample is not None: |
|
self.lengths_per_sample.zero_() |
|
|
|
|
|
def modify_logits_for_min_p_filtering(logits, min_p): |
|
"""Set the logits for none min_p values to -inf. Done in-place.""" |
|
if min_p <= 0.0 or min_p >= 1.0: |
|
return |
|
indices_to_remove = logits < min_p |
|
logits.masked_fill_(indices_to_remove, float("-Inf")) |
|
|
|
|
|
def modify_logits_for_top_k_filtering(logits, top_k): |
|
"""Set the logits for none top-k values to -inf. Done in-place.""" |
|
indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None] |
|
logits.masked_fill_(indices_to_remove, float("-Inf")) |
|
|
|
|
|
|
|
|
|
def modify_logits_for_top_p_filtering(logits, top_p): |
|
"""Set the logits for none top-p values to -inf. Done in-place.""" |
|
if top_p <= 0.0 or top_p >= 1.0: |
|
return |
|
|
|
sorted_logits, sorted_indices = torch.sort(logits, descending=False) |
|
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1) |
|
|
|
sorted_indices_to_remove = cumulative_probs <= (1 - top_p) |
|
|
|
indices_to_remove = sorted_indices_to_remove.scatter( |
|
1, sorted_indices, sorted_indices_to_remove |
|
) |
|
logits.masked_fill_(indices_to_remove, float("-inf")) |
|
|
|
|
|
def modify_logit_for_repetition_penalty(logits, prev_output_tokens, repetition_penalty=1.0): |
|
"""Apply repetition penalty. See https://arxiv.org/abs/1909.05858 |
|
logits: (batch_size, vocab_size) |
|
prev_output_tokens: (batch_size, seq_len) |
|
""" |
|
if repetition_penalty == 1.0: |
|
return logits |
|
score = torch.gather(logits, 1, prev_output_tokens) |
|
|
|
score = torch.where(score < 0, score * repetition_penalty, score / repetition_penalty) |
|
logits.scatter_(1, prev_output_tokens, score) |
|
return logits |
|
|
|
|
|
def sample(logits, top_k=1, top_p=0.0, min_p=0.0, temperature=1.0): |
|
"""Sample from top-k logits. |
|
Arguments: |
|
logits: Tensor of shape (batch_size, vocab_size) |
|
""" |
|
if top_k == 1: |
|
return logits.argmax(dim=-1) |
|
else: |
|
if top_p > 0.0: |
|
assert top_p <= 1.0, "top-p should be in (0, 1]." |
|
if top_k > 0: |
|
top_k = min(top_k, logits.size(-1)) |
|
logits_top, indices = torch.topk(logits, top_k, dim=-1) |
|
if temperature != 1.0: |
|
logits_top /= temperature |
|
modify_logits_for_top_p_filtering(logits_top, top_p) |
|
return indices[ |
|
torch.arange(indices.shape[0], device=indices.device), |
|
torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1), |
|
] |
|
else: |
|
if min_p > 0.0: |
|
logits_top = logits.clone() |
|
max_prob = logits_top[..., 0].item() |
|
min_prob = max_prob * min_p |
|
modify_logits_for_min_p_filtering(logits_top, min_prob) |
|
if temperature != 1.0: |
|
logits_top /= temperature |
|
return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze(dim=-1) |
|
|
|
logits_top = logits / temperature if temperature != 1.0 else logits.clone() |
|
modify_logits_for_top_p_filtering(logits_top, top_p) |
|
return torch.multinomial(torch.softmax(logits_top, dim=-1), num_samples=1).squeeze( |
|
dim=-1 |
|
) |
|
|
|
|
|
@torch.inference_mode() |
|
def decode( |
|
input_ids, |
|
model, |
|
max_length, |
|
top_k=1, |
|
top_p=0.0, |
|
min_p=0.0, |
|
temperature=1.0, |
|
repetition_penalty=1.0, |
|
eos_token_id=None, |
|
teacher_outputs=None, |
|
vocab_size=None, |
|
cg=False, |
|
enable_timing=False, |
|
output_scores=False, |
|
streamer: Optional[TextStreamer] = None |
|
): |
|
"""Decoding, either greedy or with top-k or top-p sampling. |
|
If top-k = 0, don't limit the number of candidates (pure sampling). |
|
Top-k and top-p can be used together. If top_k > 0 and top_p > 0, then top-k is applied first, |
|
then top-p. |
|
We assume that all sequences in the same batch have the same length. |
|
|
|
Arguments: |
|
input_ids: (batch, seq_len) |
|
max_length: int |
|
teacher_outputs (optional): (batch, seq_len). If provided, instead of sampling from the |
|
logits, the next token is taken from the teacher_outputs. Useful for testing. |
|
Returns: GreedySearchDecoderOnlyOutput or SampleDecoderOnlyOutput, with the following fields: |
|
sequences: (batch, max_length) |
|
scores: tuples of (batch, vocab_size) |
|
""" |
|
if streamer is not None: |
|
streamer.put(input_ids.cpu()) |
|
|
|
batch_size, seqlen_og = input_ids.shape |
|
teacher_output_len = teacher_outputs.shape[1] if teacher_outputs is not None else 0 |
|
if cg: |
|
if not hasattr(model, "_decoding_cache"): |
|
model._decoding_cache = None |
|
model._decoding_cache = update_graph_cache( |
|
model, |
|
model._decoding_cache, |
|
batch_size, |
|
seqlen_og, |
|
max_length, |
|
) |
|
inference_params = model._decoding_cache.inference_params |
|
inference_params.reset(max_length, batch_size) |
|
else: |
|
inference_params = InferenceParams(max_seqlen=max_length, max_batch_size=batch_size) |
|
|
|
def get_logits(input_ids, inference_params): |
|
decoding = inference_params.seqlen_offset > 0 |
|
if decoding: |
|
position_ids = torch.full( |
|
(batch_size, 1), |
|
inference_params.seqlen_offset, |
|
dtype=torch.long, |
|
device=input_ids.device, |
|
) |
|
else: |
|
position_ids = None |
|
if not cg or not decoding: |
|
logits = model( |
|
input_ids, |
|
position_ids=position_ids, |
|
inference_params=inference_params, |
|
num_last_tokens=1, |
|
).logits.squeeze(dim=1) |
|
else: |
|
logits = model._decoding_cache.run( |
|
input_ids, position_ids, inference_params.seqlen_offset |
|
).squeeze(dim=1) |
|
return logits[..., :vocab_size] if vocab_size is not None else logits |
|
|
|
def sample_tokens(logits, inference_params): |
|
if teacher_outputs is None or teacher_output_len <= inference_params.seqlen_offset: |
|
token = sample(logits, top_k=top_k, top_p=top_p, min_p=min_p, temperature=temperature) |
|
else: |
|
token = teacher_outputs[:, inference_params.seqlen_offset] |
|
|
|
return token.unsqueeze(1) |
|
|
|
def should_stop(current_token, inference_params): |
|
if inference_params.seqlen_offset == 0: |
|
return False |
|
if eos_token_id is not None and (current_token == eos_token_id).all(): |
|
return True |
|
if inference_params.seqlen_offset >= max_length - 1: |
|
return True |
|
return False |
|
|
|
start = torch.cuda.Event(enable_timing=enable_timing) |
|
end = torch.cuda.Event(enable_timing=enable_timing) |
|
|
|
if enable_timing: |
|
start.record() |
|
scores, sequences = [], [input_ids] |
|
sequences_cat = input_ids |
|
while not should_stop(sequences[-1], inference_params): |
|
logits = get_logits(sequences[-1], inference_params) |
|
if output_scores: |
|
scores.append(logits.clone()) |
|
inference_params.seqlen_offset += sequences[-1].shape[1] |
|
if repetition_penalty == 1.0: |
|
sampled_tokens = sample_tokens(logits, inference_params) |
|
else: |
|
logits = modify_logit_for_repetition_penalty( |
|
logits, sequences_cat, repetition_penalty |
|
) |
|
sampled_tokens = sample_tokens(logits, inference_params) |
|
sequences_cat = torch.cat([sequences_cat, sampled_tokens], dim=1) |
|
sequences.append(sampled_tokens) |
|
if streamer is not None: |
|
streamer.put(sampled_tokens.cpu()) |
|
if streamer is not None: |
|
streamer.end() |
|
if enable_timing: |
|
end.record() |
|
torch.cuda.synchronize() |
|
print(f"Prompt processing + decoding time: {(start.elapsed_time(end)):.0f}ms") |
|
output_cls = GreedySearchDecoderOnlyOutput if top_k == 1 else SampleDecoderOnlyOutput |
|
return output_cls(sequences=torch.cat(sequences, dim=1), scores=tuple(scores)) |
|
|
|
|
|
class GenerationMixin: |
|
def allocate_inference_cache(self, batch_size, max_seqlen, dtype=None, **kwargs): |
|
raise NotImplementedError |
|
|
|
def generate( |
|
self, |
|
input_ids, |
|
max_length, |
|
top_k=1, |
|
top_p=0.0, |
|
min_p=0.0, |
|
temperature=1.0, |
|
return_dict_in_generate=False, |
|
output_scores=False, |
|
**kwargs, |
|
): |
|
output = decode( |
|
input_ids, self, max_length, top_k=top_k, top_p=top_p, min_p = min_p, temperature=temperature, output_scores=output_scores, **kwargs |
|
) |
|
if not output_scores: |
|
output.scores = None |
|
return output if return_dict_in_generate else output.sequences |
|
|
|
|
|
@dataclass |
|
class DecodingCGCache: |
|
max_batch_size: int = 0 |
|
max_seqlen: int = 0 |
|
device = None |
|
dtype = None |
|
callables: dict = field(default_factory=dict) |
|
mempool = None |
|
inference_params: Optional[InferenceParams] = None |
|
run: Optional[Callable] = None |
|
|
|
|
|
@torch.inference_mode() |
|
def update_graph_cache( |
|
model, |
|
cache, |
|
batch_size, |
|
seqlen_og, |
|
max_seqlen, |
|
decoding_seqlens=(1,), |
|
dtype=None, |
|
n_warmups=2, |
|
): |
|
if cache is None: |
|
cache = DecodingCGCache() |
|
param_example = next(iter(model.parameters())) |
|
device = param_example.device |
|
if dtype is None: |
|
dtype = param_example.dtype |
|
if ( |
|
(device, dtype) != (cache.device, cache.dtype) |
|
or batch_size > cache.max_batch_size |
|
or max_seqlen > cache.max_seqlen |
|
): |
|
cache.callables = {} |
|
cache.mempool = None |
|
cache.inference_params = None |
|
gc.collect() |
|
cache.device, cache.dtype = device, dtype |
|
cache.max_batch_size, cache.max_seqlen = batch_size, max_seqlen |
|
assert hasattr(model, "allocate_inference_cache"), "CUDA graph decoding requires that the model has a method allocate_inference_cache" |
|
inf_cache = model.allocate_inference_cache(batch_size, max_seqlen, dtype) |
|
lengths_per_sample = torch.full((batch_size,), seqlen_og, dtype=torch.int32, device=device) |
|
cache.inference_params = InferenceParams( |
|
max_seqlen=max_seqlen, |
|
max_batch_size=batch_size, |
|
seqlen_offset=seqlen_og, |
|
key_value_memory_dict=inf_cache, |
|
lengths_per_sample=lengths_per_sample, |
|
) |
|
cache.mempool = torch.cuda.graphs.graph_pool_handle() |
|
for decoding_seqlen in decoding_seqlens: |
|
if (batch_size, decoding_seqlen) not in cache.callables: |
|
cache.callables[batch_size, decoding_seqlen] = capture_graph( |
|
model, |
|
cache.inference_params, |
|
batch_size, |
|
max_seqlen, |
|
decoding_seqlen=decoding_seqlen, |
|
mempool=cache.mempool, |
|
n_warmups=n_warmups, |
|
) |
|
|
|
def dispatch(input_ids, position_ids, seqlen): |
|
batch_size, decoding_seqlen = input_ids.shape[:2] |
|
return cache.callables[batch_size, decoding_seqlen](input_ids, position_ids, seqlen) |
|
|
|
cache.run = dispatch |
|
cache.inference_params.seqlen_offset = 0 |
|
return cache |
|
|
|
|
|
def capture_graph( |
|
model, inference_params, batch_size, max_seqlen, decoding_seqlen=1, mempool=None, n_warmups=2 |
|
): |
|
device = next(iter(model.parameters())).device |
|
input_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device) |
|
position_ids = torch.full((batch_size, decoding_seqlen), 0, dtype=torch.long, device=device) |
|
seqlen_offset_og = inference_params.seqlen_offset |
|
inference_params.seqlen_offset = max_seqlen - decoding_seqlen |
|
inference_params.lengths_per_sample[:] = inference_params.seqlen_offset |
|
|
|
|
|
s = torch.cuda.Stream() |
|
s.wait_stream(torch.cuda.current_stream()) |
|
with torch.cuda.stream(s): |
|
for _ in range(n_warmups): |
|
logits = model( |
|
input_ids, |
|
position_ids=position_ids, |
|
inference_params=inference_params, |
|
num_last_tokens=decoding_seqlen, |
|
).logits |
|
s.synchronize() |
|
|
|
|
|
|
|
if torch.distributed.is_initialized(): |
|
torch.distributed.barrier() |
|
torch.cuda.current_stream().wait_stream(s) |
|
|
|
|
|
graph = torch.cuda.CUDAGraph() |
|
with torch.cuda.graph(graph, pool=mempool): |
|
logits = model( |
|
input_ids, |
|
position_ids=position_ids, |
|
inference_params=inference_params, |
|
num_last_tokens=decoding_seqlen, |
|
).logits |
|
|
|
def run(new_input_ids, new_position_ids, seqlen): |
|
inference_params.lengths_per_sample[:] = seqlen |
|
input_ids.copy_(new_input_ids) |
|
position_ids.copy_(new_position_ids) |
|
graph.replay() |
|
return logits.clone() |
|
|
|
inference_params.seqlen_offset = seqlen_offset_og |
|
return run |
|
|