File size: 2,910 Bytes
c69ad20
 
7375c28
c69ad20
 
 
 
 
 
 
 
c1d6585
 
c69ad20
7375c28
 
 
c69ad20
 
 
7375c28
448555e
 
552eb94
7375c28
 
 
 
 
 
 
c69ad20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
import streamlit as st
import os
# import redis
from langchain.docstore.document import Document
from sklearn.datasets import fetch_20newsgroups
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_redis import RedisConfig, RedisVectorStore
from langchain_openai import ChatOpenAI
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.runnables import RunnablePassthrough
# from dotenv import load_dotenv
# load_dotenv()

redis_pass = os.getenv("REDIS_PASS")
url = f"redis://default:{redis_pass}@redis-14461.c264.ap-south-1-1.ec2.redns.redis-cloud.com:14461"
# st.write(url)
@st.cache_resource
def load():
    openai_api_key = os.getenv("OPENAI_API_KEY")
    # redis_pass = os.getenv("REDIS_PASS")
    # REDIS_URL = os.getenv("REDIS_URL", url)
    REDIS_URL = url
    # redis_client = redis.from_url(REDIS_URL)
#     redis_client = redis.Redis(
#     host='redis-14461.c264.ap-south-1-1.ec2.redns.redis-cloud.com',
#     port=14461,
#     decode_responses=True,
#     username="default",
#     password=redis_pass,
# )




    categories = ["alt.atheism", "sci.space"]
    newsgroups = fetch_20newsgroups(
        subset="train", categories=categories, shuffle=True, random_state=42
    )

    # Use only the first 250 documents
    texts = newsgroups.data[:250]
    metadata = [
        {"category": newsgroups.target_names[target]} for target in newsgroups.target[:250]
    ]

    embeddings = HuggingFaceEmbeddings(model_name="msmarco-distilbert-base-v4")



    config = RedisConfig(
        index_name="newsgroups",
        redis_url=REDIS_URL,
        metadata_schema=[
            {"name": "category", "type": "tag"},
        ],
    )

    vector_store = RedisVectorStore(embeddings, config=config)
    ids = vector_store.add_texts(texts, metadata)
    retriever = vector_store.as_retriever(search_type="similarity", search_kwargs={"k": 2})
    llm = ChatOpenAI(model="gpt-4o", temperature=0, base_url="https://models.inference.ai.azure.com", api_key=openai_api_key)

    return retriever, llm

retriever, llm = load()

# Prompt
prompt = ChatPromptTemplate.from_messages(
    [
        (
            "human",
            """You are an assistant for question-answering tasks. Use the following pieces of retrieved context to answer the question. If you don't know the answer, just say that you don't know. Use three sentences maximum and keep the answer concise.
Question: {question} 
Context: {context} 
Answer:""",
        ),
    ]
)


def format_docs(docs):
    return "\n\n".join(doc.page_content for doc in docs)


rag_chain = (
    {"context": retriever | format_docs, "question": RunnablePassthrough()}
    | prompt
    | llm
    | StrOutputParser()
)

if query:=st.chat_input("Ask a question"):
    response = rag_chain.invoke(query)
    with st.chat_message("assistant"):
        st.write(response)