File size: 5,055 Bytes
5917f0a
 
 
 
 
 
 
 
 
 
 
 
 
fc17b57
 
5917f0a
 
 
 
fc17b57
5917f0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fc17b57
 
 
 
 
 
 
 
 
 
5917f0a
 
 
 
 
fc17b57
5917f0a
 
 
fc17b57
 
5917f0a
 
 
 
 
 
 
fc17b57
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5917f0a
fc17b57
 
 
 
 
 
 
 
5917f0a
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import hydra
from typing import List, Optional
from dataclasses import dataclass, field
import kenlm
from beam_search_utils import (
    SpeakerTaggingBeamSearchDecoder,
    load_input_jsons,
    load_reference_jsons,
    write_seglst_jsons,
    run_mp_beam_search_decoding,
    convert_nemo_json_to_seglst,
)
from hydra.core.config_store import ConfigStore
from hyper_optim import optuna_hyper_optim



@dataclass
class RealigningLanguageModelParameters:
    # Beam search parameters
    batch_size: int = 32
    use_mp: bool = True
    input_error_src_list_path: Optional[str] = None
    groundtruth_ref_list_path: Optional[str] = None
    arpa_language_model: Optional[str] = None
    word_window: int = 32
    port: List[int] = field(default_factory=list)
    parallel_chunk_word_len: int = 250
    use_ngram: bool = True
    peak_prob: float = 0.95
    alpha: float = 0.5
    beta: float = 0.05
    beam_width: int = 16
    out_dir: Optional[str] = None

    # Optuna parameters
    hyper_params_optim: bool = False
    optuna_n_trials: int = 200
    workspace_dir: Optional[str]  = None
    asrdiar_file_name: Optional[str]  = None
    storage: Optional[str] = "sqlite:///optuna-speaker-beam-search.db"
    optuna_study_name: Optional[str] = "speaker_beam_search"
    output_log_file: Optional[str] = None
    temp_out_dir: Optional[str] = None

cs = ConfigStore.instance()
cs.store(name="config", node=RealigningLanguageModelParameters)

@hydra.main(config_name="config", version_base="1.1")
def main(cfg: RealigningLanguageModelParameters) -> None:
    __INFO_TAG__ = "[INFO]"
    trans_info_dict = load_input_jsons(input_error_src_list_path=cfg.input_error_src_list_path, peak_prob=float(cfg.peak_prob))
    reference_info_dict  = load_reference_jsons(reference_seglst_list_path=cfg.groundtruth_ref_list_path)
    source_info_dict = load_reference_jsons(reference_seglst_list_path=cfg.input_error_src_list_path)

    # Load ARPA language model in advance 
    loaded_kenlm_model = kenlm.Model(cfg.arpa_language_model)
    speaker_beam_search_decoder = SpeakerTaggingBeamSearchDecoder(loaded_kenlm_model=loaded_kenlm_model, cfg=cfg)
    
    div_trans_info_dict = speaker_beam_search_decoder.divide_chunks(trans_info_dict=trans_info_dict, 
                                                                    win_len=cfg.parallel_chunk_word_len, 
                                                                    word_window=cfg.word_window,
                                                                    port=cfg.port,)

    if cfg.hyper_params_optim:
        print(f"{__INFO_TAG__} Optimizing hyper-parameters...")
        cfg = optuna_hyper_optim(cfg=cfg,
                                speaker_beam_search_decoder=speaker_beam_search_decoder,
                                loaded_kenlm_model=loaded_kenlm_model,
                                div_trans_info_dict=div_trans_info_dict, 
                                org_trans_info_dict=trans_info_dict,
                                source_info_dict=source_info_dict,
                                reference_info_dict=reference_info_dict, 
                                ) 
        
        __INFO_TAG__ = f"{__INFO_TAG__} Optimized hyper-parameters - "
    else:
        trans_info_dict = run_mp_beam_search_decoding(speaker_beam_search_decoder, 
                                                        loaded_kenlm_model=loaded_kenlm_model,
                                                        div_trans_info_dict=div_trans_info_dict, 
                                                        org_trans_info_dict=trans_info_dict, 
                                                        div_mp=True,
                                                        win_len=cfg.parallel_chunk_word_len,
                                                        word_window=cfg.word_window,
                                                        port=cfg.port,
                                                        use_ngram=cfg.use_ngram,
                                                        )
        hypothesis_sessions_dict = convert_nemo_json_to_seglst(trans_info_dict) 
        
        write_seglst_jsons(hypothesis_sessions_dict, input_error_src_list_path=cfg.input_error_src_list_path, diar_out_path=cfg.out_dir, ext_str='hyp')
        write_seglst_jsons(reference_info_dict, input_error_src_list_path=cfg.groundtruth_ref_list_path, diar_out_path=cfg.out_dir, ext_str='ref')
        write_seglst_jsons(source_info_dict, input_error_src_list_path=cfg.groundtruth_ref_list_path, diar_out_path=cfg.out_dir, ext_str='src')
    
        print(f"{__INFO_TAG__} Parameters used: \
                \n ALPHA: {cfg.alpha} \
                \n BETA: {cfg.beta} \
                \n BEAM WIDTH: {cfg.beam_width} \
                \n Word Window: {cfg.word_window} \
                \n Use Ngram: {cfg.use_ngram} \
                \n Chunk Word Len: {cfg.parallel_chunk_word_len} \
                \n SpeakerLM Model: {cfg.arpa_language_model}")

if __name__ == '__main__':
    main()