import random from BeamDiffusionModel.models.CoSeD.cross_attention import get_softmax from BeamDiffusionModel.models.diffusionModel.configs.config_loader import CONFIG from BeamDiffusionModel.tree.tree import BeamSearchTree from BeamDiffusionModel.utils.utils import gen_img def set_softmax(nodes, softmax, n_latents, n_max_latents): for node, softmax_value in zip(nodes, softmax): node.set_softmax(softmax_value, n_latents, n_max_latents) def beam_inference(steps, latents_idx, n_seeds= 1, seeds=[], steps_back=2, beam_width=4, window_size=2, use_rand=True): while len(seeds) < n_seeds: seeds.append(random.randint(0, 10**6)) captions = steps tree = BeamSearchTree(steps_back,beam_width,latents_idx,len(captions)) nodes_to_explore = [] for i, caption in enumerate(captions): if i == 0: for seed in seeds: latents, img = gen_img(caption, seed=seed) new_node = tree.add_node(tree.root, caption, i + 1, "Rand Seed", "Rand Seed", img, latents, None) nodes_to_explore.append(new_node) else: next_nodes = [] for child, parent_node in enumerate(nodes_to_explore): parent_childs = [] current_step_embeddings, current_image_embeddings = [], [] if use_rand: seed = random.randint(0, 10 ** 6) latents, img = gen_img(caption, seed=seed) new_node = tree.add_node(parent_node, caption, i + 1, "Rand Seed", "Rand Seed", img, latents, None) parent_childs.append(new_node) current_step_embedding, current_image_embedding = new_node.get_features() current_step_embeddings.append(current_step_embedding) current_image_embeddings.append(current_image_embedding) ancestors = parent_node.get_ancestors(steps_back-1) for ancestor_idx, ancestor in enumerate(ancestors): for latent in latents_idx: ancestor_latent = ancestor.get_latent(latent) latents, img = gen_img(caption, latent=ancestor_latent) new_node = tree.add_node(parent_node, caption, i + 1, ancestor.step, latent,img, latents, None) parent_childs.append(new_node) current_step_embedding, current_image_embedding = new_node.get_features() current_step_embeddings.append(current_step_embedding) current_image_embeddings.append(current_image_embedding) if current_step_embeddings != []: previous_steps_embeddings, previous_images_embeddings = tree.get_previous_steps_features(parent_childs[-1]) softmax = get_softmax(previous_steps_embeddings, previous_images_embeddings, current_step_embeddings, current_image_embeddings) set_softmax(parent_childs, softmax, len(latents_idx), CONFIG["stable_diffusion"]["diffusion_settings"]["steps"]) next_nodes += parent_childs if i >= window_size: print("-----------------------------------Cleaning some nodes-----------------------------------") best_paths = tree.get_n_best_paths(beam_width, i + 1) new_next_nodes = [] for node in next_nodes: for node_path in best_paths: if node in node_path: new_next_nodes.append(node) next_nodes = new_next_nodes nodes_to_explore = next_nodes return tree.best_path_imgs()