|
""" |
|
ACE-Step: A Step Towards Music Generation Foundation Model |
|
|
|
https://github.com/ace-step/ACE-Step |
|
|
|
Apache 2.0 License |
|
""" |
|
|
|
import os |
|
import click |
|
|
|
@click.command() |
|
@click.option( |
|
"--checkpoint_path", |
|
type=str, |
|
default="", |
|
help="Path to the checkpoint directory. Downloads automatically if empty.", |
|
) |
|
@click.option( |
|
"--server_name", |
|
type=str, |
|
default="127.0.0.1", |
|
help="The server name to use for the Gradio app.", |
|
) |
|
@click.option( |
|
"--port", type=int, default=None, help="The port to use for the Gradio app." |
|
) |
|
@click.option("--device_id", type=int, default=0, help="The CUDA device ID to use.") |
|
@click.option( |
|
"--share", |
|
type=click.BOOL, |
|
default=False, |
|
help="Whether to create a public, shareable link for the Gradio app.", |
|
) |
|
@click.option( |
|
"--bf16", |
|
type=click.BOOL, |
|
default=True, |
|
help="Whether to use bfloat16 precision. Turn off if using MPS.", |
|
) |
|
@click.option( |
|
"--torch_compile", type=click.BOOL, default=False, help="Whether to use torch.compile." |
|
) |
|
@click.option( |
|
"--cpu_offload", type=bool, default=False, help="Whether to use CPU offloading (only load current stage's model to GPU)" |
|
) |
|
@click.option( |
|
"--overlapped_decode", type=bool, default=False, help="Whether to use overlapped decoding (run dcae and vocoder using sliding windows)" |
|
) |
|
def main(checkpoint_path, server_name, port, device_id, share, bf16, torch_compile, cpu_offload, overlapped_decode): |
|
""" |
|
Main function to launch the ACE Step pipeline demo. |
|
""" |
|
|
|
os.environ["CUDA_VISIBLE_DEVICES"] = str(device_id) |
|
|
|
from acestep.ui.components import create_main_demo_ui |
|
from acestep.pipeline_ace_step import ACEStepPipeline |
|
from acestep.data_sampler import DataSampler |
|
|
|
model_demo = ACEStepPipeline( |
|
checkpoint_dir=checkpoint_path, |
|
dtype="bfloat16" if bf16 else "float32", |
|
torch_compile=torch_compile, |
|
cpu_offload=cpu_offload, |
|
overlapped_decode=overlapped_decode |
|
) |
|
data_sampler = DataSampler() |
|
|
|
demo = create_main_demo_ui( |
|
text2music_process_func=model_demo.__call__, |
|
sample_data_func=data_sampler.sample, |
|
load_data_func=data_sampler.load_json, |
|
) |
|
demo.queue().launch(inbrowser=True) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|