Update app.py
Browse files
app.py
CHANGED
@@ -1,19 +1,27 @@
|
|
1 |
import gradio as gr
|
2 |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
|
3 |
from langchain_community.llms import HuggingFacePipeline
|
|
|
4 |
from langchain_core.prompts import PromptTemplate
|
5 |
from langchain.chains import LLMChain
|
6 |
-
from langchain_core.memory import ConversationBufferMemory
|
7 |
|
8 |
# Load model and tokenizer
|
9 |
-
model_name = "microsoft/DialoGPT-medium"
|
10 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
11 |
model = AutoModelForCausalLM.from_pretrained(model_name)
|
12 |
|
13 |
# Create text-generation pipeline
|
14 |
-
pipe = pipeline(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
|
16 |
-
# Wrap with
|
17 |
llm = HuggingFacePipeline(pipeline=pipe)
|
18 |
|
19 |
# Prompt Template
|
@@ -27,10 +35,10 @@ prompt = PromptTemplate(
|
|
27 |
template=template
|
28 |
)
|
29 |
|
30 |
-
#
|
31 |
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
|
32 |
|
33 |
-
# Chain
|
34 |
llm_chain = LLMChain(
|
35 |
llm=llm,
|
36 |
prompt=prompt,
|
@@ -50,8 +58,9 @@ demo = gr.ChatInterface(
|
|
50 |
title="AI Chatbot",
|
51 |
description="A simple chatbot using LangChain + HuggingFace + Gradio",
|
52 |
theme="default",
|
53 |
-
|
54 |
)
|
55 |
|
|
|
56 |
if __name__ == "__main__":
|
57 |
demo.queue().launch(share=True)
|
|
|
1 |
import gradio as gr
|
2 |
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
|
3 |
from langchain_community.llms import HuggingFacePipeline
|
4 |
+
from langchain_community.memory import ConversationBufferMemory
|
5 |
from langchain_core.prompts import PromptTemplate
|
6 |
from langchain.chains import LLMChain
|
|
|
7 |
|
8 |
# Load model and tokenizer
|
9 |
+
model_name = "microsoft/DialoGPT-medium" # You can change this to another HF model if needed
|
10 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
11 |
model = AutoModelForCausalLM.from_pretrained(model_name)
|
12 |
|
13 |
# Create text-generation pipeline
|
14 |
+
pipe = pipeline(
|
15 |
+
"text-generation",
|
16 |
+
model=model,
|
17 |
+
tokenizer=tokenizer,
|
18 |
+
max_length=1000,
|
19 |
+
do_sample=True,
|
20 |
+
truncation=True, # Explicit truncation to avoid HF warnings
|
21 |
+
pad_token_id=tokenizer.eos_token_id # Prevents warning for open-end generation
|
22 |
+
)
|
23 |
|
24 |
+
# Wrap with LangChain LLM wrapper
|
25 |
llm = HuggingFacePipeline(pipeline=pipe)
|
26 |
|
27 |
# Prompt Template
|
|
|
35 |
template=template
|
36 |
)
|
37 |
|
38 |
+
# Conversation memory (stores past messages)
|
39 |
memory = ConversationBufferMemory(memory_key="chat_history", return_messages=True)
|
40 |
|
41 |
+
# LangChain LLM Chain
|
42 |
llm_chain = LLMChain(
|
43 |
llm=llm,
|
44 |
prompt=prompt,
|
|
|
58 |
title="AI Chatbot",
|
59 |
description="A simple chatbot using LangChain + HuggingFace + Gradio",
|
60 |
theme="default",
|
61 |
+
type="chat" # Uses newer format to avoid Gradio tuple warnings
|
62 |
)
|
63 |
|
64 |
+
# Launch
|
65 |
if __name__ == "__main__":
|
66 |
demo.queue().launch(share=True)
|