File size: 6,712 Bytes
9fed6cd
4f6e6aa
7ea5fab
aed3f20
66d3d20
abd2b7a
224fa84
092b591
 
aed3f20
494ef37
092b591
d5c8eb9
092b591
 
494ef37
d5c8eb9
092b591
 
 
 
 
78badb3
f0dd428
092b591
224fa84
092b591
224fa84
d3c7ce9
092b591
 
 
 
 
 
224fa84
f0dd428
 
98cb312
 
 
 
 
 
5f59609
98cb312
 
 
 
 
 
f0dd428
494ef37
aed3f20
 
 
9fed6cd
4f6e6aa
494ef37
 
7ea5fab
016850e
 
4f6e6aa
 
9fed6cd
aed3f20
092b591
aed3f20
4f6e6aa
 
 
 
 
 
 
 
 
 
 
 
5db4b55
ed41d7d
 
 
 
 
092b591
 
 
 
 
 
 
 
 
 
 
632facd
 
 
 
 
 
 
d351d70
 
 
 
632facd
 
 
d351d70
632facd
 
 
d351d70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
632facd
 
 
 
092b591
 
 
 
 
 
 
 
ed41d7d
 
 
 
 
494ef37
ed41d7d
12d60b3
ed41d7d
494ef37
ed41d7d
 
aed3f20
ed41d7d
f0dd428
ed41d7d
 
 
d27a194
 
f0dd428
092b591
a3a5713
d85b3ab
092b591
9fed6cd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
import gradio as gr
from llama_cpp import Llama
import os
from groq import Groq
import numpy as np
import wave
import uuid
from nemoguardrails import LLMRails, RailsConfig
from GoogleTTS import GoogleTTS

#tts
#import torchaudio
#from speechbrain.inference.TTS import FastSpeech2
# from speechbrain.inference.TTS import Tacotron2
# from speechbrain.inference.vocoders import HIFIGAN

#fastspeech2 = FastSpeech2.from_hparams(source="speechbrain/tts-fastspeech2-ljspeech", savedir="pretrained_models/tts-fastspeech2-ljspeech")
# tacotron2 = Tacotron2.from_hparams(source="speechbrain/tts-tacotron2-ljspeech", savedir="tmpdir_tts")
# hifi_gan = HIFIGAN.from_hparams(source="speechbrain/tts-hifigan-ljspeech", savedir="pretrained_models/tts-hifigan-ljspeech")

#google tts
tts = GoogleTTS()

def text_to_speech(text):
    # mel_output, mel_length, alignment = tacotron2.encode_text(text)
    # Running Vocoder (spectrogram-to-waveform)
    # waveforms = hifi_gan.decode_batch(mel_output)
    # Save the waverform
    outfile = f"{os.path.join(os.getcwd(), str(uuid.uuid4()))}.wav"
    # torchaudio.save(outfile, waveforms.squeeze(1), 22050)
    if len(text) > 5000:
        text_str = text[0:4999]
    else:
        text_str = text
    ret = tts.tts(text_str, outfile)
    return outfile

def combine_audio_files(audio_files):
    data= []
    outfile = "sounds.wav"
    for infile in audio_files:
        w = wave.open(infile, 'rb')
        data.append([w.getparams(), w.readframes(w.getnframes())] )
        w.close()
        #os.remove(infile)  # Remove temporary files
    output = wave.open(outfile, 'wb')
    output.setparams(data[0][0])
    for i in range(len(data)):
        output.writeframes(data[i][1])
    output.close()
    return outfile
    
#client
client = Groq(
    api_key=os.getenv("GROQ_API_KEY"),
)

llm = Llama.from_pretrained(
    repo_id="amir22010/fine_tuned_product_marketing_email_gemma_2_9b_q4_k_m", #custom fine tuned model
    filename="unsloth.Q4_K_M.gguf", #model file name
    cache_dir=os.path.abspath(os.getcwd()),
    n_ctx=2048, 
    n_batch=126,
    verbose=False
)

#guardrail model
guard_llm = "llama-3.2-11b-text-preview"

#marketing prompt
marketing_email_prompt = """Below is a product and description, please write a marketing email for this product.

### Product:
{}

### Description:
{}

### Marketing Email:
{}"""

async def greet(product,description):
    user_reques = marketing_email_prompt.format(
    product, # product
    description, # description
    "", # output - leave this blank for generation!
    )
    messages=[
            {"role": "user", "content":  user_reques},
        ]
    # messages = [
    #     {
    #         "role": "system",
    #         "content": "Your role is to assess the user prompt.",
    #     },
    #     {"role": "user", "content": guardrail_prompt.format(user_reques)},
    # ]
    #nemo guard
    # config = RailsConfig.from_path("nemo")
    YAML_CONFIG = """
    models:
      - type: main
        engine: groq
        model: llama-3.2-11b-text-preview
        
      - type: llama_guard
        engine: groq
        model: llama-guard-3-8b
        
    rails:
      input:
        flows:
          - content safety check input $model=main

    output:
        flows:
          - content safety check output $model=llama_guard

    prompts:
      - task: content_safety_check_input $model=main
        content: |
          You are a policy expert trying to help determine whether a user prompt is in
          violation of the defined safety policies.
    
          <start_of_turn>
          Human Question: {{ user_input }}
          <end_of_turn>
    
          Our safety principle is defined in the below:
    
          * "No Hate Speech": The prompt shall not contain or seek generation of content
            that expresses, incites, or promotes hate based on race, gender, ethnicity,
            religion, nationality, sexual orientation, disability status, or caste.
          * "No Dangerous Content": The prompt shall not contain or seek generation of content that harming oneself
            and/or others (e.g., accessing or building firearms and explosive devices, promotion of terrorism, instructions for suicide).
          * "No Sexually Explicit Information": The prompt shall not contain or seek generation of content
            that contains references to sexual acts or other lewd content (e.g., sexually graphic descriptions,
            content aimed at causing arousal). Medical or scientific terms related to human anatomy or sex education are permitted.
          * "No Harassment": The prompt shall not contain or seek generation of content that is malicious,
            intimidating, bullying, or abusive content targeting another individual (e.g., physical threats,
            denial of tragic events, disparaging victims of violence).
          ===
    
          Does the human prompt/question violate the above principle? Your answer must start with
          'Yes' or 'No'. And then walk through step by step to be sure we answer
          correctly.
          
        output_parser: is_content_safe
    
    streaming: False
    """
    config = RailsConfig.from_content(yaml_content=YAML_CONFIG)
    app = LLMRails(config=config, llm=client)
    options = {"output_vars": ["triggered_input_rail", "triggered_output_rail"]}
    output = app.generate(messages=messages, options=options)
    warning_message = output.output_data["triggered_input_rail"] or output.output_data["triggered_output_rail"]
    if warning_message:
        gr.Warning(f"Guardrail triggered: {warning_message}")
        chat = [output.response[0]['content']]
        yield chat[0]
    else:
        output = llm.create_chat_completion(
        messages=[
            {
                "role": "system",
                "content": "Your go-to Email Marketing Guru - I'm here to help you craft short and concise compelling campaigns, boost conversions, and take your business to the next level.",
            },
            {"role": "user", "content":  user_reques},
        ],
        max_tokens=2048,
        temperature=0.7,
        stream=True
        )
        partial_message = ""
        audio_list = []
        for chunk in output:
            delta = chunk['choices'][0]['delta']
            if 'content' in delta:
                #audio_list.append([text_to_speech(delta.get('content', ''))])
                #processed_audio = combine_audio_files(audio_list)
                partial_message = partial_message + delta.get('content', '')
                yield partial_message

audio = gr.Audio()
demo = gr.Interface(fn=greet, inputs=["text","text"], concurrency_limit=10, outputs=["text"])
demo.launch()