Spaces:
VIDraft
/
Running on Zero

ACE-Singer / app.py
ginipick's picture
Update app.py
29efb71 verified
raw
history blame
3.55 kB
# app.py
import argparse
import streamlit as st
import os
from pipeline_ace_step import ACEStepPipeline
from data_sampler import DataSampler
# Streamlit ์„ค์ •
st.set_page_config(
page_title="ACE Step Music Generator",
page_icon="๐ŸŽต",
layout="wide"
)
def get_args():
"""ํ™˜๊ฒฝ๋ณ€์ˆ˜ ๋˜๋Š” ๊ธฐ๋ณธ๊ฐ’์œผ๋กœ ์„ค์ •"""
return {
'checkpoint_path': os.environ.get('CHECKPOINT_PATH'),
'device_id': int(os.environ.get('DEVICE_ID', '0')),
'bf16': os.environ.get('BF16', 'True').lower() == 'true',
'torch_compile': os.environ.get('TORCH_COMPILE', 'False').lower() == 'true'
}
@st.cache_resource
def load_model(args):
"""๋ชจ๋ธ ๋กœ๋”ฉ (์บ์‹œ๋จ)"""
os.environ["CUDA_VISIBLE_DEVICES"] = str(args['device_id'])
persistent_storage_path = "/data"
model_demo = ACEStepPipeline(
checkpoint_dir=args['checkpoint_path'],
dtype="bfloat16" if args['bf16'] else "float32",
persistent_storage_path=persistent_storage_path,
torch_compile=args['torch_compile']
)
data_sampler = DataSampler()
return model_demo, data_sampler
def main():
st.title("๐ŸŽต ACE Step Music Generator")
args = get_args()
try:
model_demo, data_sampler = load_model(args)
# UI ๊ตฌ์„ฑ
col1, col2 = st.columns([2, 1])
with col1:
st.header("Generate Music")
# ํ…์ŠคํŠธ ์ž…๋ ฅ
prompt = st.text_area(
"Enter your music description:",
placeholder="Enter a description of the music you want to generate...",
height=100
)
# ์ƒ์„ฑ ๋ฒ„ํŠผ
if st.button("Generate Music", type="primary"):
if prompt:
with st.spinner("Generating music..."):
try:
result = model_demo(prompt)
st.success("Music generated successfully!")
# ๊ฒฐ๊ณผ ํ‘œ์‹œ (result ํ˜•ํƒœ์— ๋”ฐ๋ผ ์กฐ์ • ํ•„์š”)
if hasattr(result, 'audio'):
st.audio(result.audio)
else:
st.write(result)
except Exception as e:
st.error(f"Error generating music: {str(e)}")
else:
st.warning("Please enter a description first.")
with col2:
st.header("Sample Data")
if st.button("Load Sample"):
try:
sample_data = data_sampler.sample()
st.json(sample_data)
except Exception as e:
st.error(f"Error loading sample: {str(e)}")
# ํŒŒ์ผ ์—…๋กœ๋“œ
uploaded_file = st.file_uploader(
"Upload JSON data",
type=['json']
)
if uploaded_file:
try:
data = data_sampler.load_json(uploaded_file)
st.json(data)
except Exception as e:
st.error(f"Error loading file: {str(e)}")
except Exception as e:
st.error(f"Error loading model: {str(e)}")
import traceback
st.code(traceback.format_exc())
if __name__ == "__main__":
main()