Ruurd commited on
Commit
e7c4f38
·
1 Parent(s): 2bf0c40
Files changed (1) hide show
  1. app.py +19 -10
app.py CHANGED
@@ -13,9 +13,11 @@ from transformers import TextIteratorStreamer
13
  import queue
14
 
15
  class RichTextStreamer(TextIteratorStreamer):
16
- def __init__(self, tokenizer, **kwargs):
17
  super().__init__(tokenizer, **kwargs)
18
  self.token_queue = queue.Queue()
 
 
19
 
20
  def put(self, value):
21
  if isinstance(value, torch.Tensor):
@@ -26,6 +28,9 @@ class RichTextStreamer(TextIteratorStreamer):
26
  token_ids = [value]
27
 
28
  for token_id in token_ids:
 
 
 
29
  token_str = self.tokenizer.decode([token_id], **self.decode_kwargs)
30
  is_special = token_id in self.tokenizer.all_special_ids
31
  self.token_queue.put({
@@ -60,25 +65,29 @@ def chat_with_model(messages):
60
  device = torch.device("cuda")
61
  current_model.to(device).half()
62
 
63
- inputs = current_tokenizer(prompt, return_tensors="pt")
64
- inputs = {k: v.to(device) for k, v in inputs.items()}
 
 
65
 
66
- streamer = RichTextStreamer(current_tokenizer, skip_prompt=True, skip_special_tokens=False)
67
-
68
- max_new_tokens = 256
69
- generated_tokens = 0
70
- output_text = ""
71
- in_think = False
72
 
 
73
  generation_kwargs = dict(
74
  **inputs,
75
- max_new_tokens=max_new_tokens,
76
  do_sample=True,
77
  streamer=streamer,
78
  eos_token_id=eos_id,
79
  pad_token_id=pad_id
80
  )
81
 
 
82
  thread = threading.Thread(target=current_model.generate, kwargs=generation_kwargs)
83
  thread.start()
84
 
 
13
  import queue
14
 
15
  class RichTextStreamer(TextIteratorStreamer):
16
+ def __init__(self, tokenizer, prompt_len=0, **kwargs):
17
  super().__init__(tokenizer, **kwargs)
18
  self.token_queue = queue.Queue()
19
+ self.prompt_len = prompt_len
20
+ self.count = 0
21
 
22
  def put(self, value):
23
  if isinstance(value, torch.Tensor):
 
28
  token_ids = [value]
29
 
30
  for token_id in token_ids:
31
+ self.count += 1
32
+ if self.count <= self.prompt_len:
33
+ continue # skip prompt tokens
34
  token_str = self.tokenizer.decode([token_id], **self.decode_kwargs)
35
  is_special = token_id in self.tokenizer.all_special_ids
36
  self.token_queue.put({
 
65
  device = torch.device("cuda")
66
  current_model.to(device).half()
67
 
68
+ # 1. Tokenize prompt
69
+ prompt = "Your input here"
70
+ inputs = current_tokenizer(prompt, return_tensors="pt").to(device)
71
+ prompt_len = inputs["input_ids"].shape[-1]
72
 
73
+ # 2. Init streamer with prompt_len
74
+ streamer = RichTextStreamer(
75
+ tokenizer=current_tokenizer,
76
+ prompt_len=prompt_len,
77
+ skip_special_tokens=False
78
+ )
79
 
80
+ # 3. Build generation kwargs
81
  generation_kwargs = dict(
82
  **inputs,
83
+ max_new_tokens=256,
84
  do_sample=True,
85
  streamer=streamer,
86
  eos_token_id=eos_id,
87
  pad_token_id=pad_id
88
  )
89
 
90
+ # 4. Launch generation in a thread
91
  thread = threading.Thread(target=current_model.generate, kwargs=generation_kwargs)
92
  thread.start()
93