Sanchit2207's picture
Update app.py
49c9a25 verified
raw
history blame
2.22 kB
import gradio as gr
from transformers import pipeline, AutoTokenizer, AutoModelForCausalLM
import torch
# ---------------- Agent 1: Intent Classifier ----------------
intent_classifier = pipeline("zero-shot-classification", model="MoritzLaurer/DeBERTa-v3-base-mnli-fever-anli")
def detect_intent(text):
labels = {
"weather": "The user wants to know the weather.",
"faq": "The user is asking for help.",
"smalltalk": "The user is making casual conversation."
}
best_intent = "smalltalk"
best_score = 0
for label, hypothesis in labels.items():
result = intent_classifier(text=text, text_pair=hypothesis)[0]
if result['label'] == 'ENTAILMENT' and result['score'] > best_score:
best_score = result['score']
best_intent = label
return best_intent
# ---------------- Agent 2: Domain Logic ----------------
def handle_logic(intent):
if intent == "weather":
return "It's sunny and 26°C today."
elif intent == "faq":
return "To reset your password, use the 'Forgot Password' option."
else:
return "That's great! Anything else you'd like to talk about?"
# ---------------- Agent 3: Natural Language Generation ----------------
tokenizer = AutoTokenizer.from_pretrained("microsoft/DialoGPT-small")
model = AutoModelForCausalLM.from_pretrained("microsoft/DialoGPT-small")
def generate_reply(prompt):
input_ids = tokenizer.encode(prompt + tokenizer.eos_token, return_tensors='pt')
output_ids = model.generate(input_ids, max_length=100, pad_token_id=tokenizer.eos_token_id)
response = tokenizer.decode(output_ids[:, input_ids.shape[-1]:][0], skip_special_tokens=True)
return response
# ---------------- Chatbot Pipeline ----------------
def chatbot(user_input):
intent = detect_intent(user_input)
logic = handle_logic(intent)
response = generate_reply(logic)
return response
# ---------------- Gradio UI ----------------
gr.Interface(
fn=chatbot,
inputs=gr.Textbox(label="User Input"),
outputs=gr.Textbox(label="Chatbot Response"),
title="3-Agent Chatbot",
description="Intent Detection → Domain Logic → Natural Language Generation"
).launch()