frankai98's picture
Update app.py
2c637db verified
raw
history blame
6.82 kB
import nest_asyncio
nest_asyncio.apply()
import streamlit as st
from transformers import pipeline
import torch
from gtts import gTTS
import io
import time
import asyncio
import datetime
if not asyncio.get_event_loop().is_running():
asyncio.set_event_loop(asyncio.new_event_loop())
# Initialize session state
if 'processed_data' not in st.session_state:
st.session_state.processed_data = {
'scenario': None,
'story': None,
'audio': None
}
if 'image_data' not in st.session_state:
st.session_state.image_data = None
if 'timer_start' not in st.session_state:
st.session_state.timer_start = None
if 'timer_running' not in st.session_state:
st.session_state.timer_running = False
if 'timer_complete' not in st.session_state:
st.session_state.timer_complete = False
# Page setup
st.set_page_config(page_title="Your Image to Audio Story", page_icon="🦜")
st.header("Turn Your Image to a Short Audio Story for Children")
# Model loading
@st.cache_resource
def load_models():
return {
"img_model": pipeline("image-to-text", "cnmoro/tiny-image-captioning"),
"story_model": pipeline("text-generation", "Qwen/Qwen2.5-0.5B-Instruct")
}
models = load_models()
# Processing functions
def img2text(url):
return models["img_model"](url)[0]["generated_text"]
def text2story(text):
prompt = f"Generate a 100-word story about: {text}"
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt}
]
response = models["story_model"](
messages,
max_new_tokens=100,
do_sample=True,
temperature=0.7
)[0]["generated_text"]
return response[2]["content"]
def text2audio(story_text):
audio_io = io.BytesIO()
tts = gTTS(text=story_text, lang='en', slow=False)
tts.write_to_fp(audio_io)
audio_io.seek(0)
return {'audio': audio_io, 'sampling_rate': 16000}
# Create fixed containers for UI elements
image_container = st.empty()
timer_container = st.empty()
status_container = st.empty()
progress_container = st.empty()
results_container = st.container()
# Native Streamlit timer display
def display_timer():
if st.session_state.timer_start is not None and st.session_state.timer_running:
elapsed = datetime.datetime.now() - st.session_state.timer_start
elapsed_seconds = int(elapsed.total_seconds())
minutes = elapsed_seconds // 60
seconds = elapsed_seconds % 60
if st.session_state.timer_complete:
timer_container.markdown(f"⏱️ **Elapsed: {minutes:02d}:{seconds:02d}**", unsafe_allow_html=True)
else:
timer_container.markdown(f"⏱️ Elapsed: {minutes:02d}:{seconds:02d}")
# Keep updating timer while running
if not st.session_state.timer_complete:
time.sleep(0.1) # Small delay to reduce CPU usage
st.experimental_rerun()
# UI components
uploaded_file = st.file_uploader("Select an Image After the Models are Loaded...")
# Always display the image if we have image data
if st.session_state.image_data is not None:
image_container.image(st.session_state.image_data, caption="Uploaded Image", use_container_width=True)
# Display results if available
if st.session_state.processed_data.get('scenario'):
with results_container:
st.write("**Caption:**", st.session_state.processed_data['scenario'])
if st.session_state.processed_data.get('story'):
with results_container:
st.write("**Story:**", st.session_state.processed_data['story'])
# Display timer if running
if st.session_state.timer_running:
display_timer()
# Process new uploaded file
if uploaded_file is not None:
# Save the image data to session state
bytes_data = uploaded_file.getvalue()
st.session_state.image_data = bytes_data
# Display the image
image_container.image(bytes_data, caption="Uploaded Image", use_container_width=True)
if st.session_state.get('current_file') != uploaded_file.name:
st.session_state.current_file = uploaded_file.name
# Start timer
st.session_state.timer_start = datetime.datetime.now()
st.session_state.timer_running = True
st.session_state.timer_complete = False
# Progress indicators
status_text = status_container.empty()
progress_bar = progress_container.progress(0)
try:
# Save uploaded file
with open(uploaded_file.name, "wb") as file:
file.write(bytes_data)
# Stage 1: Image to Text
status_text.markdown("**🖼️ Generating caption...**")
progress_bar.progress(0)
st.session_state.processed_data['scenario'] = img2text(uploaded_file.name)
progress_bar.progress(33)
# Stage 2: Text to Story
status_text.markdown("**📖 Generating story...**")
progress_bar.progress(33)
st.session_state.processed_data['story'] = text2story(
st.session_state.processed_data['scenario']
)
progress_bar.progress(66)
# Stage 3: Story to Audio
status_text.markdown("**🔊 Synthesizing audio...**")
progress_bar.progress(66)
st.session_state.processed_data['audio'] = text2audio(
st.session_state.processed_data['story']
)
progress_bar.progress(100)
# Final status
status_text.success("**✅ Generation complete!**")
# Stop timer
st.session_state.timer_complete = True
# Show results
with results_container:
st.write("**Caption:**", st.session_state.processed_data['scenario'])
st.write("**Story:**", st.session_state.processed_data['story'])
except Exception as e:
st.session_state.timer_complete = True
status_text.error(f"**❌ Error:** {str(e)}")
progress_bar.empty()
raise e
# Audio playback
if st.button("Play Audio of the Story Generated"):
if st.session_state.processed_data.get('audio'):
# Make sure the image is still displayed
if st.session_state.image_data is not None:
image_container.image(st.session_state.image_data, caption="Uploaded Image", use_container_width=True)
audio_data = st.session_state.processed_data['audio']
st.audio(
audio_data['audio'].getvalue(),
format="audio/mp3"
)
else:
st.warning("Please generate a story first!")