manik-hossain commited on
Commit
c69ad20
·
1 Parent(s): f4d9678
Files changed (2) hide show
  1. app.py +84 -0
  2. requirements.txt +7 -0
app.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import os
3
+ import redis
4
+ from langchain.docstore.document import Document
5
+ from sklearn.datasets import fetch_20newsgroups
6
+ from langchain_huggingface import HuggingFaceEmbeddings
7
+ from langchain_redis import RedisConfig, RedisVectorStore
8
+ from langchain_openai import ChatOpenAI
9
+ from langchain_core.output_parsers import StrOutputParser
10
+ from langchain_core.prompts import ChatPromptTemplate
11
+ from langchain_core.runnables import RunnablePassthrough
12
+ from dotenv import load_dotenv
13
+ load_dotenv()
14
+
15
+ @st.cache_resource
16
+ def load():
17
+ openai_api_key = os.getenv("OPENAI_API_KEY")
18
+ REDIS_URL = os.getenv("REDIS_URL", "redis://localhost:6379")
19
+ redis_client = redis.from_url(REDIS_URL)
20
+
21
+
22
+
23
+
24
+ categories = ["alt.atheism", "sci.space"]
25
+ newsgroups = fetch_20newsgroups(
26
+ subset="train", categories=categories, shuffle=True, random_state=42
27
+ )
28
+
29
+ # Use only the first 250 documents
30
+ texts = newsgroups.data[:250]
31
+ metadata = [
32
+ {"category": newsgroups.target_names[target]} for target in newsgroups.target[:250]
33
+ ]
34
+
35
+ embeddings = HuggingFaceEmbeddings(model_name="msmarco-distilbert-base-v4")
36
+
37
+
38
+
39
+ config = RedisConfig(
40
+ index_name="newsgroups",
41
+ redis_url=REDIS_URL,
42
+ metadata_schema=[
43
+ {"name": "category", "type": "tag"},
44
+ ],
45
+ )
46
+
47
+ vector_store = RedisVectorStore(embeddings, config=config)
48
+ ids = vector_store.add_texts(texts, metadata)
49
+ retriever = vector_store.as_retriever(search_type="similarity", search_kwargs={"k": 2})
50
+ llm = ChatOpenAI(model="gpt-4o", temperature=0, base_url="https://models.inference.ai.azure.com", api_key=openai_api_key)
51
+
52
+ return retriever, llm
53
+
54
+ retriever, llm = load()
55
+
56
+ # Prompt
57
+ prompt = ChatPromptTemplate.from_messages(
58
+ [
59
+ (
60
+ "human",
61
+ """You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know. Use three sentences maximum and keep the answer concise.
62
+ Question: {question}
63
+ Context: {context}
64
+ Answer:""",
65
+ ),
66
+ ]
67
+ )
68
+
69
+
70
+ def format_docs(docs):
71
+ return "\n\n".join(doc.page_content for doc in docs)
72
+
73
+
74
+ rag_chain = (
75
+ {"context": retriever | format_docs, "question": RunnablePassthrough()}
76
+ | prompt
77
+ | llm
78
+ | StrOutputParser()
79
+ )
80
+
81
+ if query:=st.chat_input("Ask a question"):
82
+ response = rag_chain.invoke(query)
83
+ with st.chat_message("assistant"):
84
+ st.write(response)
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ langchain
2
+ langchain_openai
3
+ streamlit
4
+ langchain-redis
5
+ langchain-huggingface
6
+ sentence-transformers
7
+ scikit-learn