MarketMate / app.py
amir22010's picture
Update app.py
d351d70 verified
raw
history blame
6.71 kB
import gradio as gr
from llama_cpp import Llama
import os
from groq import Groq
import numpy as np
import wave
import uuid
from nemoguardrails import LLMRails, RailsConfig
from GoogleTTS import GoogleTTS
#tts
#import torchaudio
#from speechbrain.inference.TTS import FastSpeech2
# from speechbrain.inference.TTS import Tacotron2
# from speechbrain.inference.vocoders import HIFIGAN
#fastspeech2 = FastSpeech2.from_hparams(source="speechbrain/tts-fastspeech2-ljspeech", savedir="pretrained_models/tts-fastspeech2-ljspeech")
# tacotron2 = Tacotron2.from_hparams(source="speechbrain/tts-tacotron2-ljspeech", savedir="tmpdir_tts")
# hifi_gan = HIFIGAN.from_hparams(source="speechbrain/tts-hifigan-ljspeech", savedir="pretrained_models/tts-hifigan-ljspeech")
#google tts
tts = GoogleTTS()
def text_to_speech(text):
# mel_output, mel_length, alignment = tacotron2.encode_text(text)
# Running Vocoder (spectrogram-to-waveform)
# waveforms = hifi_gan.decode_batch(mel_output)
# Save the waverform
outfile = f"{os.path.join(os.getcwd(), str(uuid.uuid4()))}.wav"
# torchaudio.save(outfile, waveforms.squeeze(1), 22050)
if len(text) > 5000:
text_str = text[0:4999]
else:
text_str = text
ret = tts.tts(text_str, outfile)
return outfile
def combine_audio_files(audio_files):
data= []
outfile = "sounds.wav"
for infile in audio_files:
w = wave.open(infile, 'rb')
data.append([w.getparams(), w.readframes(w.getnframes())] )
w.close()
#os.remove(infile) # Remove temporary files
output = wave.open(outfile, 'wb')
output.setparams(data[0][0])
for i in range(len(data)):
output.writeframes(data[i][1])
output.close()
return outfile
#client
client = Groq(
api_key=os.getenv("GROQ_API_KEY"),
)
llm = Llama.from_pretrained(
repo_id="amir22010/fine_tuned_product_marketing_email_gemma_2_9b_q4_k_m", #custom fine tuned model
filename="unsloth.Q4_K_M.gguf", #model file name
cache_dir=os.path.abspath(os.getcwd()),
n_ctx=2048,
n_batch=126,
verbose=False
)
#guardrail model
guard_llm = "llama-3.2-11b-text-preview"
#marketing prompt
marketing_email_prompt = """Below is a product and description, please write a marketing email for this product.
### Product:
{}
### Description:
{}
### Marketing Email:
{}"""
async def greet(product,description):
user_reques = marketing_email_prompt.format(
product, # product
description, # description
"", # output - leave this blank for generation!
)
messages=[
{"role": "user", "content": user_reques},
]
# messages = [
# {
# "role": "system",
# "content": "Your role is to assess the user prompt.",
# },
# {"role": "user", "content": guardrail_prompt.format(user_reques)},
# ]
#nemo guard
# config = RailsConfig.from_path("nemo")
YAML_CONFIG = """
models:
- type: main
engine: groq
model: llama-3.2-11b-text-preview
- type: llama_guard
engine: groq
model: llama-guard-3-8b
rails:
input:
flows:
- content safety check input $model=main
output:
flows:
- content safety check output $model=llama_guard
prompts:
- task: content_safety_check_input $model=main
content: |
You are a policy expert trying to help determine whether a user prompt is in
violation of the defined safety policies.
<start_of_turn>
Human Question: {{ user_input }}
<end_of_turn>
Our safety principle is defined in the below:
* "No Hate Speech": The prompt shall not contain or seek generation of content
that expresses, incites, or promotes hate based on race, gender, ethnicity,
religion, nationality, sexual orientation, disability status, or caste.
* "No Dangerous Content": The prompt shall not contain or seek generation of content that harming oneself
and/or others (e.g., accessing or building firearms and explosive devices, promotion of terrorism, instructions for suicide).
* "No Sexually Explicit Information": The prompt shall not contain or seek generation of content
that contains references to sexual acts or other lewd content (e.g., sexually graphic descriptions,
content aimed at causing arousal). Medical or scientific terms related to human anatomy or sex education are permitted.
* "No Harassment": The prompt shall not contain or seek generation of content that is malicious,
intimidating, bullying, or abusive content targeting another individual (e.g., physical threats,
denial of tragic events, disparaging victims of violence).
===
Does the human prompt/question violate the above principle? Your answer must start with
'Yes' or 'No'. And then walk through step by step to be sure we answer
correctly.
output_parser: is_content_safe
streaming: False
"""
config = RailsConfig.from_content(yaml_content=YAML_CONFIG)
app = LLMRails(config=config, llm=client)
options = {"output_vars": ["triggered_input_rail", "triggered_output_rail"]}
output = app.generate(messages=messages, options=options)
warning_message = output.output_data["triggered_input_rail"] or output.output_data["triggered_output_rail"]
if warning_message:
gr.Warning(f"Guardrail triggered: {warning_message}")
chat = [output.response[0]['content']]
yield chat[0]
else:
output = llm.create_chat_completion(
messages=[
{
"role": "system",
"content": "Your go-to Email Marketing Guru - I'm here to help you craft short and concise compelling campaigns, boost conversions, and take your business to the next level.",
},
{"role": "user", "content": user_reques},
],
max_tokens=2048,
temperature=0.7,
stream=True
)
partial_message = ""
audio_list = []
for chunk in output:
delta = chunk['choices'][0]['delta']
if 'content' in delta:
#audio_list.append([text_to_speech(delta.get('content', ''))])
#processed_audio = combine_audio_files(audio_list)
partial_message = partial_message + delta.get('content', '')
yield partial_message
audio = gr.Audio()
demo = gr.Interface(fn=greet, inputs=["text","text"], concurrency_limit=10, outputs=["text"])
demo.launch()