Spaces:
Running
on
Zero
Running
on
Zero
# app.py | |
import argparse | |
import streamlit as st | |
import os | |
from pipeline_ace_step import ACEStepPipeline | |
from data_sampler import DataSampler | |
# Streamlit ์ค์ | |
st.set_page_config( | |
page_title="ACE Step Music Generator", | |
page_icon="๐ต", | |
layout="wide" | |
) | |
def get_args(): | |
"""ํ๊ฒฝ๋ณ์ ๋๋ ๊ธฐ๋ณธ๊ฐ์ผ๋ก ์ค์ """ | |
return { | |
'checkpoint_path': os.environ.get('CHECKPOINT_PATH'), | |
'device_id': int(os.environ.get('DEVICE_ID', '0')), | |
'bf16': os.environ.get('BF16', 'True').lower() == 'true', | |
'torch_compile': os.environ.get('TORCH_COMPILE', 'False').lower() == 'true' | |
} | |
def load_model(args): | |
"""๋ชจ๋ธ ๋ก๋ฉ (์บ์๋จ)""" | |
os.environ["CUDA_VISIBLE_DEVICES"] = str(args['device_id']) | |
persistent_storage_path = "/data" | |
model_demo = ACEStepPipeline( | |
checkpoint_dir=args['checkpoint_path'], | |
dtype="bfloat16" if args['bf16'] else "float32", | |
persistent_storage_path=persistent_storage_path, | |
torch_compile=args['torch_compile'] | |
) | |
data_sampler = DataSampler() | |
return model_demo, data_sampler | |
def main(): | |
st.title("๐ต ACE Step Music Generator") | |
args = get_args() | |
try: | |
model_demo, data_sampler = load_model(args) | |
# UI ๊ตฌ์ฑ | |
col1, col2 = st.columns([2, 1]) | |
with col1: | |
st.header("Generate Music") | |
# ํ ์คํธ ์ ๋ ฅ | |
prompt = st.text_area( | |
"Enter your music description:", | |
placeholder="Enter a description of the music you want to generate...", | |
height=100 | |
) | |
# ์์ฑ ๋ฒํผ | |
if st.button("Generate Music", type="primary"): | |
if prompt: | |
with st.spinner("Generating music..."): | |
try: | |
result = model_demo(prompt) | |
st.success("Music generated successfully!") | |
# ๊ฒฐ๊ณผ ํ์ (result ํํ์ ๋ฐ๋ผ ์กฐ์ ํ์) | |
if hasattr(result, 'audio'): | |
st.audio(result.audio) | |
else: | |
st.write(result) | |
except Exception as e: | |
st.error(f"Error generating music: {str(e)}") | |
else: | |
st.warning("Please enter a description first.") | |
with col2: | |
st.header("Sample Data") | |
if st.button("Load Sample"): | |
try: | |
sample_data = data_sampler.sample() | |
st.json(sample_data) | |
except Exception as e: | |
st.error(f"Error loading sample: {str(e)}") | |
# ํ์ผ ์ ๋ก๋ | |
uploaded_file = st.file_uploader( | |
"Upload JSON data", | |
type=['json'] | |
) | |
if uploaded_file: | |
try: | |
data = data_sampler.load_json(uploaded_file) | |
st.json(data) | |
except Exception as e: | |
st.error(f"Error loading file: {str(e)}") | |
except Exception as e: | |
st.error(f"Error loading model: {str(e)}") | |
import traceback | |
st.code(traceback.format_exc()) | |
if __name__ == "__main__": | |
main() |