Ruurd commited on
Commit
55c85d9
·
1 Parent(s): 86a7aaf

Fix richtextstreamer for multiple tokens input

Browse files
Files changed (1) hide show
  1. app.py +17 -10
app.py CHANGED
@@ -16,17 +16,24 @@ class RichTextStreamer(TextIteratorStreamer):
16
  def __init__(self, tokenizer, **kwargs):
17
  super().__init__(tokenizer, **kwargs)
18
  self.token_queue = queue.Queue()
19
-
20
  def put(self, value):
21
- # Instead of just decoding here, we emit full info per token
22
- token_id = value.item() if hasattr(value, "item") else value
23
- token_str = self.tokenizer.decode([token_id], **self.decode_kwargs)
24
- is_special = token_id in self.tokenizer.all_special_ids
25
- self.token_queue.put({
26
- "token_id": token_id,
27
- "token": token_str,
28
- "is_special": is_special
29
- })
 
 
 
 
 
 
 
30
 
31
  def __iter__(self):
32
  while True:
 
16
  def __init__(self, tokenizer, **kwargs):
17
  super().__init__(tokenizer, **kwargs)
18
  self.token_queue = queue.Queue()
19
+
20
  def put(self, value):
21
+ # Convert incoming tensor or list to flat list of token IDs
22
+ if isinstance(value, torch.Tensor):
23
+ token_ids = value.view(-1).tolist()
24
+ elif isinstance(value, list):
25
+ token_ids = value
26
+ else:
27
+ token_ids = [value]
28
+
29
+ for token_id in token_ids:
30
+ token_str = self.tokenizer.decode([token_id], **self.decode_kwargs)
31
+ is_special = token_id in self.tokenizer.all_special_ids
32
+ self.token_queue.put({
33
+ "token_id": token_id,
34
+ "token": token_str,
35
+ "is_special": is_special
36
+ })
37
 
38
  def __iter__(self):
39
  while True: