Upload KotobaWhisperPipeline
Browse files- config.json +1 -1
- generation_config.json +1 -1
- kotoba_whisper.py +4 -23
config.json
CHANGED
@@ -54,7 +54,7 @@
|
|
54 |
"pad_token_id": 50256,
|
55 |
"scale_embedding": false,
|
56 |
"torch_dtype": "float32",
|
57 |
-
"transformers_version": "4.
|
58 |
"use_cache": true,
|
59 |
"use_weighted_layer_sum": false,
|
60 |
"vocab_size": 51866
|
|
|
54 |
"pad_token_id": 50256,
|
55 |
"scale_embedding": false,
|
56 |
"torch_dtype": "float32",
|
57 |
+
"transformers_version": "4.41.0.dev0",
|
58 |
"use_cache": true,
|
59 |
"use_weighted_layer_sum": false,
|
60 |
"vocab_size": 51866
|
generation_config.json
CHANGED
@@ -261,5 +261,5 @@
|
|
261 |
"transcribe": 50360,
|
262 |
"translate": 50359
|
263 |
},
|
264 |
-
"transformers_version": "4.
|
265 |
}
|
|
|
261 |
"transcribe": 50360,
|
262 |
"translate": 50359
|
263 |
},
|
264 |
+
"transformers_version": "4.41.0.dev0"
|
265 |
}
|
kotoba_whisper.py
CHANGED
@@ -249,6 +249,8 @@ class KotobaWhisperPipeline(AutomaticSpeechRecognitionPipeline):
|
|
249 |
encoder = self.model.get_encoder()
|
250 |
# Consume values so we can let extra information flow freely through
|
251 |
# the pipeline (important for `partial` in microphone)
|
|
|
|
|
252 |
if "input_features" in model_inputs:
|
253 |
inputs = model_inputs.pop("input_features")
|
254 |
elif "input_values" in model_inputs:
|
@@ -260,18 +262,7 @@ class KotobaWhisperPipeline(AutomaticSpeechRecognitionPipeline):
|
|
260 |
)
|
261 |
|
262 |
# custom processing for Whisper timestamps and word-level timestamps
|
263 |
-
|
264 |
-
generate_kwargs["return_timestamps"] = return_timestamps
|
265 |
-
if return_timestamps == "word":
|
266 |
-
generate_kwargs["return_token_timestamps"] = True
|
267 |
-
generate_kwargs["return_segments"] = True
|
268 |
-
|
269 |
-
if stride is not None:
|
270 |
-
if isinstance(stride, tuple):
|
271 |
-
generate_kwargs["num_frames"] = stride[0] // self.feature_extractor.hop_length
|
272 |
-
else:
|
273 |
-
generate_kwargs["num_frames"] = [s[0] // self.feature_extractor.hop_length for s in stride]
|
274 |
-
|
275 |
if inputs.shape[-1] > self.feature_extractor.nb_max_frames:
|
276 |
generate_kwargs["input_features"] = inputs
|
277 |
else:
|
@@ -279,17 +270,7 @@ class KotobaWhisperPipeline(AutomaticSpeechRecognitionPipeline):
|
|
279 |
|
280 |
tokens = self.model.generate(attention_mask=attention_mask, **generate_kwargs)
|
281 |
# whisper longform generation stores timestamps in "segments"
|
282 |
-
|
283 |
-
if "segments" not in tokens:
|
284 |
-
out = {"tokens": tokens["sequences"], "token_timestamps": tokens["token_timestamps"]}
|
285 |
-
else:
|
286 |
-
token_timestamps = [
|
287 |
-
torch.cat([segment["token_timestamps"] for segment in segment_list])
|
288 |
-
for segment_list in tokens["segments"]
|
289 |
-
]
|
290 |
-
out = {"tokens": tokens["sequences"], "token_timestamps": token_timestamps}
|
291 |
-
else:
|
292 |
-
out = {"tokens": tokens}
|
293 |
if self.type == "seq2seq_whisper":
|
294 |
if stride is not None:
|
295 |
out["stride"] = stride
|
|
|
249 |
encoder = self.model.get_encoder()
|
250 |
# Consume values so we can let extra information flow freely through
|
251 |
# the pipeline (important for `partial` in microphone)
|
252 |
+
if type(return_timestamps) is not bool:
|
253 |
+
raise ValueError("return_timestamps should be bool")
|
254 |
if "input_features" in model_inputs:
|
255 |
inputs = model_inputs.pop("input_features")
|
256 |
elif "input_values" in model_inputs:
|
|
|
262 |
)
|
263 |
|
264 |
# custom processing for Whisper timestamps and word-level timestamps
|
265 |
+
generate_kwargs["return_timestamps"] = True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
266 |
if inputs.shape[-1] > self.feature_extractor.nb_max_frames:
|
267 |
generate_kwargs["input_features"] = inputs
|
268 |
else:
|
|
|
270 |
|
271 |
tokens = self.model.generate(attention_mask=attention_mask, **generate_kwargs)
|
272 |
# whisper longform generation stores timestamps in "segments"
|
273 |
+
out = {"tokens": tokens}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
274 |
if self.type == "seq2seq_whisper":
|
275 |
if stride is not None:
|
276 |
out["stride"] = stride
|