Fix richtextstreamer for multiple tokens input
Browse files
app.py
CHANGED
@@ -16,17 +16,24 @@ class RichTextStreamer(TextIteratorStreamer):
|
|
16 |
def __init__(self, tokenizer, **kwargs):
|
17 |
super().__init__(tokenizer, **kwargs)
|
18 |
self.token_queue = queue.Queue()
|
19 |
-
|
20 |
def put(self, value):
|
21 |
-
#
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
def __iter__(self):
|
32 |
while True:
|
|
|
16 |
def __init__(self, tokenizer, **kwargs):
|
17 |
super().__init__(tokenizer, **kwargs)
|
18 |
self.token_queue = queue.Queue()
|
19 |
+
|
20 |
def put(self, value):
|
21 |
+
# Convert incoming tensor or list to flat list of token IDs
|
22 |
+
if isinstance(value, torch.Tensor):
|
23 |
+
token_ids = value.view(-1).tolist()
|
24 |
+
elif isinstance(value, list):
|
25 |
+
token_ids = value
|
26 |
+
else:
|
27 |
+
token_ids = [value]
|
28 |
+
|
29 |
+
for token_id in token_ids:
|
30 |
+
token_str = self.tokenizer.decode([token_id], **self.decode_kwargs)
|
31 |
+
is_special = token_id in self.tokenizer.all_special_ids
|
32 |
+
self.token_queue.put({
|
33 |
+
"token_id": token_id,
|
34 |
+
"token": token_str,
|
35 |
+
"is_special": is_special
|
36 |
+
})
|
37 |
|
38 |
def __iter__(self):
|
39 |
while True:
|