LLAMA-3-From-Scratch / inference.py
DrNerd's picture
Upload 7 files
48b48f8 verified
# ==============================================================================
# Inference Script
# ==============================================================================
# --- Necessary Imports ---
import torch
import torch.nn as nn
from dataclasses import dataclass, field
import math
import torch.nn.functional as F
from transformers import AutoTokenizer
import os
import glob
import time
import datetime
import traceback
import dataclasses # Make sure this is imported
# --- Model Configuration ---
# IMPORTANT: This definition MUST exactly match the one used during training
# when the checkpoint was saved.
@dataclass
class ModelArgs:
# --- ~221M Config used for training step_1200 ---
hidden_size: int = 768; num_hidden_layers: int = 12; num_attention_heads: int = 12
num_key_value_heads: int = 12; intermediate_size: int = 2048; vocab_size: int = 128000
rms_norm_eps: float = 1e-5; rope_theta: float = 500000.0; max_position_embeddings: int = 4096
head_dim: int = field(init=False)
add_recency_bias: bool = False # Ensure this matches the value used when saving the checkpoint
def __post_init__(self):
self.head_dim = self.hidden_size // self.num_attention_heads
if self.hidden_size % self.num_attention_heads != 0: raise ValueError("hidden_size % num_attention_heads != 0")
if self.num_attention_heads % self.num_key_value_heads != 0: raise ValueError("num_attention_heads % num_key_value_heads != 0")
# --- Model Components (RMSNorm, RoPE funcs, Attention, FeedForward, TransformerBlock, Llama) ---
# V V V V V V V V V V V V V V V V V V V V V V V V V V V V V V V V V V V V V V V
# --- PASTE THE FULL DEFINITIONS OF THE FOLLOWING CLASSES/FUNCTIONS HERE ---
# --- from your model_architecture.py script: ---
#
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6): super().__init__(); self.eps = eps; self.weight = nn.Parameter(torch.ones(dim))
def _norm(self, x): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
def forward(self, x): original_dtype = x.dtype; output = self._norm(x.float()).to(original_dtype); return output * self.weight
def precompute_theta_pos_frequencies(head_dim: int, seq_len: int, device: str | torch.device, theta: float = 10000.0):
if head_dim % 2 != 0: raise ValueError("head_dim must be even for RoPE")
theta_indices = torch.arange(0, head_dim, 2).float(); theta_freqs = 1.0 / (theta**(theta_indices / head_dim))
target_device = torch.device(device) if isinstance(device, str) else device; theta_freqs = theta_freqs.to(target_device)
positions = torch.arange(seq_len, device=target_device).float(); freqs = torch.outer(positions, theta_freqs).float(); return freqs, positions
def apply_rotary_embeddings(x: torch.Tensor, freqs_cis_full: torch.Tensor, positions: torch.Tensor):
positions = positions.long(); max_pos = freqs_cis_full.shape[0]
if torch.max(positions) >= max_pos: positions = torch.clamp(positions, max=max_pos - 1)
freqs = freqs_cis_full[positions]; freqs = freqs.unsqueeze(0).unsqueeze(2)
bsz, seq_len, n_part_heads, head_dim = x.shape; x1 = x[..., : head_dim // 2]; x2 = x[..., head_dim // 2 :]
cos_freqs = torch.cos(freqs).type_as(x); sin_freqs = torch.sin(freqs).type_as(x)
rotated_x1 = x1 * cos_freqs - x2 * sin_freqs; rotated_x2 = x1 * sin_freqs + x2 * cos_freqs
rotated_x = torch.cat([rotated_x1, rotated_x2], dim=-1); return rotated_x.type_as(x)
class Attention(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__(); self.args = args; self.num_heads = args.num_attention_heads; self.num_kv_heads = args.num_key_value_heads
self.head_dim = args.head_dim; self.repeats = self.num_heads // self.num_kv_heads
self.wq = nn.Linear(args.hidden_size, args.num_attention_heads * args.head_dim, bias=False)
self.wk = nn.Linear(args.hidden_size, args.num_key_value_heads * args.head_dim, bias=False)
self.wv = nn.Linear(args.hidden_size, args.num_key_value_heads * args.head_dim, bias=False)
self.wo = nn.Linear(args.num_attention_heads * args.head_dim, args.hidden_size, bias=False)
def _repeat_kv(self, x: torch.Tensor, n_rep: int) -> torch.Tensor:
bsz, n_kv_heads, seqlen, head_dim = x.shape;
if n_rep == 1: return x
return (x[:, :, None, :, :].expand(bsz, n_kv_heads, n_rep, seqlen, head_dim).reshape(bsz, n_kv_heads * n_rep, seqlen, head_dim))
def _create_recency_bias(self, seqlen, full_seqlen, device, dtype, bias_strength=0.1, decay_rate=0.9):
bias = torch.zeros((1, 1, seqlen, full_seqlen), device=device, dtype=dtype); indices = torch.arange(full_seqlen, device=device)
rel_pos = torch.arange(seqlen, device=device).unsqueeze(1) - indices.unsqueeze(0); mask = rel_pos >= 0
decaying_bias = bias_strength * (decay_rate ** (-rel_pos[mask])); bias[:, :, mask] = decaying_bias.type_as(bias); return bias
def forward(self, x: torch.Tensor, freqs_cis_full: torch.Tensor, positions: torch.Tensor, mask: torch.Tensor | None = None, cache: tuple[torch.Tensor, torch.Tensor] | None = None) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
bsz, seqlen, _ = x.shape; xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
xq = xq.view(bsz, seqlen, self.num_heads, self.head_dim); xk = xk.view(bsz, seqlen, self.num_kv_heads, self.head_dim); xv = xv.view(bsz, seqlen, self.num_kv_heads, self.head_dim)
xq = apply_rotary_embeddings(xq, freqs_cis_full, positions); xk = apply_rotary_embeddings(xk, freqs_cis_full, positions)
xk = xk.transpose(1, 2); xv = xv.transpose(1, 2)
if cache is not None: cache_k, cache_v = cache; keys = torch.cat((cache_k.to(xk.device), xk), dim=2); values = torch.cat((cache_v.to(xv.device), xv), dim=2)
else: keys = xk; values = xv
updated_cache = (keys.detach(), values.detach()); keys_repeated = self._repeat_kv(keys, self.repeats); values_repeated = self._repeat_kv(values, self.repeats)
xq = xq.transpose(1, 2); scores = torch.matmul(xq.float(), keys_repeated.transpose(-2, -1).float()) / math.sqrt(self.head_dim)
if self.args.add_recency_bias:
full_seqlen = keys_repeated.shape[-2]; recency_bias = self._create_recency_bias(seqlen, full_seqlen, device=scores.device, dtype=scores.dtype); scores = scores + recency_bias
if mask is not None:
full_seqlen = keys_repeated.shape[-2]; expected_mask_shape_end = (seqlen, full_seqlen)
if mask.shape[-2:] != expected_mask_shape_end:
try: mask_slice = mask[:, :, -seqlen:, :full_seqlen]; scores = scores + mask_slice.float()
except Exception: pass
else: scores = scores + mask.float()
scores = nn.functional.softmax(scores, dim=-1).type_as(xq); output = torch.matmul(scores, values_repeated)
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1); output = self.wo(output); return output, updated_cache
class FeedForward(nn.Module):
def __init__(self, args: ModelArgs): super().__init__(); self.w1 = nn.Linear(args.hidden_size, args.intermediate_size, bias=False); self.w3 = nn.Linear(args.hidden_size, args.intermediate_size, bias=False); self.w2 = nn.Linear(args.intermediate_size, args.hidden_size, bias=False)
def forward(self, x: torch.Tensor) -> torch.Tensor: return self.w2(nn.functional.silu(self.w1(x)) * self.w3(x))
class TransformerBlock(nn.Module):
def __init__(self, args: ModelArgs): super().__init__(); self.args = args; self.attention_norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps); self.attention = Attention(args); self.ffn_norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps); self.feed_forward = FeedForward(args)
def forward(self, x: torch.Tensor, freqs_cis_full: torch.Tensor, positions: torch.Tensor, mask: torch.Tensor | None = None, cache: tuple[torch.Tensor, torch.Tensor] | None = None) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
r, cache = self.attention(self.attention_norm(x), freqs_cis_full, positions, mask, cache); h = x + r; r = self.feed_forward(self.ffn_norm(h)); out = h + r; return out, cache
class Llama(nn.Module):
def __init__(self, args: ModelArgs):
super().__init__(); self.args = args; self.tok_embeddings = nn.Embedding(args.vocab_size, args.hidden_size); self.layers = nn.ModuleList([TransformerBlock(args) for _ in range(args.num_hidden_layers)])
self.norm = RMSNorm(args.hidden_size, eps=args.rms_norm_eps); self.tok_embeddings.weight.requires_grad = True
freqs_cis, _ = precompute_theta_pos_frequencies(args.head_dim, args.max_position_embeddings, device='cpu', theta=args.rope_theta)
self.register_buffer("freqs_cis", freqs_cis, persistent=False)
def forward(self, tokens: torch.Tensor, positions: torch.Tensor):
bsz, seqlen = tokens.shape; h = self.tok_embeddings(tokens); freqs_cis_full = self.freqs_cis.to(h.device); mask = None
if seqlen > 1: mask = torch.full((1, 1, seqlen, seqlen), float("-inf"), device=h.device); mask = torch.triu(mask, diagonal=1).type_as(h)
positions = positions.to(h.device)
for layer in self.layers: h, _ = layer(h, freqs_cis_full, positions, mask, cache=None) # Pass cache=None for non-cached forward
h = self.norm(h); output = F.linear(h, self.tok_embeddings.weight); return output # Use tied weights
@torch.no_grad()
def generate(model: Llama, tokenizer: AutoTokenizer, prompt: str, max_new_tokens: int, temperature: float = 1.0, top_k: int | None = None, top_p: float | None = None):
model.eval() # CORRECTED: Separate line
try:
model_device = next(model.parameters()).device
model_dtype = next(model.parameters()).dtype
except StopIteration:
print("Warning: Model has no parameters. Assuming CPU and float32.")
model_device = torch.device("cpu")
model_dtype = torch.float32
prompt_ids = tokenizer.encode(prompt, add_special_tokens=True); tokens = torch.tensor([prompt_ids], dtype=torch.long, device=model_device)
cache = [(torch.zeros((1, model.args.num_key_value_heads, 0, model.args.head_dim), device=model_device, dtype=model_dtype),
torch.zeros((1, model.args.num_key_value_heads, 0, model.args.head_dim), device=model_device, dtype=model_dtype))
for _ in range(model.args.num_hidden_layers)]
generated_token_ids = []; current_tokens = tokens; print(f"Generating {max_new_tokens} tokens from prompt: '{prompt}'"); print("Output: ", end='')
full_freqs_cis = model.freqs_cis.to(model_device)
for i in range(max_new_tokens):
current_seq_len = current_tokens.shape[1]; start_pos = cache[0][0].shape[2]; positions = torch.arange(start_pos, start_pos + current_seq_len, device=model_device)
current_mask = None;
if i == 0 and current_seq_len > 1: current_mask = torch.full((1, 1, current_seq_len, current_seq_len), float("-inf"), device=model_device); current_mask = torch.triu(current_mask, diagonal=1).type(model_dtype)
h = model.tok_embeddings(current_tokens); updated_cache_list = []
for layer_idx, layer in enumerate(model.layers): h, updated_layer_cache = layer(h, full_freqs_cis, positions, current_mask, cache[layer_idx]); updated_cache_list.append(updated_layer_cache)
cache = updated_cache_list; h = model.norm(h); logits = F.linear(h, model.tok_embeddings.weight)
next_token_logits = logits[:, -1, :]
if temperature == 0: next_token_id = torch.argmax(next_token_logits, dim=-1, keepdim=True)
else:
next_token_logits = next_token_logits / temperature
if top_k is not None and top_k > 0: v, _ = torch.topk(next_token_logits, min(top_k, next_token_logits.size(-1))); next_token_logits[next_token_logits < v[:, [-1]]] = float('-inf')
if top_p is not None and 0.0 < top_p < 1.0:
probs_for_filter = F.softmax(next_token_logits, dim=-1); probs_sort, probs_idx = torch.sort(probs_for_filter, descending=True); probs_sum = torch.cumsum(probs_sort, dim=-1)
mask_top_p = probs_sum > top_p; mask_top_p[..., 0] = False; mask_top_p[..., 1:] = mask_top_p[..., :-1].clone(); indices_to_remove = mask_top_p.scatter(1, probs_idx, mask_top_p); next_token_logits[indices_to_remove] = float('-inf')
probs = F.softmax(next_token_logits, dim=-1); next_token_id = torch.multinomial(probs, num_samples=1)
if tokenizer.eos_token_id is not None and next_token_id.item() == tokenizer.eos_token_id: print("\n[EOS token reached]"); break
next_token_id_item = next_token_id.item(); generated_token_ids.append(next_token_id_item); current_tokens = next_token_id.clone()
print(tokenizer.decode([next_token_id_item]), end='', flush=True)
if len(generated_token_ids) >= max_new_tokens: break
print("\n--- Generation Complete ---"); final_token_ids = prompt_ids + generated_token_ids; full_generated_text = tokenizer.decode(final_token_ids, skip_special_tokens=False)
print(f"\nFull generated text:\n{full_generated_text}"); return full_generated_text
# --- End Placeholders ---
# --- Main Inference Execution ---
if __name__ == "__main__":
# --- Configuration for Inference ---
# --- !! USE SPECIFIC WINDOWS PATH !! ---
raw_checkpoint_path = r".\step_800.pt" # <<< CHANGED to step 1200
# --- Normalize the path ---
checkpoint_path = os.path.normpath(raw_checkpoint_path)
# --- End Adjust ---
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"\n--- Inference Setup ---")
print(f"Using device: {device}")
print(f"Attempting to load checkpoint: {checkpoint_path}")
# --- Load Checkpoint and Model Args ---
if not os.path.exists(checkpoint_path):
# Removed the fallback logic as we are specifying an exact path
exit(f"Error: Checkpoint file not found at the specified path: {checkpoint_path}")
try:
# Load checkpoint to CPU first
checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False) # weights_only=False needed for ModelArgs
# Load args dict and instantiate ModelArgs
saved_args_data = checkpoint.get('model_args', checkpoint.get('model_args_dict')) # Check both keys
if not saved_args_data: exit("Error: model_args not found in checkpoint.")
if not isinstance(saved_args_data, dict): saved_args_dict = dataclasses.asdict(saved_args_data)
else: saved_args_dict = saved_args_data
init_field_names = {f.name for f in dataclasses.fields(ModelArgs) if f.init}
filtered_args_dict = {k: v for k, v in saved_args_dict.items() if k in init_field_names}
config_inf = ModelArgs(**filtered_args_dict)
print(f"Loaded model config from checkpoint: {config_inf}")
# --- Instantiate Model ---
model_inf = Llama(config_inf) # Instantiate on CPU
print("Model instantiated on CPU.")
# --- Load Weights ---
model_inf.load_state_dict(checkpoint['model_state_dict'])
print("Model weights loaded.")
model_inf.to(device) # Move model to target device
print(f"Model moved to {device}.")
# --- Prepare for Inference ---
model_inf.eval()
if device.type == 'cuda':
try: model_inf = model_inf.half(); print("Converted loaded model to float16 for inference.")
except Exception as e: print(f"Could not convert model to float16: {e}")
except Exception as e: exit(f"Error loading checkpoint or instantiating model: {e}")
# --- Load Tokenizer ---
tokenizer_name_inf = "deepseek-ai/DeepSeek-R1"
print(f"Loading tokenizer: {tokenizer_name_inf}")
try:
tokenizer_inf = AutoTokenizer.from_pretrained(tokenizer_name_inf, trust_remote_code=True)
if tokenizer_inf.pad_token is None:
if tokenizer_inf.eos_token: tokenizer_inf.pad_token = tokenizer_inf.eos_token
else: tokenizer_inf.add_special_tokens({'pad_token': '[PAD]'})
print("Tokenizer loaded.")
except Exception as e: exit(f"Error loading tokenizer: {e}")
# --- Run Generation ---
print(f"\n--- Running Generation with Loaded Checkpoint ({os.path.basename(checkpoint_path)}) ---") # Updated print
prompt_inf = "Valkyria Chronicles is a tactical role-playing game developed and published by"
max_gen_len = 100
gen_temperature = 0.7
gen_top_k = 50
gen_top_p = 0.9
try:
start_time_inf = time.time()
_ = generate(
model=model_inf, tokenizer=tokenizer_inf, prompt=prompt_inf,
max_new_tokens=max_gen_len, temperature=gen_temperature,
top_k=gen_top_k, top_p=gen_top_p
)
end_time_inf = time.time()
print(f"\nInference duration: {datetime.timedelta(seconds=int(end_time_inf - start_time_inf))}")
print("\n(Output quality depends heavily on limited training. Expect limited coherence.)")
except Exception as e: print(f"\nAn error occurred during generation: {e}"); traceback.print_exc()
print("\n--- Inference Script Section Finished ---")