Spaces:
Sleeping
Sleeping
import nest_asyncio | |
nest_asyncio.apply() | |
import streamlit as st | |
from transformers import pipeline | |
import torch | |
from gtts import gTTS | |
import io | |
import time | |
import asyncio | |
from streamlit.components.v1 import html | |
if not asyncio.get_event_loop().is_running(): | |
asyncio.set_event_loop(asyncio.new_event_loop()) | |
# Initialize session state | |
if 'processed_data' not in st.session_state: | |
st.session_state.processed_data = { | |
'scenario': None, | |
'story': None, | |
'audio': None | |
} | |
if 'image_data' not in st.session_state: | |
st.session_state.image_data = None | |
if 'timer_started' not in st.session_state: | |
st.session_state.timer_started = False | |
# Page setup | |
st.set_page_config(page_title="Your Image to Audio Story", page_icon="🦜") | |
st.header("Turn Your Image to a Short Audio Story for Children") | |
# Model loading | |
def load_models(): | |
return { | |
"img_model": pipeline("image-to-text", "cnmoro/tiny-image-captioning"), | |
"story_model": pipeline("text-generation", "Qwen/Qwen2.5-0.5B-Instruct") | |
} | |
models = load_models() | |
# Processing functions | |
def img2text(url): | |
return models["img_model"](url)[0]["generated_text"] | |
def text2story(text): | |
prompt = f"Generate a 100-word story about: {text}" | |
messages = [ | |
{"role": "system", "content": "You are a helpful assistant."}, | |
{"role": "user", "content": prompt} | |
] | |
response = models["story_model"]( | |
messages, | |
max_new_tokens=100, | |
do_sample=True, | |
temperature=0.7 | |
)[0]["generated_text"] | |
return response[2]["content"] | |
def text2audio(story_text): | |
audio_io = io.BytesIO() | |
tts = gTTS(text=story_text, lang='en', slow=False) | |
tts.write_to_fp(audio_io) | |
audio_io.seek(0) | |
return {'audio': audio_io, 'sampling_rate': 16000} | |
# Create fixed containers for UI elements | |
image_container = st.empty() | |
timer_container = st.empty() | |
status_container = st.empty() | |
progress_container = st.empty() | |
results_container = st.container() | |
# Display initial timer placeholder (empty timer) | |
timer_container.markdown( | |
'<div id="timer-display" style="font-size:16px;color:#666;margin-bottom:10px;">⏱️ Elapsed: 00:00</div>', | |
unsafe_allow_html=True | |
) | |
# JavaScript timer functions | |
def start_timer(): | |
timer_html = """ | |
<script> | |
// Only create a timer if one doesn't exist already | |
if (!window.timerInterval) { | |
console.log("Starting timer"); | |
var startTime = new Date().getTime(); | |
// Store the start time in localStorage to persist across reruns | |
localStorage.setItem('timerStartTime', startTime); | |
window.timerInterval = setInterval(function() { | |
var now = new Date().getTime(); | |
var startTimeFromStorage = parseInt(localStorage.getItem('timerStartTime') || startTime); | |
var elapsed = now - startTimeFromStorage; | |
var minutes = Math.floor(elapsed / (1000 * 60)); | |
var seconds = Math.floor((elapsed % (1000 * 60)) / 1000); | |
var formattedTime = | |
(minutes < 10 ? "0" : "") + minutes + ":" + | |
(seconds < 10 ? "0" : "") + seconds; | |
var timerElement = document.getElementById("timer-display"); | |
if (timerElement) { | |
timerElement.innerHTML = "⏱️ Elapsed: " + formattedTime; | |
} | |
// Store current display for potential freezing | |
localStorage.setItem('timerCurrentDisplay', formattedTime); | |
}, 100); | |
// Flag that timer is running | |
localStorage.setItem('timerRunning', 'true'); | |
} | |
</script> | |
""" | |
html(timer_html, height=0) | |
def freeze_timer(): | |
freeze_html = """ | |
<script> | |
console.log("Freezing timer"); | |
// Clear the interval if it exists | |
if (window.timerInterval) { | |
clearInterval(window.timerInterval); | |
window.timerInterval = null; | |
} | |
// Get the last timer display | |
var lastTimerDisplay = localStorage.getItem('timerCurrentDisplay') || "00:00"; | |
// Update the timer display with frozen styling | |
var timerElement = document.getElementById("timer-display"); | |
if (timerElement) { | |
timerElement.style.color = "#00cc00"; | |
timerElement.style.fontWeight = "bold"; | |
timerElement.innerHTML = "⏱️ Elapsed: " + lastTimerDisplay + " ✓"; | |
} | |
// Set flag that timer is frozen | |
localStorage.setItem('timerRunning', 'false'); | |
</script> | |
""" | |
html(freeze_html, height=0) | |
def reset_timer(): | |
reset_html = """ | |
<script> | |
console.log("Resetting timer"); | |
// Clear any existing interval | |
if (window.timerInterval) { | |
clearInterval(window.timerInterval); | |
window.timerInterval = null; | |
} | |
// Clear localStorage timer data | |
localStorage.removeItem('timerStartTime'); | |
localStorage.removeItem('timerCurrentDisplay'); | |
localStorage.removeItem('timerRunning'); | |
// Reset the display | |
var timerElement = document.getElementById("timer-display"); | |
if (timerElement) { | |
timerElement.style.color = "#666"; | |
timerElement.style.fontWeight = "normal"; | |
timerElement.innerHTML = "⏱️ Elapsed: 00:00"; | |
} | |
</script> | |
""" | |
html(reset_html, height=0) | |
# Always display the image if we have image data | |
if st.session_state.image_data is not None: | |
image_container.image(st.session_state.image_data, caption="Uploaded Image", use_container_width=True) | |
# UI components | |
uploaded_file = st.file_uploader("Select an Image After the Models are Loaded...") | |
# Process new uploaded file | |
if uploaded_file is not None: | |
# Save the image data to session state | |
bytes_data = uploaded_file.getvalue() | |
st.session_state.image_data = bytes_data | |
# Display the image | |
image_container.image(bytes_data, caption="Uploaded Image", use_container_width=True) | |
if st.session_state.get('current_file') != uploaded_file.name: | |
st.session_state.current_file = uploaded_file.name | |
# Reset and start timer | |
reset_timer() | |
start_timer() # Start the timer ONLY when a new file is uploaded | |
st.session_state.timer_started = True | |
# Progress indicators | |
status_text = status_container.empty() | |
progress_bar = progress_container.progress(0) | |
try: | |
# Save uploaded file | |
with open(uploaded_file.name, "wb") as file: | |
file.write(bytes_data) | |
# Stage 1: Image to Text | |
status_text.markdown("**🖼️ Generating caption...**") | |
st.session_state.processed_data['scenario'] = img2text(uploaded_file.name) | |
progress_bar.progress(33) | |
# Stage 2: Text to Story | |
status_text.markdown("**📖 Generating story...**") | |
st.session_state.processed_data['story'] = text2story( | |
st.session_state.processed_data['scenario'] | |
) | |
progress_bar.progress(66) | |
# Stage 3: Story to Audio | |
status_text.markdown("**🔊 Synthesizing audio...**") | |
st.session_state.processed_data['audio'] = text2audio( | |
st.session_state.processed_data['story'] | |
) | |
progress_bar.progress(100) | |
# Final status | |
status_text.success("**✅ Generation complete!**") | |
# Show results | |
with results_container: | |
st.write("**Caption:**", st.session_state.processed_data['scenario']) | |
st.write("**Story:**", st.session_state.processed_data['story']) | |
except Exception as e: | |
status_text.error(f"**❌ Error:** {str(e)}") | |
progress_bar.empty() | |
raise e | |
# If we have a previously started timer but a new page load, restart it | |
elif st.session_state.timer_started: | |
# Don't reset, just restart the timer to continue from where it was | |
start_timer() | |
# Display results if available | |
if st.session_state.processed_data.get('scenario'): | |
with results_container: | |
st.write("**Caption:**", st.session_state.processed_data['scenario']) | |
if st.session_state.processed_data.get('story'): | |
with results_container: | |
st.write("**Story:**", st.session_state.processed_data['story']) | |
# Audio playback | |
if st.button("Play Audio of the Story Generated"): | |
if st.session_state.processed_data.get('audio'): | |
# Make sure the image is still displayed | |
if st.session_state.image_data is not None: | |
image_container.image(st.session_state.image_data, caption="Uploaded Image", use_container_width=True) | |
# Freeze the timer if it was started | |
if st.session_state.timer_started: | |
freeze_timer() | |
# Play the audio | |
audio_data = st.session_state.processed_data['audio'] | |
st.audio( | |
audio_data['audio'].getvalue(), | |
format="audio/mp3" | |
) | |
else: | |
st.warning("Please generate a story first!") |