parkermoe commited on
Commit
253aeba
·
1 Parent(s): 29d0dff

'fix <pad>'

Browse files
Files changed (1) hide show
  1. app.py +2 -1
app.py CHANGED
@@ -78,7 +78,8 @@ class TextGenerator:
78
  next_token = self.sample_from(next_token_logits, temperature)
79
  tokens = torch.cat([tokens, torch.LongTensor([[next_token]]).to(device)], dim=1)
80
 
81
- generated_text = ' '.join(self.vocab.get_itos()[token] for token in tokens[0])
 
82
  return generated_text
83
 
84
  text_generator = TextGenerator(vocab=vocab, top_k=10)
 
78
  next_token = self.sample_from(next_token_logits, temperature)
79
  tokens = torch.cat([tokens, torch.LongTensor([[next_token]]).to(device)], dim=1)
80
 
81
+ generated_tokens = [token for token in tokens[0] if self.vocab.get_itos()[token] != '<pad>']
82
+ generated_text = ' '.join(self.vocab.get_itos()[token] for token in generated_tokens)
83
  return generated_text
84
 
85
  text_generator = TextGenerator(vocab=vocab, top_k=10)