Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -24,14 +24,9 @@ from langchain_openai import OpenAIEmbeddings, ChatOpenAI
|
|
24 |
from langchain_community.vectorstores import Qdrant
|
25 |
from langchain_community.embeddings import HuggingFaceEmbeddings
|
26 |
from langchain_community.embeddings import SentenceTransformerEmbeddings
|
27 |
-
|
28 |
|
29 |
torch.cuda.empty_cache()
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
import nest_asyncio
|
35 |
nest_asyncio.apply()
|
36 |
co = cohere.Client(st.secrets["COHERE_API_KEY"])
|
37 |
|
@@ -102,17 +97,40 @@ retriever = vector_store.as_retriever()
|
|
102 |
|
103 |
|
104 |
|
105 |
-
#
|
106 |
-
|
107 |
-
selected_model = st.session_state["selected_model"]
|
108 |
if "OpenAI" in selected_model:
|
109 |
-
|
|
|
|
|
110 |
elif "LLaMA" in selected_model:
|
111 |
-
|
112 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
elif "Gemini" in selected_model:
|
114 |
-
|
115 |
-
st.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
else:
|
117 |
st.error("Unsupported model selected.")
|
118 |
st.stop()
|
@@ -290,7 +308,14 @@ def get_reranked_response(query: str):
|
|
290 |
reranked_docs = rerank_with_cohere(query, docs)
|
291 |
context = "\n\n".join([doc.page_content for doc in reranked_docs])
|
292 |
prompt = AI_PROMPT_TEMPLATE.format(question=query, context=context)
|
293 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
294 |
|
295 |
# === App UI ===
|
296 |
|
|
|
24 |
from langchain_community.vectorstores import Qdrant
|
25 |
from langchain_community.embeddings import HuggingFaceEmbeddings
|
26 |
from langchain_community.embeddings import SentenceTransformerEmbeddings
|
27 |
+
import nest_asyncio
|
28 |
|
29 |
torch.cuda.empty_cache()
|
|
|
|
|
|
|
|
|
|
|
30 |
nest_asyncio.apply()
|
31 |
co = cohere.Client(st.secrets["COHERE_API_KEY"])
|
32 |
|
|
|
97 |
|
98 |
|
99 |
|
100 |
+
#selected_model = st.session_state["selected_model"]
|
101 |
+
|
|
|
102 |
if "OpenAI" in selected_model:
|
103 |
+
from langchain_openai import ChatOpenAI
|
104 |
+
llm = ChatOpenAI(model="gpt-4o", temperature=0.2, api_key=st.secrets["OPENAI_API_KEY"])
|
105 |
+
|
106 |
elif "LLaMA" in selected_model:
|
107 |
+
from groq import Groq
|
108 |
+
|
109 |
+
client = Groq(api_key=st.secrets["GROQ_API_KEY"]) # Store in `.streamlit/secrets.toml`
|
110 |
+
def get_llama_response(prompt):
|
111 |
+
completion = client.chat.completions.create(
|
112 |
+
model="meta-llama/llama-4-maverick-17b-128e-instruct",
|
113 |
+
messages=[{"role": "user", "content": prompt}],
|
114 |
+
temperature=1,
|
115 |
+
max_completion_tokens=1024,
|
116 |
+
top_p=1,
|
117 |
+
stream=False
|
118 |
+
)
|
119 |
+
return completion.choices[0].message.content
|
120 |
+
|
121 |
+
llm = get_llama_response # use this in place of llm.invoke()
|
122 |
+
|
123 |
elif "Gemini" in selected_model:
|
124 |
+
import google.generativeai as genai
|
125 |
+
genai.configure(api_key=st.secrets["GEMINI_API_KEY"]) # Store in `.streamlit/secrets.toml`
|
126 |
+
|
127 |
+
gemini_model = genai.GenerativeModel("gemini-2.5-pro-exp-03-25")
|
128 |
+
def get_gemini_response(prompt):
|
129 |
+
response = gemini_model.generate_content(prompt)
|
130 |
+
return response.text
|
131 |
+
|
132 |
+
llm = get_gemini_response
|
133 |
+
|
134 |
else:
|
135 |
st.error("Unsupported model selected.")
|
136 |
st.stop()
|
|
|
308 |
reranked_docs = rerank_with_cohere(query, docs)
|
309 |
context = "\n\n".join([doc.page_content for doc in reranked_docs])
|
310 |
prompt = AI_PROMPT_TEMPLATE.format(question=query, context=context)
|
311 |
+
|
312 |
+
if callable(llm):
|
313 |
+
# Gemini or LLaMA
|
314 |
+
return type("Obj", (), {"content": llm(prompt)})
|
315 |
+
else:
|
316 |
+
# OpenAI LangChain interface
|
317 |
+
return llm.invoke([{"role": "system", "content": prompt}])
|
318 |
+
|
319 |
|
320 |
# === App UI ===
|
321 |
|