|
import os |
|
import re |
|
import time |
|
import random |
|
import gradio as gr |
|
from huggingface_hub import InferenceClient |
|
|
|
|
|
|
|
|
|
|
|
|
|
client = InferenceClient( |
|
model="HuggingFaceH4/zephyr-7b-beta" |
|
) |
|
|
|
|
|
ENABLE_SCRAPING = False |
|
SITE_URL = "https://your-agri-future-site.com" |
|
|
|
|
|
knowledge_base = "" |
|
|
|
|
|
if ENABLE_SCRAPING: |
|
try: |
|
from selenium import webdriver |
|
from selenium.webdriver.chrome.options import Options |
|
from selenium.webdriver.common.by import By |
|
|
|
def scrape_site(url): |
|
options = Options() |
|
options.headless = True |
|
driver = webdriver.Chrome(options=options) |
|
driver.get(url) |
|
|
|
time.sleep(5) |
|
try: |
|
|
|
content_element = driver.find_element(By.ID, "content") |
|
page_text = content_element.text |
|
except Exception as e: |
|
page_text = "Error encountered during scraping: " + str(e) |
|
driver.quit() |
|
return page_text |
|
|
|
knowledge_base = scrape_site(SITE_URL) |
|
print("Scraped knowledge base successfully.") |
|
except Exception as e: |
|
print("Scraping failed or Selenium is not configured:", e) |
|
else: |
|
print("Scraping is disabled; proceeding without scraped site content.") |
|
|
|
|
|
|
|
def is_greeting(query: str, lang: str) -> bool: |
|
greetings = { |
|
"en": ["hello", "hi", "hey", "good morning", "good afternoon", "good evening"], |
|
"fr": ["bonjour", "salut", "coucou", "bonsoir"], |
|
"am": ["ሰላም", "ሰላም እንደምን", "እንዴት"] |
|
} |
|
greet_list = greetings.get(lang, greetings["en"]) |
|
|
|
if lang != "am": |
|
query = query.lower() |
|
return any(query.startswith(greet) for greet in greet_list) |
|
|
|
def generate_dynamic_greeting(language: str) -> str: |
|
""" |
|
Generate a dynamic, context-relevant greeting using the Hugging Face Inference API. |
|
""" |
|
system_prompts = { |
|
"en": ( |
|
"You are a friendly chatbot specializing in agriculture and agro-investment. " |
|
"A user just greeted you. Generate a warm, dynamic greeting message in English that is context-aware and encourages discussion about agriculture or agro-investment." |
|
), |
|
"fr": ( |
|
"Vous êtes un chatbot chaleureux spécialisé dans l'agriculture et les investissements agroalimentaires. " |
|
"Un utilisateur vient de vous saluer. Générez un message de salutation dynamique et chaleureux en français, en restant pertinent par rapport à l'agriculture ou aux investissements agroalimentaires." |
|
), |
|
"am": ( |
|
"እርስዎ በግብርናና በአገልግሎት ስርዓተ-ቢዝነስ ውስጥ ባለሙያ ቻትቦት ናቸው። " |
|
"ተጠቃሚው በአማርኛ ሰላም መልእክት አስቀድመዋል። " |
|
"በአማርኛ ተዛማጅ እና ትክክለኛ የሆነ ሰላም መልእክት ፍጥረት ያድርጉ።" |
|
) |
|
} |
|
prompt = system_prompts.get(language, system_prompts["en"]) |
|
messages = [{"role": "system", "content": prompt}] |
|
response = client.chat_completion( |
|
messages, |
|
max_tokens=128, |
|
stream=False, |
|
temperature=1, |
|
top_p=0.95, |
|
) |
|
try: |
|
greeting_message = response.choices[0].message.content |
|
except AttributeError: |
|
greeting_message = str(response) |
|
return greeting_message.strip() |
|
|
|
def generate_dynamic_out_of_scope_message(language: str) -> str: |
|
""" |
|
Generate a dynamic out-of-scope message using the Hugging Face Inference API. |
|
""" |
|
system_prompts = { |
|
"en": ( |
|
"You are a helpful chatbot specializing in agriculture and agro-investment. " |
|
"A user just asked a question that is not related to these topics. " |
|
"Generate a friendly, varied, and intelligent out-of-scope response in English that kindly encourages the user to ask about agriculture or agro-investment." |
|
), |
|
"fr": ( |
|
"Vous êtes un chatbot utile spécialisé dans l'agriculture et les investissements agroalimentaires. " |
|
"Un utilisateur vient de poser une question qui ne concerne pas ces sujets. " |
|
"Générez une réponse élégante, variée et intelligente en français pour indiquer que la question est hors de portée, en invitant l'utilisateur à poser une question sur l'agriculture ou les investissements agroalimentaires." |
|
), |
|
"am": ( |
|
"እርስዎ በግብርናና በአገልግሎት ስርዓተ-ቢዝነስ ውስጥ በተለይ የተሞሉ ቻትቦት ናቸው። " |
|
"ተጠቃሚው ለግብርና ወይም ለአገልግሎት ስርዓተ-ቢዝነስ ተያይዞ ያልሆነ ጥያቄ አስቀድመዋል። " |
|
"በአማርኛ በተለያዩ መልኩ የውጭ ክፍል መልእክት ፍጥረት ያድርጉ፤ እባኮትን ተጠቃሚውን ለግብርና ወይም ለአገልግሎት ጥያቄዎች ለመጠየቅ ያነጋግሩ።" |
|
) |
|
} |
|
prompt = system_prompts.get(language, system_prompts["en"]) |
|
messages = [{"role": "system", "content": prompt}] |
|
response = client.chat_completion( |
|
messages, |
|
max_tokens=128, |
|
stream=False, |
|
temperature=1, |
|
top_p=0.95, |
|
) |
|
try: |
|
out_message = response.choices[0].message.content |
|
except AttributeError: |
|
out_message = str(response) |
|
return out_message.strip() |
|
|
|
def is_domain_query(query: str) -> bool: |
|
""" |
|
Check if a query relates to agriculture or agro-investment. |
|
""" |
|
domain_keywords = [ |
|
"agriculture", "farming", "crop", "agro", "investment", "soil", |
|
"irrigation", "harvest", "organic", "sustainable", "agribusiness", |
|
"livestock", "agroalimentaire", "agriculture durable", |
|
"greenhouse", "horticulture", "pesticide", "fertilizer", |
|
"rural development", "food production", "crop yield", "farm equipment", |
|
"agronomy", "farming techniques", "organic farming", "agro-tech", |
|
"farm management", "agrifood" |
|
] |
|
return any(re.search(r"\b" + keyword + r"\b", query, re.IGNORECASE) for keyword in domain_keywords) |
|
|
|
def retrieve_relevant_snippet(query: str, text: str, max_length: int = 300) -> str: |
|
""" |
|
Retrieve a relevant snippet from the text based on the query. |
|
""" |
|
sentences = re.split(r'[.?!]', text) |
|
for sentence in sentences: |
|
if is_domain_query(sentence) and all(word.lower() in sentence.lower() for word in query.split()): |
|
snippet = sentence.strip() |
|
return snippet[:max_length] + "..." if len(snippet) > max_length else snippet |
|
return "" |
|
|
|
|
|
def respond(message, history: list, system_message, max_tokens, temperature, top_p, language): |
|
|
|
if is_greeting(message, language): |
|
yield generate_dynamic_greeting(language) |
|
return |
|
|
|
|
|
if not is_domain_query(message): |
|
yield generate_dynamic_out_of_scope_message(language) |
|
return |
|
|
|
|
|
messages_list = [{"role": "system", "content": system_message}] |
|
for user_msg, assistant_msg in history: |
|
if user_msg: |
|
messages_list.append({"role": "user", "content": user_msg}) |
|
if assistant_msg: |
|
messages_list.append({"role": "assistant", "content": assistant_msg}) |
|
|
|
|
|
if knowledge_base: |
|
snippet = retrieve_relevant_snippet(message, knowledge_base) |
|
if snippet: |
|
retrieval_context = f"Reference from Agri Future Investment platform: {snippet}" |
|
messages_list.insert(0, {"role": "system", "content": retrieval_context}) |
|
|
|
messages_list.append({"role": "user", "content": message}) |
|
|
|
|
|
response_text = "" |
|
for partial_response in client.chat_completion( |
|
messages_list, |
|
max_tokens=1024, |
|
stream=True, |
|
temperature=temperature, |
|
top_p=top_p, |
|
): |
|
if partial_response.choices and partial_response.choices[0].delta: |
|
token = partial_response.choices[0].delta.content |
|
if token: |
|
response_text += token |
|
yield response_text |
|
|
|
demo = gr.ChatInterface( |
|
fn=respond, |
|
additional_inputs=[ |
|
gr.Textbox( |
|
value="You are AgriFutureBot, a specialized assistant for agriculture and agro-investment insights.", |
|
label="System Message" |
|
), |
|
gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max New Tokens"), |
|
gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"), |
|
gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (Nucleus Sampling)"), |
|
gr.Dropdown(choices=["en", "fr", "am"], value="en", label="Language") |
|
], |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |