LTH001's picture
Update app.py
63497b6 verified
raw
history blame
4.14 kB
# import part
import streamlit as st
from transformers import pipeline
from PIL import Image
import io
import numpy as np
import soundfile as sf # For handling audio file operations
# function part
def generate_image_caption(image):
"""Generates a caption for the given image using a pre-trained model.
Args:
image: PIL Image object
Returns:
str: Generated caption text
"""
# Initialize image-to-text pipeline with BLIP model
img2caption = pipeline("image-to-text", model="Salesforce/blip-image-captioning-base")
result = img2caption(image)
return result[0]['generated_text']
def text2story(text):
"""Generates a children's story from text input using story generation model.
Args:
text: Input text prompt
Returns:
str: Generated story text
"""
# Craft prompt with specific requirements for children's stories
story_prompt = f"Create a funny 100-word story for 8-year-olds about: {text}. Include: "
story_prompt += "1) A silly character 2) Magical object 3) Sound effects 4) Happy ending"
# Initialize text generation pipeline
pipe = pipeline("text-generation", model="pranavpsv/genre-story-generator-v2")
# Generate story with controlled randomness parameters
story_text = pipe(
story_prompt,
max_new_tokens=200, # Limit story length
temperature=0.9, # Control randomness (higher = more creative)
top_k=50 # Limit vocabulary choices
)[0]['generated_text']
# Clean output by splitting at the required ending marker
return story_text.split("Happy ending")[-1].strip()
def story_to_speech(story_text):
"""Converts story text to audio using text-to-speech model.
Args:
story_text: Story text to convert
Returns:
BytesIO: Audio data in WAV format
"""
# Initialize Bark text-to-speech pipeline
tts_pipe = pipeline("text-to-speech", model="suno/bark-small")
# Generate audio array (numpy array of sound samples)
audio_output = tts_pipe(story_text, max_length=400) # Limit text length for stability
# Convert numpy array to playable audio bytes
audio_bytes = io.BytesIO()
sf.write(
audio_bytes,
audio_output["audio"],
audio_output["sampling_rate"],
format='WAV'
)
audio_bytes.seek(0) # Reset pointer for Streamlit audio player
return audio_bytes
def main():
"""Main function for Streamlit application workflow"""
# Configure page header
st.title("πŸ“– Image Story Generator with Audio")
st.write("Upload an image to get a magical story read aloud!")
# Image upload widget (supports JPG/PNG)
uploaded_image = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
if uploaded_image:
# Process image
image = Image.open(uploaded_image).convert("RGB") # Ensure RGB format
st.image(image, use_column_width=True) # Display uploaded image
# Image analysis section
with st.spinner("✨ Analyzing image..."):
caption = generate_image_caption(image)
# Display image understanding
st.subheader("Image Understanding")
st.write(caption)
# Story generation section
with st.spinner("πŸ“– Writing story..."):
story = text2story(caption)
# Display generated story
st.subheader("Magical Story")
st.write(story)
# Audio generation section
if st.button("🎧 Read Story Aloud"):
with st.spinner("πŸ”Š Generating audio..."):
try:
# Convert story to audio (trim to 400 characters for model stability)
audio_bytes = story_to_speech(story[:400])
# Display audio player
st.audio(audio_bytes, format="audio/wav")
except Exception as e:
st.error(f"Error generating audio: {str(e)}")
if __name__ == "__main__":
# Start the Streamlit application
main()