File size: 3,798 Bytes
faed9d7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
from diart import SpeakerDiarization
from diart.inference import StreamingInference
from diart.sources import AudioSource
from rx.subject import Subject
import threading
import numpy as np
import asyncio

class WebSocketAudioSource(AudioSource):
    """

    Simple custom AudioSource that blocks in read()

    until close() is called.

    push_audio() is used to inject new PCM chunks.

    """
    def __init__(self, uri: str = "websocket", sample_rate: int = 16000):
        super().__init__(uri, sample_rate)
        self._close_event = threading.Event()
        self._closed = False

    def read(self):
        self._close_event.wait()

    def close(self):
        if not self._closed:
            self._closed = True
            self.stream.on_completed()
            self._close_event.set()

    def push_audio(self, chunk: np.ndarray):
        chunk = np.expand_dims(chunk, axis=0)
        if not self._closed:
            self.stream.on_next(chunk)


def create_pipeline(SAMPLE_RATE):
    diar_pipeline = SpeakerDiarization()
    ws_source = WebSocketAudioSource(uri="websocket_source", sample_rate=SAMPLE_RATE)
    inference = StreamingInference(
        pipeline=diar_pipeline,
        source=ws_source,
        do_plot=False,
        show_progress=False,
    )
    return inference, ws_source


def init_diart(SAMPLE_RATE):
    inference, ws_source = create_pipeline(SAMPLE_RATE)

    def diar_hook(result):
        """

        Hook called each time Diart processes a chunk.

        result is (annotation, audio).

        We store the label of the last segment in 'current_speaker'.

        """
        global l_speakers
        l_speakers = []
        annotation, audio = result
        for speaker in annotation._labels:            
            segments_beg = annotation._labels[speaker].segments_boundaries_[0]
            segments_end = annotation._labels[speaker].segments_boundaries_[-1]
            asyncio.create_task(
            l_speakers_queue.put({"speaker": speaker, "beg": segments_beg, "end": segments_end})
        )

    l_speakers_queue = asyncio.Queue()
    inference.attach_hooks(diar_hook)

    # Launch Diart in a background thread
    loop = asyncio.get_event_loop()
    diar_future = loop.run_in_executor(None, inference)
    return inference, l_speakers_queue, ws_source


class DiartDiarization():
    def __init__(self, SAMPLE_RATE):
        self.inference, self.l_speakers_queue, self.ws_source = init_diart(SAMPLE_RATE)
        self.segment_speakers = []

    async def diarize(self, pcm_array):
        self.ws_source.push_audio(pcm_array)
        self.segment_speakers = []
        while not self.l_speakers_queue.empty():
            self.segment_speakers.append(await self.l_speakers_queue.get())
    
    def close(self):
        self.ws_source.close()


    def assign_speakers_to_chunks(self, chunks):
        """

        Go through each chunk and see which speaker(s) overlap

        that chunk's time range in the Diart annotation.

        Then store the speaker label(s) (or choose the most overlapping).

        This modifies `chunks` in-place or returns a new list with assigned speakers.

        """
        if not self.segment_speakers:
            return chunks

        for segment in self.segment_speakers:
            seg_beg = segment["beg"]
            seg_end = segment["end"]
            speaker = segment["speaker"]
            for ch in chunks:
                if seg_end <= ch["beg"] or seg_beg >= ch["end"]:
                    continue
                # We have overlap. Let's just pick the speaker (could be more precise in a more complex implementation)
                ch["speaker"] = speaker

        return chunks