Spaces:
Running
Running
import gradio as gr | |
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM | |
import torch | |
# ---------------- Agent 1: Intent Classifier ---------------- | |
intent_classifier = pipeline("zero-shot-classification", model="MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli") | |
def detect_intent(text): | |
labels = { | |
"weather": "The user wants to know the weather.", | |
"faq": "The user is asking for help.", | |
"smalltalk": "The user is making casual conversation." | |
} | |
best_intent = "smalltalk" | |
best_score = 0 | |
for label, hypothesis in labels.items(): | |
result = intent_classifier(text=text, text_pair=hypothesis)[0] | |
if result['label'] == 'ENTAILMENT' and result['score'] > best_score: | |
best_score = result['score'] | |
best_intent = label | |
return best_intent | |
# ---------------- Agent 2: Domain Logic ---------------- | |
def handle_logic(intent): | |
if intent == "weather": | |
return "It's sunny and 26°C today." | |
elif intent == "faq": | |
return "To reset your password, use the 'Forgot Password' option." | |
else: | |
return "That's great! Anything else you'd like to talk about?" | |
# ---------------- Agent 3: Natural Language Generation ---------------- | |
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-small") | |
model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-small") | |
def generate_reply(prompt): | |
input_ids = tokenizer.encode(prompt + tokenizer.eos_token, return_tensors='pt') | |
output_ids = model.generate(input_ids, max_length=100, pad_token_id=tokenizer.eos_token_id) | |
response = tokenizer.decode(output_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True) | |
return response | |
# ---------------- Chatbot Pipeline ---------------- | |
def chatbot(user_input): | |
intent = detect_intent(user_input) | |
logic = handle_logic(intent) | |
response = generate_reply(logic) | |
return response | |
# ---------------- Gradio UI ---------------- | |
gr.Interface( | |
fn=chatbot, | |
inputs=gr.Textbox(label="User Input"), | |
outputs=gr.Textbox(label="Chatbot Response"), | |
title="3-Agent Chatbot", | |
description="Intent Detection → Domain Logic → Natural Language Generation" | |
).launch() | |