File size: 2,895 Bytes
498ffec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import time

from multiprocessing import Process, Manager
from tqdm import tqdm


def worker_main(work_queue, result_queue, process_func, config):
    while True:
        item = work_queue.get()
        if item is None:
            result_queue.put(None)
            break
        try:
            results, cost = process_func(config, item)
            result_queue.put((results, cost))
        except Exception as e:
            item_info = item.get('idx', item.get('id', 'unknown item'))
            print(f"Error processing item {item_info}: {e}")
            result_queue.put(None)
        finally:
            work_queue.task_done()

def run_parallel_evaluation(dataset, process_func, config, num_workers, description):
    """
    Runs parallel evaluation on the given dataset and returns the results.

    Args:
        dataset (list or datasets.Dataset): Data to evaluate.
        process_func (callable): Function to process each data item.
        config (dict): Configuration for the process_func.
        num_workers (int): Number of worker processes to use.
        description (str): Description to display on the tqdm progress bar.

    Returns:
        tuple: (list of evaluation results, total cost)
    """
    manager = Manager()
    work_queue = manager.Queue()
    result_queue = manager.Queue()

    # Add data to the work queue
    dataset_list = list(dataset) if not isinstance(dataset, list) else dataset
    for data in dataset_list:
        work_queue.put(data)
    
    # Add termination signals for workers
    for _ in range(num_workers):
        work_queue.put(None)

    # Start parallel processing
    processes = []
    for _ in range(num_workers):
        p = Process(target=worker_main, args=(work_queue, result_queue, process_func, config))
        p.start()
        processes.append(p)
    
    # Show progress bar and collect results
    process_results = []
    process_cost = 0
    completed_workers = 0

    with tqdm(total=len(dataset_list), desc=description) as pbar:
        while completed_workers < num_workers:
            result_item = result_queue.get()
            if result_item is None:
                completed_workers += 1
            else:
                results, cost = result_item
                if results is not None:
                    process_results.append(results)
                    process_cost += cost if cost is not None else 0
                pbar.update(1)

    # Wait for all processes to finish
    for p in processes:
        p.join()

    # Collect remaining results
    while not result_queue.empty():
        result_item = result_queue.get_nowait()
        if result_item is not None:
            results, cost = result_item
            if results is not None:
                process_results.append(results)
                process_cost += cost if cost is not None else 0

    return process_results, process_cost