|
import sys |
|
from pathlib import Path |
|
sys.path.append(str(Path(__file__).resolve().parent.parent)) |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
import torch.nn as nn |
|
import torch.optim as optim |
|
from models.model import Microformer |
|
from tokenizers import Tokenizer |
|
from config import VOCAB_SIZE, EMBED_DIM, NUM_HEADS, FF_DIM, NUM_LAYERS, MAX_SEQ_LEN, ADAPTER_DIM |
|
import sqlite3 |
|
from datetime import datetime |
|
|
|
|
|
tokenizer = Tokenizer.from_file("data/tokenizer.json") |
|
VOCAB_SIZE = tokenizer.get_vocab_size() |
|
|
|
model = Microformer( |
|
vocab_size=VOCAB_SIZE, |
|
embed_dim=EMBED_DIM, |
|
num_heads=NUM_HEADS, |
|
ff_dim=FF_DIM, |
|
num_layers=NUM_LAYERS, |
|
max_seq_len=MAX_SEQ_LEN, |
|
long_term_adapter_dim=ADAPTER_DIM, |
|
session_adapter_dim=ADAPTER_DIM |
|
) |
|
model.load_state_dict(torch.load("microformer.pt")) |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
model.to(device) |
|
|
|
|
|
model.freeze_except_adapters(session_only=True, include_output=True) |
|
|
|
criterion = nn.CrossEntropyLoss() |
|
optimizer = optim.Adam( |
|
filter(lambda p: p.requires_grad, model.parameters()), |
|
lr=1e-2 |
|
) |
|
|
|
|
|
conn = sqlite3.connect("memory.db") |
|
c = conn.cursor() |
|
c.execute(""" |
|
CREATE TABLE IF NOT EXISTS memory ( |
|
timestamp TEXT, |
|
prompt TEXT, |
|
response TEXT |
|
) |
|
""") |
|
conn.commit() |
|
|
|
def top_k_top_p_filtering(logits, top_k=50, top_p=0.9): |
|
logits = logits.squeeze(0) |
|
probs = torch.softmax(logits, dim=-1) |
|
|
|
|
|
sorted_probs, sorted_indices = torch.sort(probs, descending=True) |
|
cumulative_probs = torch.cumsum(sorted_probs, dim=-1) |
|
|
|
|
|
sorted_mask = cumulative_probs > top_p |
|
sorted_mask[1:] = sorted_mask[:-1].clone() |
|
sorted_mask[0] = False |
|
|
|
|
|
if top_k < sorted_probs.size(0): |
|
sorted_mask[top_k:] = True |
|
|
|
|
|
sorted_probs[sorted_mask] = 0.0 |
|
|
|
|
|
sorted_probs /= sorted_probs.sum() |
|
sampled_relative_index = torch.multinomial(sorted_probs, 1).item() |
|
sampled_token_id = sorted_indices[sampled_relative_index].item() |
|
|
|
return sampled_token_id |
|
|
|
def generate(prompt, length=100, temperature=1.0, top_p=0.9, top_k=50): |
|
input_ids = tokenizer.encode(prompt).ids |
|
input_tensor = torch.tensor([input_ids], dtype=torch.long, device=device) |
|
|
|
eos_token_id = tokenizer.token_to_id("<EOS>") |
|
|
|
for _ in range(length): |
|
with torch.no_grad(): |
|
logits = model(input_tensor) |
|
logits = logits[:, -1, :] / temperature |
|
|
|
|
|
for token_id in input_tensor[0].tolist(): |
|
logits[0, token_id] *= 0.8 |
|
|
|
next_token_id = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p) |
|
|
|
input_tensor = torch.cat([input_tensor, torch.tensor([[next_token_id]], device=device)], dim=1) |
|
|
|
if next_token_id == eos_token_id: |
|
break |
|
|
|
output_ids = input_tensor[0].tolist() |
|
decoded = tokenizer.decode(output_ids) |
|
|
|
if "<EOS>" in decoded: |
|
decoded = decoded.split("<EOS>")[0].strip() |
|
|
|
return decoded |
|
|
|
def online_unsupervised_update(model, tokenizer, text, optimizer, loss_fn, device, max_len=64): |
|
|
|
ids = tokenizer.encode(text).ids + [tokenizer.token_to_id("<EOS>")] |
|
if len(ids) < 2: |
|
return None |
|
|
|
ids = ids[:max_len + 1] |
|
input_ids = ids[:-1] |
|
target_ids = ids[1:] |
|
input_ids += [tokenizer.token_to_id("<PAD>")] * (max_len - len(input_ids)) |
|
target_ids += [tokenizer.token_to_id("<PAD>")] * (max_len - len(target_ids)) |
|
input_tensor = torch.tensor([input_ids], dtype=torch.long, device=device) |
|
target_tensor = torch.tensor([target_ids], dtype=torch.long, device=device) |
|
|
|
model.train() |
|
logits = model(input_tensor) |
|
logits = logits.view(-1, logits.size(-1)) |
|
targets = target_tensor.view(-1) |
|
loss = loss_fn(logits, targets) |
|
optimizer.zero_grad() |
|
loss.backward() |
|
optimizer.step() |
|
model.eval() |
|
return loss.item() |
|
|
|
|
|
def reset_session_adapters(model): |
|
for layer in model.layers: |
|
if getattr(layer, 'session_adapter', None) is not None: |
|
for param in layer.session_adapter.parameters(): |
|
if param.data is not None: |
|
nn.init.zeros_(param.data) |
|
|
|
if __name__ == "__main__": |
|
while True: |
|
prompt = input("\nEnter a prompt (or 'exit' to quit): ") |
|
if prompt.lower() in {"exit", "quit"}: |
|
break |
|
temp = float(input("Temperature (e.g. 0.7, 1.0): ")) |
|
|
|
output = generate(prompt, length=100, temperature=temp, top_p=0.9, top_k=50) |
|
print("\nGenerated text:\n") |
|
print(output) |
|
|
|
|
|
teach = input("\nDo you want to teach the model a better answer? (y/N): ").strip().lower() |
|
if teach == "y": |
|
your_answer = input("Type your ideal response for this prompt: ") |
|
model.freeze_except_adapters(session_only=True, include_output=True) |
|
online_text = prompt + " " + your_answer |
|
loss = online_unsupervised_update( |
|
model, tokenizer, online_text, optimizer, criterion, device, max_len=MAX_SEQ_LEN |
|
) |
|
print(f"[Online update loss: {loss:.4f}]") |
|
else: |
|
model.freeze_except_adapters(session_only=True, include_output=True) |
|
online_text = prompt + " " + output |
|
loss = online_unsupervised_update( |
|
model, tokenizer, online_text, optimizer, criterion, device, max_len=MAX_SEQ_LEN |
|
) |
|
print(f"[Online (self-improve) update loss: {loss:.4f}]") |
|
|
|
|
|
c.execute("INSERT INTO memory (timestamp, prompt, response) VALUES (?, ?, ?)", |
|
(datetime.now().isoformat(timespec='seconds'), prompt, output)) |
|
conn.commit() |
|
|
|
print("\nRecent memory:") |
|
for row in c.execute("SELECT * FROM memory ORDER BY timestamp DESC LIMIT 5"): |
|
print(f"[{row[0]}] {row[1]} → {row[2]}") |
|
|
|
|
|
|
|
|