abancp commited on
Commit
d727d22
·
verified ·
1 Parent(s): dd1b76c

Update inference_fine_tune.py

Browse files
Files changed (1) hide show
  1. inference_fine_tune.py +23 -27
inference_fine_tune.py CHANGED
@@ -31,41 +31,37 @@ model.eval()
31
  state = torch.load(model_path,map_location=torch.device('cpu'))
32
  model.load_state_dict(state['model_state_dict'])
33
 
34
- def generate_response(prompt:str):
35
- print("Prompt : ",prompt)
36
-
37
- word = ""
38
  input_tokens = tokenizer.encode(prompt).ids
39
- input_tokens.extend([user_token_id] + input_tokens + [ai_token_id] )
 
40
  if len(input_tokens) > config['seq_len']:
41
- print(f"exceeding max length of input : {config['seq_len']}")
42
- exit()
43
- input_tokens = torch.tensor(input_tokens)
44
- decoder_input = input_tokens.to(device)
45
- if decoder_input.dim() == 1:
46
- decoder_input = decoder_input.unsqueeze(0)
47
  temperature = 0.7
48
  top_k = 50
49
  i = 0
50
- print("Output : ",end="")
51
- while decoder_input.shape[1] < 2000:
52
- # Apply causal mask based on current decoder_input length
53
- # decoder_mask = (decoder_input != pad_token_id).unsqueeze(0).int() & causal_mask(decoder_input.size(1)).type_as(input_mask).to(device)
54
- # Get model output
55
- out = model.decode(decoder_input)
56
- logits = model.project(out[:, -1]) # Get logits for last token
57
  logits = logits / temperature
58
  top_k_logits, top_k_indices = torch.topk(logits, top_k)
59
  probs = torch.softmax(top_k_logits, dim=-1)
60
  next_token = torch.multinomial(probs, num_samples=1)
61
  next_token = top_k_indices.gather(-1, next_token)
62
- word += tokenizer.decode([next_token.item()])
63
- print(word,end="")
64
- i+=1
65
- decoder_input = torch.cat([decoder_input, next_token], dim=1)
66
- if decoder_input.shape[1] > config['seq_len']:
67
- decoder_input = decoder_input[:,-config['seq_len']:]
68
- if next_token.item() == eos_token_id or i >= 1024:
 
 
69
  break
70
- print()
71
- return word
 
31
  state = torch.load(model_path,map_location=torch.device('cpu'))
32
  model.load_state_dict(state['model_state_dict'])
33
 
34
+ def generate_response(prompt: str):
35
+ print("Prompt:", prompt)
 
 
36
  input_tokens = tokenizer.encode(prompt).ids
37
+ input_tokens = [user_token_id] + input_tokens + [ai_token_id]
38
+
39
  if len(input_tokens) > config['seq_len']:
40
+ print(f"Exceeding max length of input: {config['seq_len']}")
41
+ return
42
+
43
+ input_tokens = torch.tensor(input_tokens).unsqueeze(0).to(device) # (1, seq_len)
44
+
 
45
  temperature = 0.7
46
  top_k = 50
47
  i = 0
48
+
49
+ while input_tokens.shape[1] < 2000:
50
+ out = model.decode(input_tokens)
51
+ logits = model.project(out[:, -1])
 
 
 
52
  logits = logits / temperature
53
  top_k_logits, top_k_indices = torch.topk(logits, top_k)
54
  probs = torch.softmax(top_k_logits, dim=-1)
55
  next_token = torch.multinomial(probs, num_samples=1)
56
  next_token = top_k_indices.gather(-1, next_token)
57
+
58
+ decoded_word = tokenizer.decode([next_token.item()])
59
+ yield decoded_word # Streaming output token-by-token
60
+
61
+ input_tokens = torch.cat([input_tokens, next_token], dim=1)
62
+ if input_tokens.shape[1] > config['seq_len']:
63
+ input_tokens = input_tokens[:, -config['seq_len']:]
64
+
65
+ if next_token.item() == eos_token_id or i >= 1024:
66
  break
67
+ i += 1