Update processing_qwen2_ts.py to work with the latest vllm patch for ChatTS support. (#16)
Browse files- Update processing_qwen2_ts.py to work with the latest vllm patch for ChatTS support. (d9c80001adf853e6d9c275b6e981b8d352ee0e5f)
Co-authored-by: Alexander Chemeris <alexanderchemeris@users.noreply.huggingface.co>
- processing_qwen2_ts.py +55 -36
processing_qwen2_ts.py
CHANGED
@@ -91,45 +91,62 @@ class Qwen2TSProcessor(ProcessorMixin):
|
|
91 |
if timeseries is None:
|
92 |
timeseries = []
|
93 |
|
94 |
-
encoded_ts_arrays = []
|
95 |
reconstructed_prompts = []
|
96 |
-
|
97 |
-
|
98 |
-
# Split prompt by <ts><ts/> placeholders
|
99 |
-
last_ts_cnt = total_ts_cnt
|
100 |
-
prompt_segments = prompt.split("<ts><ts/>")
|
101 |
-
total_ts_cnt = total_ts_cnt + len(prompt_segments) - 1
|
102 |
-
|
103 |
-
# Encode each time series and rebuild the prompt
|
104 |
-
reconstructed_prompt = prompt_segments[0]
|
105 |
-
|
106 |
-
for i, ts in enumerate(timeseries[last_ts_cnt:total_ts_cnt]):
|
107 |
-
encoded_ts, ts_prompt, _ = sp_encoding(ts, eots_token=not vllm_flag)
|
108 |
-
reconstructed_prompt += ts_prompt + prompt_segments[i + 1]
|
109 |
-
# Ensure time series shape [1, seq_len, feature_dim] for batch concatenation
|
110 |
-
encoded_ts_arrays.append(encoded_ts[None, ...])
|
111 |
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
f"Mismatch between <ts><ts/> placeholders ({total_ts_cnt}) "
|
117 |
-
f"and time series ({len(encoded_ts_arrays)})."
|
118 |
-
)
|
119 |
-
|
120 |
-
if len(encoded_ts_arrays) > 0:
|
121 |
-
# Pad time series to the same length
|
122 |
-
max_length = max(ts.shape[1] for ts in encoded_ts_arrays)
|
123 |
-
padded_ts_arrays = [
|
124 |
-
np.pad(ts, ((0, 0), (0, max_length - ts.shape[1]), (0, 0)), mode="constant", constant_values=0.0)
|
125 |
-
for ts in encoded_ts_arrays
|
126 |
-
]
|
127 |
-
concatenated_ts = np.concatenate(padded_ts_arrays, axis=0) # Shape: [batch_size, max_length, feature_dim]
|
128 |
|
129 |
-
#
|
130 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
else:
|
132 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
133 |
|
134 |
# Tokenize the processed prompt
|
135 |
tokenizer_outputs = {}
|
@@ -138,7 +155,9 @@ class Qwen2TSProcessor(ProcessorMixin):
|
|
138 |
|
139 |
# Create the final output
|
140 |
outputs = tokenizer_outputs
|
141 |
-
if
|
|
|
|
|
142 |
outputs["timeseries"] = concatenated_ts
|
143 |
|
144 |
return BatchFeature(data=outputs)
|
|
|
91 |
if timeseries is None:
|
92 |
timeseries = []
|
93 |
|
|
|
94 |
reconstructed_prompts = []
|
95 |
+
concatenated_ts = None
|
96 |
+
ts_tokens = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
|
98 |
+
if vllm_flag:
|
99 |
+
# All prompt modifications have to be done inside of the vLLM
|
100 |
+
# to work correctly with its caching mechanism.
|
101 |
+
reconstructed_prompts = text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
102 |
|
103 |
+
# Process timeseries data
|
104 |
+
encoded_ts_arrays = []
|
105 |
+
for ts in timeseries:
|
106 |
+
# Get the normalized data and prompt text
|
107 |
+
encoded_ts, ts_prompt, _ = sp_encoding(ts, eots_token=False)
|
108 |
+
# Tokenize the ts_prompt and add to the tokens list
|
109 |
+
if self.tokenizer is not None:
|
110 |
+
tokens = self.tokenizer.encode(ts_prompt, add_special_tokens=False)
|
111 |
+
ts_tokens.append(tokens)
|
112 |
+
encoded_ts_arrays.append(encoded_ts[None, ...])
|
113 |
else:
|
114 |
+
encoded_ts_arrays = []
|
115 |
+
total_ts_cnt = 0
|
116 |
+
for idx, prompt in enumerate(text):
|
117 |
+
# Split prompt by <ts><ts/> placeholders
|
118 |
+
last_ts_cnt = total_ts_cnt
|
119 |
+
prompt_segments = prompt.split("<ts><ts/>")
|
120 |
+
total_ts_cnt = total_ts_cnt + len(prompt_segments) - 1
|
121 |
+
|
122 |
+
# Encode each time series and rebuild the prompt
|
123 |
+
reconstructed_prompt = prompt_segments[0]
|
124 |
+
|
125 |
+
for i, ts in enumerate(timeseries[last_ts_cnt:total_ts_cnt]):
|
126 |
+
encoded_ts, ts_prompt, _ = sp_encoding(ts, eots_token=not vllm_flag)
|
127 |
+
reconstructed_prompt += ts_prompt + prompt_segments[i + 1]
|
128 |
+
# Ensure time series shape [1, seq_len, feature_dim] for batch concatenation
|
129 |
+
encoded_ts_arrays.append(encoded_ts[None, ...])
|
130 |
+
|
131 |
+
reconstructed_prompts.append(reconstructed_prompt)
|
132 |
+
|
133 |
+
if len(timeseries) != len(encoded_ts_arrays):
|
134 |
+
raise ValueError(
|
135 |
+
f"Mismatch between <ts><ts/> placeholders ({total_ts_cnt}) "
|
136 |
+
f"and time series ({len(encoded_ts_arrays)})."
|
137 |
+
)
|
138 |
+
|
139 |
+
if len(encoded_ts_arrays) > 0:
|
140 |
+
# Pad time series to the same length
|
141 |
+
max_length = max(ts.shape[1] for ts in encoded_ts_arrays)
|
142 |
+
padded_ts_arrays = [
|
143 |
+
np.pad(ts, ((0, 0), (0, max_length - ts.shape[1]), (0, 0)), mode="constant", constant_values=0.0)
|
144 |
+
for ts in encoded_ts_arrays
|
145 |
+
]
|
146 |
+
concatenated_ts = np.concatenate(padded_ts_arrays, axis=0) # Shape: [batch_size, max_length, feature_dim]
|
147 |
+
|
148 |
+
# Convert to torch
|
149 |
+
concatenated_ts = torch.from_numpy(concatenated_ts).half()
|
150 |
|
151 |
# Tokenize the processed prompt
|
152 |
tokenizer_outputs = {}
|
|
|
155 |
|
156 |
# Create the final output
|
157 |
outputs = tokenizer_outputs
|
158 |
+
if vllm_flag:
|
159 |
+
outputs["timeseries"] = zip(ts_tokens, encoded_ts_arrays)
|
160 |
+
elif concatenated_ts is not None:
|
161 |
outputs["timeseries"] = concatenated_ts
|
162 |
|
163 |
return BatchFeature(data=outputs)
|