File size: 4,587 Bytes
9db7bb3
 
930b87c
 
9db7bb3
 
 
 
 
 
 
cda1077
 
 
 
 
 
 
9db7bb3
930b87c
 
 
9db7bb3
139ac7f
9db7bb3
 
cda1077
 
 
 
 
 
 
 
 
 
 
 
 
 
 
930b87c
 
cda1077
 
 
 
 
 
 
 
 
 
 
930b87c
 
4211f84
cda1077
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
930b87c
 
cda1077
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import gradio as gr
from pydub import AudioSegment
import json
import uuid
import edge_tts
import asyncio
import aiofiles
import os
import time
import mimetypes
from typing import List, Dict

# NEW – Hugging Face Transformers
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

# NEW – external model id
MODEL_ID = "tabularisai/german-gemma-3-1b-it"

# Constants
MAX_FILE_SIZE_MB = 20
MAX_FILE_SIZE_BYTES = MAX_FILE_SIZE_MB * 1024 * 1024  # Convert MB to bytes


class PodcastGenerator:
    def __init__(self):
        self.tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
        self.model = AutoModelForCausalLM.from_pretrained(
            MODEL_ID,
            torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
            device_map="auto",
        ).eval()

    async def generate_script(
        self,
        prompt: str,
        language: str,
        api_key: str,
        file_obj=None,
        progress=None,
    ) -> Dict:
        example = """
{
  "topic": "AGI",
  "podcast": [
    {
      "speaker": 2,
      "line": "So, AGI, huh? Seems like everyone's talking about it these days."
    },
    {
      "speaker": 1,
      "line": "Yeah, it's definitely having a moment, isn't it?"
    }
  ]
}
"""

        if language == "Auto Detect":
            language_instruction = (
                "- The podcast MUST be in the same language as the user input."
            )
        else:
            language_instruction = f"- The podcast MUST be in {language} language"

        system_prompt = f"""
You are a professional podcast generator. Your task is to generate a professional podcast script based on the user input.
{language_instruction}
- The podcast should have 2 speakers.
- The podcast should be long.
- Do not use names for the speakers.
- The podcast should be interesting, lively, and engaging, and hook the listener from the start.
- The input text might be disorganized or unformatted, originating from sources like PDFs or text files. Ignore any formatting inconsistencies or irrelevant details; your task is to distill the essential points, identify key definitions, and highlight intriguing facts that would be suitable for discussion in a podcast.
- The script must be in JSON format.

Follow this example structure:
{example}
"""

        if prompt and file_obj:
            user_prompt = (
                f"Please generate a podcast script based on the uploaded file following user input:\n{prompt}"
            )
        elif prompt:
            user_prompt = (
                f"Please generate a podcast script based on the following user input:\n{prompt}"
            )
        else:
            user_prompt = "Please generate a podcast script based on the uploaded file."

        # If a file is provided we still read it for completeness (not required for HF generation)
        if file_obj:
            _ = await self._read_file_bytes(file_obj)

        if progress:
            progress(0.3, "Generating podcast script...")

        inputs = self.tokenizer(
            f"{system_prompt}\n\n{user_prompt}", return_tensors="pt"
        ).to(self.model.device)

        try:
            output = self.model.generate(**inputs, max_new_tokens=2048, temperature=1.0)
            response_text = self.tokenizer.decode(output[0], skip_special_tokens=True)
        except Exception as e:
            raise Exception(f"Failed to generate podcast script: {e}")

        print(f"Generated podcast script:\n{response_text}")

        if progress:
            progress(0.4, "Script generated successfully!")

        return json.loads(response_text)

    async def _read_file_bytes(self, file_obj) -> bytes:
        if hasattr(file_obj, "size"):
            file_size = file_obj.size
        else:
            file_size = os.path.getsize(file_obj.name)

        if file_size > MAX_FILE_SIZE_BYTES:
            raise Exception(
                f"File size exceeds the {MAX_FILE_SIZE_MB}MB limit. Please upload a smaller file."
            )

        if hasattr(file_obj, "read"):
            return file_obj.read()
        else:
            async with aiofiles.open(file_obj.name, "rb") as f:
                return await f.read()

    @staticmethod
    def _get_mime_type(filename: str) -> str:
        ext = os.path.splitext(filename)[1].lower()
        if ext == ".pdf":
            return "application/pdf"
        elif ext == ".txt":
            return "text/plain"
        else:
            mime_type, _ = mimetypes.guess_type(filename)
            return mime_type or "application/octet-stream"