from transformers import PretrainedConfig, PreTrainedModel, Pipeline import torch from BeamDiffusionModel.beamInference import beam_inference # Your custom configuration for the BeamDiffusion model class BeamDiffusionConfig(PretrainedConfig): model_type = "beam_diffusion" def __init__(self, latents_idx=None, n_seeds=4, seeds=None, steps_back=2, beam_width=4, window_size=2, use_rand=True, **kwargs): super().__init__(**kwargs) self.latents_idx = latents_idx if latents_idx else [0, 1, 2, 3] self.n_seeds = n_seeds self.seeds = seeds if seeds else [] self.steps_back = steps_back self.beam_width = beam_width self.window_size = window_size self.use_rand = use_rand import torch.nn as nn from huggingface_hub import ModelHubMixin # Custom BeamDiffusionModel that performs inference for each step class BeamDiffusionModel(PreTrainedModel, ModelHubMixin): config_class = BeamDiffusionConfig model_type = "beam_diffusion" def __init__(self, config): super().__init__(config) self.config = config self.dummy_param = nn.Parameter(torch.zeros(1)) # Ensure at least one parameter def forward(self, input_data): images = beam_inference( steps=input_data.get('steps', []), latents_idx=input_data.get("latents_idx", [0, 1, 2, 3]), n_seeds=input_data.get("n_seeds", 4), seeds=input_data.get("seeds", []), steps_back=input_data.get("steps_back", 2), beam_width=input_data.get("beam_width", 4), window_size=input_data.get("window_size", 2), use_rand=input_data.get("use_rand", True) ) return {"images": images} # Custom pipeline to handle inference class BeamDiffusionPipeline(Pipeline, ModelHubMixin): def __init__(self, model, tokenizer=None, device="cuda", framework="pt"): super().__init__(model=model, tokenizer=tokenizer, device=device, framework=framework) def __call__(self, inputs): return self._forward(inputs) def preprocess(self, inputs): """Converts raw input data into model-ready format.""" return inputs # Keep as-is def postprocess(self, model_outputs): """Processes model output into a user-friendly format.""" return model_outputs["images"] # Ensure this matches expected output def _sanitize_parameters(self, **kwargs): """Handles unused parameters gracefully.""" return {}, {}, {} def _forward(self, model_inputs): return self.model(model_inputs)