frankai98's picture
Update app.py
89a5c7c verified
raw
history blame
9.17 kB
import nest_asyncio
nest_asyncio.apply()
import streamlit as st
from transformers import pipeline
import torch
from gtts import gTTS
import io
import time
import asyncio
from streamlit.components.v1 import html
if not asyncio.get_event_loop().is_running():
asyncio.set_event_loop(asyncio.new_event_loop())
# Initialize session state
if 'processed_data' not in st.session_state:
st.session_state.processed_data = {
'scenario': None,
'story': None,
'audio': None
}
if 'image_data' not in st.session_state:
st.session_state.image_data = None
if 'timer_started' not in st.session_state:
st.session_state.timer_started = False
# Page setup
st.set_page_config(page_title="Your Image to Audio Story", page_icon="🦜")
st.header("Turn Your Image to a Short Audio Story for Children")
# Model loading
@st.cache_resource
def load_models():
return {
"img_model": pipeline("image-to-text", "cnmoro/tiny-image-captioning"),
"story_model": pipeline("text-generation", "Qwen/Qwen2.5-0.5B-Instruct")
}
models = load_models()
# Processing functions
def img2text(url):
return models["img_model"](url)[0]["generated_text"]
def text2story(text):
prompt = f"Generate a 100-word story about: {text}"
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt}
]
response = models["story_model"](
messages,
max_new_tokens=100,
do_sample=True,
temperature=0.7
)[0]["generated_text"]
return response[2]["content"]
def text2audio(story_text):
audio_io = io.BytesIO()
tts = gTTS(text=story_text, lang='en', slow=False)
tts.write_to_fp(audio_io)
audio_io.seek(0)
return {'audio': audio_io, 'sampling_rate': 16000}
# Create fixed containers for UI elements
image_container = st.empty()
timer_container = st.empty()
status_container = st.empty()
progress_container = st.empty()
results_container = st.container()
# Display initial timer placeholder (empty timer)
timer_container.markdown(
'<div id="timer-display" style="font-size:16px;color:#666;margin-bottom:10px;">⏱️ Elapsed: 00:00</div>',
unsafe_allow_html=True
)
# JavaScript timer functions
def start_timer():
timer_html = """
<script>
// Only create a timer if one doesn't exist already
if (!window.timerInterval) {
console.log("Starting timer");
var startTime = new Date().getTime();
// Store the start time in localStorage to persist across reruns
localStorage.setItem('timerStartTime', startTime);
window.timerInterval = setInterval(function() {
var now = new Date().getTime();
var startTimeFromStorage = parseInt(localStorage.getItem('timerStartTime') || startTime);
var elapsed = now - startTimeFromStorage;
var minutes = Math.floor(elapsed / (1000 * 60));
var seconds = Math.floor((elapsed % (1000 * 60)) / 1000);
var formattedTime =
(minutes < 10 ? "0" : "") + minutes + ":" +
(seconds < 10 ? "0" : "") + seconds;
var timerElement = document.getElementById("timer-display");
if (timerElement) {
timerElement.innerHTML = "⏱️ Elapsed: " + formattedTime;
}
// Store current display for potential freezing
localStorage.setItem('timerCurrentDisplay', formattedTime);
}, 100);
// Flag that timer is running
localStorage.setItem('timerRunning', 'true');
}
</script>
"""
html(timer_html, height=0)
def freeze_timer():
freeze_html = """
<script>
console.log("Freezing timer");
// Clear the interval if it exists
if (window.timerInterval) {
clearInterval(window.timerInterval);
window.timerInterval = null;
}
// Get the last timer display
var lastTimerDisplay = localStorage.getItem('timerCurrentDisplay') || "00:00";
// Update the timer display with frozen styling
var timerElement = document.getElementById("timer-display");
if (timerElement) {
timerElement.style.color = "#00cc00";
timerElement.style.fontWeight = "bold";
timerElement.innerHTML = "⏱️ Elapsed: " + lastTimerDisplay + " ✓";
}
// Set flag that timer is frozen
localStorage.setItem('timerRunning', 'false');
</script>
"""
html(freeze_html, height=0)
def reset_timer():
reset_html = """
<script>
console.log("Resetting timer");
// Clear any existing interval
if (window.timerInterval) {
clearInterval(window.timerInterval);
window.timerInterval = null;
}
// Clear localStorage timer data
localStorage.removeItem('timerStartTime');
localStorage.removeItem('timerCurrentDisplay');
localStorage.removeItem('timerRunning');
// Reset the display
var timerElement = document.getElementById("timer-display");
if (timerElement) {
timerElement.style.color = "#666";
timerElement.style.fontWeight = "normal";
timerElement.innerHTML = "⏱️ Elapsed: 00:00";
}
</script>
"""
html(reset_html, height=0)
# Always display the image if we have image data
if st.session_state.image_data is not None:
image_container.image(st.session_state.image_data, caption="Uploaded Image", use_container_width=True)
# UI components
uploaded_file = st.file_uploader("Select an Image After the Models are Loaded...")
# Process new uploaded file
if uploaded_file is not None:
# Save the image data to session state
bytes_data = uploaded_file.getvalue()
st.session_state.image_data = bytes_data
# Display the image
image_container.image(bytes_data, caption="Uploaded Image", use_container_width=True)
if st.session_state.get('current_file') != uploaded_file.name:
st.session_state.current_file = uploaded_file.name
# Reset and start timer
reset_timer()
start_timer() # Start the timer ONLY when a new file is uploaded
st.session_state.timer_started = True
# Progress indicators
status_text = status_container.empty()
progress_bar = progress_container.progress(0)
try:
# Save uploaded file
with open(uploaded_file.name, "wb") as file:
file.write(bytes_data)
# Stage 1: Image to Text
status_text.markdown("**🖼️ Generating caption...**")
st.session_state.processed_data['scenario'] = img2text(uploaded_file.name)
progress_bar.progress(33)
# Stage 2: Text to Story
status_text.markdown("**📖 Generating story...**")
st.session_state.processed_data['story'] = text2story(
st.session_state.processed_data['scenario']
)
progress_bar.progress(66)
# Stage 3: Story to Audio
status_text.markdown("**🔊 Synthesizing audio...**")
st.session_state.processed_data['audio'] = text2audio(
st.session_state.processed_data['story']
)
progress_bar.progress(100)
# Final status
status_text.success("**✅ Generation complete!**")
# Show results
with results_container:
st.write("**Caption:**", st.session_state.processed_data['scenario'])
st.write("**Story:**", st.session_state.processed_data['story'])
except Exception as e:
status_text.error(f"**❌ Error:** {str(e)}")
progress_bar.empty()
raise e
# If we have a previously started timer but a new page load, restart it
elif st.session_state.timer_started:
# Don't reset, just restart the timer to continue from where it was
start_timer()
# Display results if available
if st.session_state.processed_data.get('scenario'):
with results_container:
st.write("**Caption:**", st.session_state.processed_data['scenario'])
if st.session_state.processed_data.get('story'):
with results_container:
st.write("**Story:**", st.session_state.processed_data['story'])
# Audio playback
if st.button("Play Audio of the Story Generated"):
if st.session_state.processed_data.get('audio'):
# Make sure the image is still displayed
if st.session_state.image_data is not None:
image_container.image(st.session_state.image_data, caption="Uploaded Image", use_container_width=True)
# Freeze the timer if it was started
if st.session_state.timer_started:
freeze_timer()
# Play the audio
audio_data = st.session_state.processed_data['audio']
st.audio(
audio_data['audio'].getvalue(),
format="audio/mp3"
)
else:
st.warning("Please generate a story first!")