File size: 30,477 Bytes
55b0a04
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
import os
import json
import uuid
import random
import yaml
import litellm
import tqdm
import concurrent.futures
from typing import List, Dict, Any, Optional, Tuple
from datetime import datetime
from pathlib import Path

# Reuse components from the existing codebase
from prompts import PROMPTS, format_prompt, load_prompts
from utils import save_results, generate_user_id
from ibfs import generate_strategies, answer_query
from zero_shot import zero_shot_answer

# Load environment variables and prompts
load_prompts()


class UserAgent:
    """Simulates a user with a preferred answer and decision-making behavior."""

    def __init__(self, llm_model: str = "gpt-4o", epsilon: float = 0.2):
        """
        Initialize the UserAgent with properties.

        Args:
            llm_model: The LLM model to use for agent decisions
            epsilon: Probability of making a random choice instead of optimal
        """
        self.llm_model = llm_model
        self.epsilon = epsilon
        self.preferred_answer = None
        self.query = None
        self.id = generate_user_id()

    def set_preferences(self, query: str):
        """
        Set a query and generate the preferred answer for this user.

        Args:
            query: The question the user wants answered
        """
        self.query = query

        # Generate the user's preferred answer using the LLM
        messages = [
            {"role": "system",
             "content": "You are generating a preferred answer that a user has in mind for their query. This represents what the user is hoping to learn or the perspective they're hoping to see."},
            {"role": "user",
             "content": f"For the query: '{query}', generate a detailed, thoughtful answer that will serve as the user's preferred answer. This is the information or perspective they are hoping to find. Make it 20 words."}
        ]

        response = litellm.completion(
            model=self.llm_model,
            messages=messages,
            max_tokens=1000
        )

        self.preferred_answer = response.choices[0].message.content
        return self.preferred_answer

    def choose_strategy(self, strategies: List[str]) -> int:
        """
        Choose a strategy from the provided options.

        Args:
            strategies: List of strategy descriptions

        Returns:
            Index of the chosen strategy (0-based)
        """
        # With probability epsilon, make a random choice
        if random.random() < self.epsilon:
            return random.randint(0, len(strategies) - 1)

        # Otherwise, evaluate which strategy gets closest to preferred answer
        if not self.preferred_answer or not strategies:
            return 0  # Default to first option if no preference or strategies

        # Prompt the LLM to rank the strategies based on similarity to preferred answer
        strategy_list = "\n".join([f"{i + 1}. {s}" for i, s in enumerate(strategies)])

        messages = [
            {"role": "system",
             "content": "You are helping a user select the strategy that would most likely lead to their preferred answer."},
            {"role": "user", "content": f"""
Query: {self.query}

User's preferred answer: {self.preferred_answer}

Available strategies:
{strategy_list}

Which strategy (provide the number only) would most likely lead to an answer that matches the user's preferred answer? Respond with only a single number representing your choice.
"""}
        ]

        try:
            response = litellm.completion(
                model=self.llm_model,
                messages=messages,
                temperature=0.2,
                max_tokens=10
            )

            # Extract the chosen strategy number
            content = response.choices[0].message.content.strip()
            # Find the first number in the response
            import re
            match = re.search(r'\d+', content)
            if match:
                choice = int(match.group()) - 1  # Convert to 0-based index
                # Ensure it's within bounds
                if 0 <= choice < len(strategies):
                    return choice

            # If we couldn't parse the response or it's out of bounds, make a random choice
            return random.randint(0, len(strategies) - 1)

        except Exception as e:
            print(f"Error in choosing strategy: {e}")
            # Fall back to random choice
            return random.randint(0, len(strategies) - 1)


class IBFSAgent:
    """Implements the Interactive Best-First Search process."""

    def __init__(self,
                 llm_model: str = "gpt-4o",
                 diversity_level: str = "medium",
                 branching_factor: int = 4,
                 max_depth: int = 2):
        """
        Initialize the IBFSAgent with properties.

        Args:
            llm_model: The LLM model to use for generating candidates
            diversity_level: How diverse the generated candidates should be (low, medium, high)
            branching_factor: Number of candidates to generate at each step
            max_depth: Maximum depth/iterations of the IBFS process
        """
        self.llm_model = llm_model
        self.diversity_level = diversity_level
        self.branching_factor = branching_factor
        self.max_depth = max_depth
        self.id = generate_user_id()

        # Set up diversity-specific prompts
        self._setup_prompts()

    def _setup_prompts(self):
        """Set up the candidate generation and refinement prompts based on diversity level."""
        # Load the base prompts from PROMPTS dictionary
        self.base_system_prompt = PROMPTS["ibfs"]["initial_strategies"]["system"]
        self.base_user_prompt = PROMPTS["ibfs"]["initial_strategies"]["user"]
        self.refinement_system_prompt = PROMPTS["ibfs"]["continuation_strategies"]["system"]
        self.refinement_user_prompt = PROMPTS["ibfs"]["continuation_strategies"]["user"]

        # Augment with diversity-specific instructions
        diversity_instructions = {
            "low": """
The strategies you generate can be similar to each other and explore related approaches.
There's no need to make them very different from each other.
""",
            "medium": """
Each strategy should represent a somewhat different approach to answering the question.
Try to include some variety in the approaches.
""",
            "high": """
Each strategy should represent a substantially different approach to answering the question.
Make sure the strategies are maximally diverse from each other - consider entirely different angles, 
methodologies, perspectives, and areas of knowledge.
"""
        }

        # Add diversity instructions to the prompts
        self.diversity_instructions = diversity_instructions[self.diversity_level]

    def generate_strategies(self, query: str, current_path: List[str] = None) -> List[str]:
        """
        Generate strategy options for the current step.

        Args:
            query: The user's query
            current_path: List of previously selected strategies

        Returns:
            List of strategy descriptions
        """
        if not current_path or len(current_path) == 0:
            # Initial generation
            system_prompt = self.base_system_prompt + "\n" + self.diversity_instructions
            user_prompt = self.base_user_prompt

            # Format the prompts
            format_args = {
                "query": query,
                "k": self.branching_factor
            }
        else:
            # Refinement of previously selected strategy
            system_prompt = self.refinement_system_prompt + "\n" + self.diversity_instructions
            user_prompt = self.refinement_user_prompt

            # Format the prompts
            format_args = {
                "query": query,
                "selected_strategy": current_path[-1],
                "k": self.branching_factor
            }

        # Format the prompts
        system_message = format_prompt(system_prompt, **format_args)
        user_message = format_prompt(user_prompt, **format_args)

        messages = [
            {"role": "system", "content": system_message},
            {"role": "user", "content": user_message}
        ]

        try:
            response = litellm.completion(
                model=self.llm_model,
                messages=messages,
                temperature=0.7,
                max_tokens=1000
            )

            content = response.choices[0].message.content

            # Use the strategy parsing from ibfs.py
            # Parse strategies using regex
            import re
            strategies = re.findall(r'\d+\.\s*(I can answer by[^\n\d]*(?:\n(?!\d+\.)[^\n]*)*)', content, re.IGNORECASE)

            # If we didn't find enough strategies with that format, try alternative parsing
            if len(strategies) < self.branching_factor:
                strategies = re.findall(r'(?:^|\n)(I can answer by[^\n]*(?:\n(?!I can answer by)[^\n]*)*)', content,
                                        re.IGNORECASE)

            # Clean up the strategies
            strategies = [s.strip() for s in strategies]

            # Ensure we have exactly b strategies
            if len(strategies) > self.branching_factor:
                strategies = strategies[:self.branching_factor]

            # If we still don't have enough strategies, create generic ones
            while len(strategies) < self.branching_factor:
                strategies.append(
                    f"I can answer by using approach #{len(strategies) + 1} (Note: Strategy generation incomplete)")

            return strategies

        except Exception as e:
            print(f"Error generating strategies: {e}")
            # Return fallback strategies
            return [f"I can answer by approach #{i + 1} (Error: Could not generate strategies)" for i in
                    range(self.branching_factor)]

    def generate_final_answer(self, query: str, strategy_path: List[str]) -> str:
        """
        Generate the final answer based on the selected strategy path.

        Args:
            query: The original user query
            strategy_path: List of selected strategies

        Returns:
            Final answer to the query
        """
        if not strategy_path:
            return "No strategy was selected to generate an answer."

        final_strategy = strategy_path[-1]

        # Use the answer_query function from ibfs.py
        return answer_query(query, final_strategy)


def run_simulation(query: str,
                   user_agent: UserAgent,
                   ibfs_agent: IBFSAgent) -> Dict[str, Any]:
    """
    Run a full simulation of a user interacting with the IBFS system.

    Args:
        query: The question to be answered
        user_agent: The UserAgent instance
        ibfs_agent: The IBFSAgent instance

    Returns:
        Dictionary containing the simulation results
    """
    # Set up the user's preferred answer
    user_agent.set_preferences(query)

    # Initialize the strategy path
    strategy_path = []

    # Record all strategies presented and choices made
    history = []

    # Run through the IBFS process up to max_depth
    for depth in range(ibfs_agent.max_depth):
        # Generate strategies at this step
        strategies = ibfs_agent.generate_strategies(query, strategy_path)

        # Have the user agent choose a strategy
        choice_idx = user_agent.choose_strategy(strategies)
        chosen_strategy = strategies[choice_idx]

        # Record this step
        history.append({
            "depth": depth,
            "strategies": strategies,
            "choice_idx": choice_idx,
            "chosen_strategy": chosen_strategy
        })

        # Update the strategy path
        strategy_path.append(chosen_strategy)

    # Generate the final answer
    final_answer = ibfs_agent.generate_final_answer(query, strategy_path)

    # Create the simulation result
    result = {
        "query": query,
        "user_id": user_agent.id,
        "ibfs_id": ibfs_agent.id,
        "user_preferred_answer": user_agent.preferred_answer,
        "final_answer": final_answer,
        "strategy_path": strategy_path,
        "history": history,
        "ibfs_config": {
            "diversity_level": ibfs_agent.diversity_level,
            "branching_factor": ibfs_agent.branching_factor,
            "max_depth": ibfs_agent.max_depth
        },
        "user_config": {
            "epsilon": user_agent.epsilon
        },
        "timestamp": datetime.now().isoformat()
    }

    return result


def evaluate_answer_similarity(answer1: str, answer2: str) -> float:
    """
    Evaluate the similarity between two answers using the LLM.

    Args:
        answer1: First answer
        answer2: Second answer

    Returns:
        Similarity score (0-1)
    """
    messages = [
        {"role": "system",
         "content": "You are evaluating the similarity between a user's preferred answer and a generated answer. Provide a similarity score from 0 to 1, where 1 means identical in content and perspective, and 0 means completely different."},
        {"role": "user", "content": f"""
Answer 1:
{answer1}

Answer 2:
{answer2}

On a scale from 0 to 1, how similar are these answers in terms of content, perspective, and key information? 
Provide only a single number as your response.
"""}
    ]

    try:
        response = litellm.completion(
            model="gpt-4o",
            messages=messages,
            temperature=0.1,
            max_tokens=10
        )

        content = response.choices[0].message.content.strip()
        # Extract the score from the response
        import re
        match = re.search(r'(\d+(\.\d+)?)', content)
        if match:
            score = float(match.group(1))
            # Ensure it's in the range [0, 1]
            return max(0, min(score, 1))
        else:
            # Default score if parsing fails
            return 0.5

    except Exception as e:
        print(f"Error evaluating similarity: {e}")
        return 0.5


def process_simulation(args):
    """
    Process a single simulation for parallel execution.

    Args:
        args: Tuple containing (query, user_config, ibfs_config, experiment_id, sim_count)

    Returns:
        Simulation result
    """
    query, user_config, ibfs_config, experiment_id, sim_count = args

    try:
        # Create agents with the current configuration
        user_agent = UserAgent(epsilon=user_config["epsilon"])
        ibfs_agent = IBFSAgent(
            diversity_level=ibfs_config["diversity_level"],
            branching_factor=ibfs_config["branching_factor"],
            max_depth=ibfs_config["max_depth"]
        )

        # Run the simulation
        result = run_simulation(query, user_agent, ibfs_agent)

        # Add evaluation of similarity between preferred and final answers
        similarity = evaluate_answer_similarity(
            user_agent.preferred_answer,
            result["final_answer"]
        )
        result["similarity_score"] = similarity

        # Add metadata
        result["experiment_id"] = experiment_id
        result["simulation_id"] = sim_count

        # Save individual result
        os.makedirs("experiment_results", exist_ok=True)
        simulation_id = f"{experiment_id}_sim_{sim_count}"
        with open(f"experiment_results/{simulation_id}.json", "w") as f:
            json.dump(result, f, indent=2)

        return result

    except Exception as e:
        print(f"Error in simulation {sim_count}: {e}")
        return {
            "error": str(e),
            "experiment_id": experiment_id,
            "simulation_id": sim_count,
            "query": query
        }


def run_experiment(queries: List[str],
                   diversity_levels: List[str],
                   branching_factors: List[int],
                   max_depths: List[int],
                   epsilon_values: List[float],
                   repetitions: int = 3,
                   max_workers: int = 4) -> List[Dict[str, Any]]:
    """
    Run a full experiment with different configurations using parallel processing.

    Args:
        queries: List of queries to test
        diversity_levels: List of diversity levels to test
        branching_factors: List of branching factors to test
        max_depths: List of max depths to test
        epsilon_values: List of epsilon values to test
        repetitions: Number of repetitions for each configuration
        max_workers: Maximum number of parallel workers

    Returns:
        List of results for all simulations
    """
    # Create the results directory if it doesn't exist
    os.makedirs("experiment_results", exist_ok=True)

    # Generate a unique experiment ID
    experiment_id = f"exp_{datetime.now().strftime('%Y%m%d_%H%M%S')}"

    # Create a config log for this experiment
    config = {
        "experiment_id": experiment_id,
        "queries": queries,
        "diversity_levels": diversity_levels,
        "branching_factors": branching_factors,
        "max_depths": max_depths,
        "epsilon_values": epsilon_values,
        "repetitions": repetitions,
        "timestamp": datetime.now().isoformat()
    }

    # Save the configuration
    with open(f"experiment_results/{experiment_id}_config.json", "w") as f:
        json.dump(config, f, indent=2)

    # Prepare all simulation configurations
    simulation_args = []
    sim_count = 0

    for query in queries:
        for diversity in diversity_levels:
            for branching in branching_factors:
                for depth in max_depths:
                    for epsilon in epsilon_values:
                        for rep in range(repetitions):
                            # Create configuration for this simulation
                            user_config = {"epsilon": epsilon}
                            ibfs_config = {
                                "diversity_level": diversity,
                                "branching_factor": branching,
                                "max_depth": depth
                            }

                            # Add to arguments list
                            simulation_args.append((query, user_config, ibfs_config, experiment_id, sim_count))
                            sim_count += 1

    # Calculate total simulations
    total_simulations = len(simulation_args)

    # Run simulations in parallel with progress bar
    results = []
    with tqdm.tqdm(total=total_simulations, desc="Running simulations") as pbar:
        with concurrent.futures.ProcessPoolExecutor(max_workers=max_workers) as executor:
            # Submit all simulations and process as they complete
            future_to_sim = {executor.submit(process_simulation, args): args for args in simulation_args}

            for future in concurrent.futures.as_completed(future_to_sim):
                result = future.result()
                results.append(result)
                pbar.update(1)

    # Save aggregated results
    with open(f"experiment_results/{experiment_id}_all_results.json", "w") as f:
        json.dump(results, f, indent=2)

    return results


def analyze_results(experiment_id: str) -> Dict[str, Any]:
    """
    Analyze the results of an experiment.

    Args:
        experiment_id: ID of the experiment to analyze

    Returns:
        Dictionary with analysis results
    """
    # Load all simulation results for this experiment
    results = []
    for filename in os.listdir("experiment_results"):
        if filename.startswith(experiment_id) and not filename.endswith("config.json") and not filename.endswith(
                "all_results.json") and not filename.endswith("analysis.json"):
            with open(f"experiment_results/{filename}", "r") as f:
                try:
                    results.append(json.load(f))
                except json.JSONDecodeError:
                    print(f"Error loading {filename}")

    # Aggregate metrics by configuration
    aggregated = {}

    for result in results:
        if "error" in result:
            continue  # Skip failed simulations

        # Create a key for this configuration
        config_key = (
            result["ibfs_config"]["diversity_level"],
            result["ibfs_config"]["branching_factor"],
            result["ibfs_config"]["max_depth"],
            result["user_config"]["epsilon"]
        )

        # Initialize if this is the first result with this config
        if config_key not in aggregated:
            aggregated[config_key] = {
                "similarity_scores": [],
                "config": {
                    "diversity_level": result["ibfs_config"]["diversity_level"],
                    "branching_factor": result["ibfs_config"]["branching_factor"],
                    "max_depth": result["ibfs_config"]["max_depth"],
                    "epsilon": result["user_config"]["epsilon"]
                },
                "queries": []
            }

        # Add the similarity score
        aggregated[config_key]["similarity_scores"].append(result["similarity_score"])

        # Track queries for this configuration
        if result["query"] not in aggregated[config_key]["queries"]:
            aggregated[config_key]["queries"].append(result["query"])

    # Calculate summary statistics
    summary = []
    for config_key, data in aggregated.items():
        scores = data["similarity_scores"]
        summary.append({
            "config": data["config"],
            "avg_similarity": sum(scores) / len(scores) if scores else 0,
            "min_similarity": min(scores) if scores else 0,
            "max_similarity": max(scores) if scores else 0,
            "std_deviation": (sum((x - (sum(scores) / len(scores))) ** 2 for x in scores) / len(
                scores)) ** 0.5 if scores else 0,
            "num_samples": len(scores),
            "queries_tested": len(data["queries"])
        })

    # Sort by average similarity (descending)
    summary.sort(key=lambda x: x["avg_similarity"], reverse=True)

    # Add query-specific analysis
    query_analysis = {}
    for result in results:
        if "error" in result:
            continue

        query = result["query"]
        if query not in query_analysis:
            query_analysis[query] = {
                "best_config": None,
                "best_similarity": -1,
                "configs_tested": 0,
                "avg_similarity": 0,
                "all_scores": []
            }

        # Track all scores for this query
        query_analysis[query]["all_scores"].append(result["similarity_score"])

        # Update best configuration for this query
        if result["similarity_score"] > query_analysis[query]["best_similarity"]:
            query_analysis[query]["best_similarity"] = result["similarity_score"]
            query_analysis[query]["best_config"] = {
                "diversity_level": result["ibfs_config"]["diversity_level"],
                "branching_factor": result["ibfs_config"]["branching_factor"],
                "max_depth": result["ibfs_config"]["max_depth"],
                "epsilon": result["user_config"]["epsilon"]
            }

    # Calculate query statistics
    for query, data in query_analysis.items():
        scores = data["all_scores"]
        data["avg_similarity"] = sum(scores) / len(scores) if scores else 0
        data["configs_tested"] = len(scores)
        # Remove the raw scores to keep the analysis file smaller
        del data["all_scores"]

    # Save the analysis
    analysis = {
        "experiment_id": experiment_id,
        "total_simulations": len(results),
        "summary": summary,
        "query_analysis": query_analysis,
        "timestamp": datetime.now().isoformat()
    }

    with open(f"experiment_results/{experiment_id}_analysis.json", "w") as f:
        json.dump(analysis, f, indent=2)

    return analysis


def compare_to_zero_shot(experiment_id: str, queries: List[str]) -> Dict[str, Any]:
    """
    Compare IBFS results to zero-shot answers.

    Args:
        experiment_id: ID of the experiment to compare
        queries: List of queries to test with zero-shot

    Returns:
        Comparison results
    """
    # First, get zero-shot answers for all queries
    zero_shot_results = []

    print("Generating zero-shot answers...")
    for query in tqdm.tqdm(queries):
        try:
            # Generate zero-shot answer using function from ibfs.py
            answer = zero_shot_answer(query)

            # Save the result
            zero_shot_results.append({
                "query": query,
                "zero_shot_answer": answer
            })
        except Exception as e:
            print(f"Error generating zero-shot answer for '{query}': {e}")
            zero_shot_results.append({
                "query": query,
                "zero_shot_answer": f"Error: {str(e)}",
                "error": True
            })

    # Load the IBFS experiment results
    analysis = analyze_results(experiment_id)

    # Load the best configurations from the analysis
    best_config = analysis["summary"][0]["config"] if analysis["summary"] else None

    # For each query, compare the best IBFS result to the zero-shot answer
    comparison = []

    print("Comparing zero-shot to IBFS results...")
    for zero_shot_result in tqdm.tqdm(zero_shot_results):
        query = zero_shot_result["query"]
        zero_shot_answer = zero_shot_result["zero_shot_answer"]

        # Find the best IBFS result for this query
        query_data = analysis.get("query_analysis", {}).get(query, {})
        best_config_for_query = query_data.get("best_config", best_config)

        if best_config_for_query:
            # Find a simulation with this configuration and query
            matching_results = []
            for filename in os.listdir("experiment_results"):
                if filename.startswith(experiment_id) and not filename.endswith(
                        "config.json") and not filename.endswith("all_results.json") and not filename.endswith(
                        "analysis.json"):
                    try:
                        with open(f"experiment_results/{filename}", "r") as f:
                            result = json.load(f)
                            if (result.get("query") == query and
                                    result.get("ibfs_config", {}).get("diversity_level") == best_config_for_query.get(
                                        "diversity_level") and
                                    result.get("ibfs_config", {}).get("branching_factor") == best_config_for_query.get(
                                        "branching_factor") and
                                    result.get("ibfs_config", {}).get("max_depth") == best_config_for_query.get(
                                        "max_depth") and
                                    result.get("user_config", {}).get("epsilon") == best_config_for_query.get(
                                        "epsilon")):
                                matching_results.append(result)
                    except:
                        continue

            # Use the best matching result (if any)
            if matching_results:
                # Sort by similarity score (descending)
                matching_results.sort(key=lambda x: x.get("similarity_score", 0), reverse=True)
                best_ibfs_result = matching_results[0]

                # Compare zero-shot to user's preferred answer
                preferred_answer = best_ibfs_result.get("user_preferred_answer", "")
                zero_shot_similarity = evaluate_answer_similarity(preferred_answer, zero_shot_answer)

                # Get the IBFS similarity (already calculated)
                ibfs_similarity = best_ibfs_result.get("similarity_score", 0)

                comparison.append({
                    "query": query,
                    "zero_shot_similarity": zero_shot_similarity,
                    "ibfs_similarity": ibfs_similarity,
                    "difference": ibfs_similarity - zero_shot_similarity,
                    "ibfs_config": best_config_for_query
                })
        else:
            print(f"No valid configuration found for query: {query}")

    # Calculate overall metrics
    zero_shot_avg = sum(item["zero_shot_similarity"] for item in comparison) / len(comparison) if comparison else 0
    ibfs_avg = sum(item["ibfs_similarity"] for item in comparison) / len(comparison) if comparison else 0
    avg_difference = sum(item["difference"] for item in comparison) / len(comparison) if comparison else 0

    # Save the comparison results
    comparison_results = {
        "experiment_id": experiment_id,
        "zero_shot_avg_similarity": zero_shot_avg,
        "ibfs_avg_similarity": ibfs_avg,
        "avg_difference": avg_difference,
        "query_comparisons": comparison,
        "timestamp": datetime.now().isoformat()
    }

    with open(f"experiment_results/{experiment_id}_zero_shot_comparison.json", "w") as f:
        json.dump(comparison_results, f, indent=2)

    return comparison_results


if __name__ == "__main__":
    # Ensure the prompts from YAML file are loaded
    if not PROMPTS:
        load_prompts()

    # Sample queries to test
    queries = [
        "What are the environmental impacts of electric vehicles compared to traditional gasoline vehicles?",
        "How has artificial intelligence changed the job market in the past decade?",
        "What are the most effective strategies for reducing stress and anxiety?",
        "What are the arguments for and against universal basic income?",
    ]

    # Experiment parameters
    diversity_levels = ["low", "medium", "high"]
    branching_factors = [2, 4, 8]
    max_depths = [1, 2, 4]
    epsilon_values = [0.1, 0.3]

    # Run a smaller test experiment
    print("Running test experiment...")
    test_results = run_experiment(
        queries=queries[:1],  # Just use the first query for testing
        diversity_levels=diversity_levels[:2],  # Test low and medium diversity
        branching_factors=[2, 4],  # Test b=2 and b=4
        max_depths=[1, 2],  # Test m=1 and m=2
        epsilon_values=[0.2],  # Test epsilon=0.2
        repetitions=10,  # Just 2 repetitions for testing
        max_workers=7  # Use 2 parallel workers for testing
    )