fix
Browse files
app.py
CHANGED
@@ -13,9 +13,11 @@ from transformers import TextIteratorStreamer
|
|
13 |
import queue
|
14 |
|
15 |
class RichTextStreamer(TextIteratorStreamer):
|
16 |
-
def __init__(self, tokenizer, **kwargs):
|
17 |
super().__init__(tokenizer, **kwargs)
|
18 |
self.token_queue = queue.Queue()
|
|
|
|
|
19 |
|
20 |
def put(self, value):
|
21 |
if isinstance(value, torch.Tensor):
|
@@ -26,6 +28,9 @@ class RichTextStreamer(TextIteratorStreamer):
|
|
26 |
token_ids = [value]
|
27 |
|
28 |
for token_id in token_ids:
|
|
|
|
|
|
|
29 |
token_str = self.tokenizer.decode([token_id], **self.decode_kwargs)
|
30 |
is_special = token_id in self.tokenizer.all_special_ids
|
31 |
self.token_queue.put({
|
@@ -60,25 +65,29 @@ def chat_with_model(messages):
|
|
60 |
device = torch.device("cuda")
|
61 |
current_model.to(device).half()
|
62 |
|
63 |
-
|
64 |
-
|
|
|
|
|
65 |
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
|
|
|
73 |
generation_kwargs = dict(
|
74 |
**inputs,
|
75 |
-
max_new_tokens=
|
76 |
do_sample=True,
|
77 |
streamer=streamer,
|
78 |
eos_token_id=eos_id,
|
79 |
pad_token_id=pad_id
|
80 |
)
|
81 |
|
|
|
82 |
thread = threading.Thread(target=current_model.generate, kwargs=generation_kwargs)
|
83 |
thread.start()
|
84 |
|
|
|
13 |
import queue
|
14 |
|
15 |
class RichTextStreamer(TextIteratorStreamer):
|
16 |
+
def __init__(self, tokenizer, prompt_len=0, **kwargs):
|
17 |
super().__init__(tokenizer, **kwargs)
|
18 |
self.token_queue = queue.Queue()
|
19 |
+
self.prompt_len = prompt_len
|
20 |
+
self.count = 0
|
21 |
|
22 |
def put(self, value):
|
23 |
if isinstance(value, torch.Tensor):
|
|
|
28 |
token_ids = [value]
|
29 |
|
30 |
for token_id in token_ids:
|
31 |
+
self.count += 1
|
32 |
+
if self.count <= self.prompt_len:
|
33 |
+
continue # skip prompt tokens
|
34 |
token_str = self.tokenizer.decode([token_id], **self.decode_kwargs)
|
35 |
is_special = token_id in self.tokenizer.all_special_ids
|
36 |
self.token_queue.put({
|
|
|
65 |
device = torch.device("cuda")
|
66 |
current_model.to(device).half()
|
67 |
|
68 |
+
# 1. Tokenize prompt
|
69 |
+
prompt = "Your input here"
|
70 |
+
inputs = current_tokenizer(prompt, return_tensors="pt").to(device)
|
71 |
+
prompt_len = inputs["input_ids"].shape[-1]
|
72 |
|
73 |
+
# 2. Init streamer with prompt_len
|
74 |
+
streamer = RichTextStreamer(
|
75 |
+
tokenizer=current_tokenizer,
|
76 |
+
prompt_len=prompt_len,
|
77 |
+
skip_special_tokens=False
|
78 |
+
)
|
79 |
|
80 |
+
# 3. Build generation kwargs
|
81 |
generation_kwargs = dict(
|
82 |
**inputs,
|
83 |
+
max_new_tokens=256,
|
84 |
do_sample=True,
|
85 |
streamer=streamer,
|
86 |
eos_token_id=eos_id,
|
87 |
pad_token_id=pad_id
|
88 |
)
|
89 |
|
90 |
+
# 4. Launch generation in a thread
|
91 |
thread = threading.Thread(target=current_model.generate, kwargs=generation_kwargs)
|
92 |
thread.start()
|
93 |
|