Rausda6 commited on
Commit
7faf9f3
·
verified ·
1 Parent(s): caafdf1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +154 -151
app.py CHANGED
@@ -1,164 +1,167 @@
1
  import gradio as gr
2
- from pydub import AudioSegment
3
- import json
4
- import uuid
5
- import edge_tts
6
- import asyncio
7
- import aiofiles
8
- import os
9
  import time
10
- import mimetypes
11
- from typing import List, Dict
12
-
13
- # NEW – Hugging Face Transformers
14
  from transformers import AutoTokenizer, AutoModelForCausalLM
15
  import torch
16
 
17
- # NEW external model id
18
- MODEL_ID = "tabularisai/german-gemma-3-1b-it"
19
-
20
- # Constants
21
- MAX_FILE_SIZE_MB = 20
22
- MAX_FILE_SIZE_BYTES = MAX_FILE_SIZE_MB * 1024 * 1024 # Convert MB to bytes
23
-
24
-
25
- class PodcastGenerator:
26
- def __init__(self):
27
- self.tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
28
- self.model = AutoModelForCausalLM.from_pretrained(
29
- MODEL_ID,
30
- torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
31
- device_map="auto",
32
- ).eval()
33
 
34
- async def generate_script(
35
- self,
36
- prompt: str,
37
- language: str,
38
- api_key: str,
39
- file_obj=None,
40
- progress=None,
41
- ) -> Dict:
42
- example = """
43
- {
44
- "topic": "AGI",
45
- "podcast": [
46
- {
47
- "speaker": 2,
48
- "line": "So, AGI, huh? Seems like everyone's talking about it these days."
49
- },
50
- {
51
- "speaker": 1,
52
- "line": "Yeah, it's definitely having a moment, isn't it?"
53
- }
54
- ]
55
- }
56
- """
57
 
58
- if language == "Auto Detect":
59
- language_instruction = (
60
- "- The podcast MUST be in the same language as the user input."
61
- )
62
  else:
63
- language_instruction = f"- The podcast MUST be in {language} language"
64
-
65
- system_prompt = f"""
66
- You are a professional podcast generator. Your task is to generate a professional podcast script based on the user input.
67
- {language_instruction}
68
- - The podcast should have 2 speakers.
69
- - The podcast should be long.
70
- - Do not use names for the speakers.
71
- - The podcast should be interesting, lively, and engaging, and hook the listener from the start.
72
- - The input text might be disorganized or unformatted, originating from sources like PDFs or text files. Ignore any formatting inconsistencies or irrelevant details; your task is to distill the essential points, identify key definitions, and highlight intriguing facts that would be suitable for discussion in a podcast.
73
- - The script must be in JSON format.
74
-
75
- Follow this example structure:
76
- {example}
77
- """
78
-
79
- if prompt and file_obj:
80
- user_prompt = (
81
- f"Please generate a podcast script based on the uploaded file following user input:\n{prompt}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
  )
83
- elif prompt:
84
- user_prompt = (
85
- f"Please generate a podcast script based on the following user input:\n{prompt}"
 
 
86
  )
87
- else:
88
- user_prompt = "Please generate a podcast script based on the uploaded file."
89
-
90
- # If a file is provided we still read it for completeness (not required for HF generation)
91
- if file_obj:
92
- _ = await self._read_file_bytes(file_obj)
93
-
94
- if progress:
95
- progress(0.3, "Generating podcast script...")
96
-
97
- inputs = self.tokenizer(
98
- f"{system_prompt}\n\n{user_prompt}", return_tensors="pt"
99
- ).to(self.model.device)
100
-
101
- try:
102
- output = self.model.generate(**inputs, max_new_tokens=2048, temperature=1.0)
103
- response_text = self.tokenizer.decode(output[0], skip_special_tokens=True)
104
- except Exception as e:
105
- raise Exception(f"Failed to generate podcast script: {e}")
106
-
107
- print(f"Generated podcast script:\n{response_text}")
108
-
109
- if progress:
110
- progress(0.4, "Script generated successfully!")
111
-
112
- return json.loads(response_text)
113
-
114
- async def _read_file_bytes(self, file_obj) -> bytes:
115
- if hasattr(file_obj, "size"):
116
- file_size = file_obj.size
117
- else:
118
- file_size = os.path.getsize(file_obj.name)
119
-
120
- if file_size > MAX_FILE_SIZE_BYTES:
121
- raise Exception(
122
- f"File size exceeds the {MAX_FILE_SIZE_MB}MB limit. Please upload a smaller file."
123
  )
124
-
125
- if hasattr(file_obj, "read"):
126
- return file_obj.read()
127
- else:
128
- async with aiofiles.open(file_obj.name, "rb") as f:
129
- return await f.read()
130
-
131
- @staticmethod
132
- def _get_mime_type(filename: str) -> str:
133
- ext = os.path.splitext(filename)[1].lower()
134
- if ext == ".pdf":
135
- return "application/pdf"
136
- elif ext == ".txt":
137
- return "text/plain"
138
- else:
139
- mime_type, _ = mimetypes.guess_type(filename)
140
- return mime_type or "application/octet-stream"
141
-
142
-
143
- # Re-add UI definition for Gradio
144
- async def generate_interface(prompt, language, api_key, file):
145
- gen = PodcastGenerator()
146
- result = await gen.generate_script(prompt, language, api_key, file)
147
- return json.dumps(result, indent=2)
148
-
149
-
150
- interface = gr.Interface(
151
- fn=generate_interface,
152
- inputs=[
153
- gr.Textbox(label="Prompt"),
154
- gr.Radio(["English", "German", "Auto Detect"], label="Language", value="Auto Detect"),
155
- gr.Textbox(label="API Key", type="password"),
156
- gr.File(label="Upload File (optional)")
157
- ],
158
- outputs=gr.Textbox(label="Generated Podcast JSON"),
159
- title="Podcast Generator using Gemma",
160
- description="Generate a lively podcast script from your input text or uploaded file using the tabularisai/german-gemma-3-1b-it model."
161
- )
162
 
163
  if __name__ == "__main__":
164
- interface.launch()
 
1
  import gradio as gr
2
+ import random
 
 
 
 
 
 
3
  import time
4
+ import os
5
+ from elevenlabs import generate, set_api_key, save
6
+ from pathlib import Path
 
7
  from transformers import AutoTokenizer, AutoModelForCausalLM
8
  import torch
9
 
10
+ # Load model and tokenizer
11
+ tokenizer = AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-alpha")
12
+ model = AutoModelForCausalLM.from_pretrained(
13
+ "HuggingFaceH4/zephyr-7b-alpha",
14
+ torch_dtype=torch.float16, # Use float16 for memory efficiency
15
+ device_map="auto" # Automatically determine device placement
16
+ )
 
 
 
 
 
 
 
 
 
17
 
18
+ api_key = os.getenv("ELEVENLABS_API_KEY")
19
+ set_api_key(api_key)
20
+ podcasts_directory = "podcasts"
21
+ os.makedirs(podcasts_directory, exist_ok=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
+ def progress_callback(progress):
24
+ if progress:
25
+ if isinstance(progress, int):
26
+ return progress
27
  else:
28
+ try:
29
+ return float(progress)
30
+ except (ValueError, TypeError):
31
+ return 0
32
+ return 0
33
+
34
+ def generate_podcast_intro(podcast_topic, structure, perspective, tone, existing_podcast_info):
35
+ with open("prompt_engineered.txt", "r", encoding='utf-8') as file:
36
+ prompt_template = file.read()
37
+
38
+ prompt = prompt_template.format(
39
+ podcast_topic=podcast_topic,
40
+ structure=structure,
41
+ perspective=perspective,
42
+ tone=tone,
43
+ existing_podcast_info=existing_podcast_info
44
+ )
45
+
46
+ return prompt
47
+
48
+ # Function to generate content
49
+ def generate_content(prompt):
50
+ # Format prompt for the Zephyr model (which follows ChatML format)
51
+ messages = [{"role": "user", "content": prompt}]
52
+
53
+ # Convert to model inputs
54
+ encoded_input = tokenizer.apply_chat_template(
55
+ messages,
56
+ return_tensors="pt"
57
+ ).to(model.device)
58
+
59
+ # Generate response
60
+ with torch.no_grad():
61
+ output = model.generate(
62
+ encoded_input,
63
+ max_new_tokens=1500, # Adjust based on desired output length
64
+ do_sample=True,
65
+ temperature=0.7, # Adjust for creativity vs determinism
66
+ top_p=0.95
67
+ )
68
+
69
+ # Decode and return only the new tokens (response)
70
+ response = tokenizer.decode(output[0][encoded_input.shape[1]:], skip_special_tokens=True)
71
+ return response
72
+
73
+ def generate_podcast_audio(podcast_script, voice, progress=gr.Progress()):
74
+ if not api_key:
75
+ return "Error: ElevenLabs API key not set. Please set the ELEVENLABS_API_KEY environment variable."
76
+
77
+ try:
78
+ audio = generate(
79
+ text=podcast_script,
80
+ voice=voice,
81
+ model="eleven_turbo_v2"
82
+ )
83
+
84
+ random_id = ''.join(random.choices('abcdefghijklmnopqrstuvwxyz0123456789', k=6))
85
+ filename = os.path.join(podcasts_directory, f"podcast_{random_id}.mp3")
86
+ save(audio, filename)
87
+ return filename
88
+ except Exception as e:
89
+ return f"Error generating audio: {str(e)}"
90
+
91
+ def create_podcast(podcast_topic, structure, perspective, tone, existing_podcast_info, voice_option, progress=gr.Progress()):
92
+ progress(0, desc="Generating podcast content...")
93
+
94
+ prompt = generate_podcast_intro(podcast_topic, structure, perspective, tone, existing_podcast_info)
95
+
96
+ progress(20, desc="Processing with AI...")
97
+ podcast_content = generate_content(prompt)
98
+
99
+ progress(60, desc="Generating audio...")
100
+ audio_file = generate_podcast_audio(podcast_content, voice_option, progress)
101
+
102
+ progress(100, desc="Complete!")
103
+ return podcast_content, audio_file
104
+
105
+ available_voices = [
106
+ "Adam", "Antoni", "Arnold", "Bella", "Callum", "Charlie", "Christina", "Clyde", "Daniel", "Dorothy",
107
+ "Ella", "Elli", "Emily", "Fin", "Freya", "Gigi", "Giovanni", "Glinda", "Grace", "Harry",
108
+ "James", "Jeremy", "Joseph", "Josh", "Knightley", "Liam", "Matilda", "Matthew", "Michael", "Nicole",
109
+ "Patrick", "Rachel", "Richard", "Sam", "Sarah", "Serena", "Thomas", "Victor", "Wayne", "Charlotte"
110
+ ]
111
+
112
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
113
+ gr.Markdown("# 🎙️ AI Podcast Generator")
114
+ gr.Markdown("Generate a complete podcast with AI, including audio narration.")
115
+
116
+ with gr.Row():
117
+ with gr.Column():
118
+ podcast_topic = gr.Textbox(
119
+ label="Podcast Topic",
120
+ placeholder="Enter the main topic of your podcast",
121
+ lines=2
122
  )
123
+
124
+ structure = gr.Radio(
125
+ ["Interview Style", "Solo Monologue", "Panel Discussion", "Storytelling", "Educational"],
126
+ label="Podcast Structure",
127
+ value="Interview Style"
128
  )
129
+
130
+ perspective = gr.Radio(
131
+ ["Balanced and Objective", "Personal Opinion", "Expert Analysis", "Conversational", "Investigative"],
132
+ label="Perspective",
133
+ value="Balanced and Objective"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
  )
135
+
136
+ tone = gr.Radio(
137
+ ["Professional", "Casual & Friendly", "Humorous", "Serious & Formal", "Inspirational"],
138
+ label="Tone",
139
+ value="Professional"
140
+ )
141
+
142
+ existing_podcast_info = gr.Textbox(
143
+ label="Additional Context (Optional)",
144
+ placeholder="Any additional information, context, or specific points you want to include",
145
+ lines=3
146
+ )
147
+
148
+ voice_option = gr.Dropdown(
149
+ choices=available_voices,
150
+ label="Voice for Audio",
151
+ value="Adam"
152
+ )
153
+
154
+ generate_btn = gr.Button("Generate Podcast", variant="primary")
155
+
156
+ with gr.Column():
157
+ podcast_output = gr.Textbox(label="Generated Podcast Script", lines=12)
158
+ audio_output = gr.Audio(label="Podcast Audio")
159
+
160
+ generate_btn.click(
161
+ create_podcast,
162
+ inputs=[podcast_topic, structure, perspective, tone, existing_podcast_info, voice_option],
163
+ outputs=[podcast_output, audio_output]
164
+ )
 
 
 
 
 
 
 
 
165
 
166
  if __name__ == "__main__":
167
+ demo.launch()