santhoshraghu commited on
Commit
f46044d
·
verified ·
1 Parent(s): b0d4516

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +40 -15
app.py CHANGED
@@ -24,14 +24,9 @@ from langchain_openai import OpenAIEmbeddings, ChatOpenAI
24
  from langchain_community.vectorstores import Qdrant
25
  from langchain_community.embeddings import HuggingFaceEmbeddings
26
  from langchain_community.embeddings import SentenceTransformerEmbeddings
27
-
28
 
29
  torch.cuda.empty_cache()
30
-
31
-
32
-
33
-
34
- import nest_asyncio
35
  nest_asyncio.apply()
36
  co = cohere.Client(st.secrets["COHERE_API_KEY"])
37
 
@@ -102,17 +97,40 @@ retriever = vector_store.as_retriever()
102
 
103
 
104
 
105
- # Dynamically initialize LLM based on selection
106
- OPENAI_API_KEY = st.secrets["OPENAI_API_KEY"]
107
- selected_model = st.session_state["selected_model"]
108
  if "OpenAI" in selected_model:
109
- llm = ChatOpenAI(model="gpt-4o", temperature=0.2, api_key=OPENAI_API_KEY)
 
 
110
  elif "LLaMA" in selected_model:
111
- st.warning("LLaMA integration is not implemented yet.")
112
- st.stop()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  elif "Gemini" in selected_model:
114
- st.warning("Gemini integration is not implemented yet.")
115
- st.stop()
 
 
 
 
 
 
 
 
116
  else:
117
  st.error("Unsupported model selected.")
118
  st.stop()
@@ -290,7 +308,14 @@ def get_reranked_response(query: str):
290
  reranked_docs = rerank_with_cohere(query, docs)
291
  context = "\n\n".join([doc.page_content for doc in reranked_docs])
292
  prompt = AI_PROMPT_TEMPLATE.format(question=query, context=context)
293
- return llm.invoke([{"role": "system", "content": prompt}])
 
 
 
 
 
 
 
294
 
295
  # === App UI ===
296
 
 
24
  from langchain_community.vectorstores import Qdrant
25
  from langchain_community.embeddings import HuggingFaceEmbeddings
26
  from langchain_community.embeddings import SentenceTransformerEmbeddings
27
+ import nest_asyncio
28
 
29
  torch.cuda.empty_cache()
 
 
 
 
 
30
  nest_asyncio.apply()
31
  co = cohere.Client(st.secrets["COHERE_API_KEY"])
32
 
 
97
 
98
 
99
 
100
+ #selected_model = st.session_state["selected_model"]
101
+
 
102
  if "OpenAI" in selected_model:
103
+ from langchain_openai import ChatOpenAI
104
+ llm = ChatOpenAI(model="gpt-4o", temperature=0.2, api_key=st.secrets["OPENAI_API_KEY"])
105
+
106
  elif "LLaMA" in selected_model:
107
+ from groq import Groq
108
+
109
+ client = Groq(api_key=st.secrets["GROQ_API_KEY"]) # Store in `.streamlit/secrets.toml`
110
+ def get_llama_response(prompt):
111
+ completion = client.chat.completions.create(
112
+ model="meta-llama/llama-4-maverick-17b-128e-instruct",
113
+ messages=[{"role": "user", "content": prompt}],
114
+ temperature=1,
115
+ max_completion_tokens=1024,
116
+ top_p=1,
117
+ stream=False
118
+ )
119
+ return completion.choices[0].message.content
120
+
121
+ llm = get_llama_response # use this in place of llm.invoke()
122
+
123
  elif "Gemini" in selected_model:
124
+ import google.generativeai as genai
125
+ genai.configure(api_key=st.secrets["GEMINI_API_KEY"]) # Store in `.streamlit/secrets.toml`
126
+
127
+ gemini_model = genai.GenerativeModel("gemini-2.5-pro-exp-03-25")
128
+ def get_gemini_response(prompt):
129
+ response = gemini_model.generate_content(prompt)
130
+ return response.text
131
+
132
+ llm = get_gemini_response
133
+
134
  else:
135
  st.error("Unsupported model selected.")
136
  st.stop()
 
308
  reranked_docs = rerank_with_cohere(query, docs)
309
  context = "\n\n".join([doc.page_content for doc in reranked_docs])
310
  prompt = AI_PROMPT_TEMPLATE.format(question=query, context=context)
311
+
312
+ if callable(llm):
313
+ # Gemini or LLaMA
314
+ return type("Obj", (), {"content": llm(prompt)})
315
+ else:
316
+ # OpenAI LangChain interface
317
+ return llm.invoke([{"role": "system", "content": prompt}])
318
+
319
 
320
  # === App UI ===
321