kurakula-Prashanth2004 commited on
Commit
a4680e8
·
verified ·
1 Parent(s): 39d7947

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +529 -0
app.py ADDED
@@ -0,0 +1,529 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pathlib
3
+ from tokenizers.normalizers import BertNormalizer
4
+ from langchain_google_genai import ChatGoogleGenerativeAI
5
+ from langchain_core.prompts import PromptTemplate
6
+ from langchain_core.documents import Document
7
+ import numpy as np
8
+ import pandas as pd
9
+ from sklearn.metrics.pairwise import cosine_similarity
10
+ from transformers import AutoTokenizer, AutoModel
11
+ import ast
12
+ import torch
13
+ from typing import Dict, Any, List
14
+ from bert_score import score as bert_score
15
+ from rouge_score import rouge_scorer
16
+ import warnings
17
+ import streamlit as st
18
+ import plotly.graph_objects as go
19
+ import plotly.express as px
20
+
21
+ # Set page config to wide layout at the start
22
+ st.set_page_config(
23
+ layout="wide",
24
+ page_title="Alloy Based Chatbot",
25
+ page_icon="🔍"
26
+ )
27
+
28
+ warnings.filterwarnings('ignore')
29
+
30
+ # Set up Google API key
31
+ os.environ["GOOGLE_API_KEY"] = st.secrets["google"]["GOOGLE_API_KEY"]
32
+
33
+ # Initialize session state
34
+ if 'page' not in st.session_state:
35
+ st.session_state.page = 'home'
36
+ if 'question' not in st.session_state:
37
+ st.session_state.question = ''
38
+ if 'results' not in st.session_state:
39
+ st.session_state.results = None
40
+ if 'selected_context' not in st.session_state:
41
+ st.session_state.selected_context = None
42
+
43
+ file_path = "vocab_mappings.txt"
44
+ with open(file_path, 'r', encoding='utf-8') as f:
45
+ mappings = f.read().strip().split('\n')
46
+
47
+ mappings = {m[0]: m[2:] for m in mappings}
48
+
49
+ norm = BertNormalizer(lowercase=False, strip_accents=True, clean_text=True, handle_chinese_chars=True)
50
+
51
+ def normalize(text):
52
+ text = [norm.normalize_str(s) for s in text.split('\n')]
53
+ out = []
54
+ for s in text:
55
+ norm_s = ''
56
+ for c in s:
57
+ norm_s += mappings.get(c, ' ')
58
+ out.append(norm_s)
59
+ return '\n'.join(out)
60
+
61
+ # Define the prompt template
62
+ template = """
63
+ You are an intelligent assistant designed to provide accurate and helpful answers based on the context provided. Follow these guidelines:
64
+ 1. Use only the information from the context to answer the question.
65
+ 2. If the context does not contain enough information to answer the question, say "I don't know" and do not make up an answer.
66
+ 3. Be concise and specific in your response.
67
+ 4. Always end your answer with "Thanks for asking!" to maintain a friendly tone.
68
+
69
+ Context: {context}
70
+
71
+ Question: {question}
72
+
73
+ Answer:
74
+ """
75
+ custom_rag_prompt = PromptTemplate.from_template(template)
76
+
77
+ # Initialize model
78
+ model = ChatGoogleGenerativeAI(model="gemini-2.0-flash", temperature=0.5)
79
+
80
+ class State:
81
+ def __init__(self, question: str):
82
+ self.question = question
83
+ self.context: List[Document] = []
84
+ self.answer: str = ""
85
+
86
+ def load_embeddings_from_csv(file_path: str):
87
+ print(f"Loading embeddings from CSV file: {file_path}")
88
+ df = pd.read_csv(file_path)
89
+ df['embedding'] = df['embedding'].apply(lambda x: np.array(ast.literal_eval(x)))
90
+ print("Embeddings loaded successfully.")
91
+ return df
92
+
93
+ def generate_query_embedding(query_text: str, model_name: str):
94
+ print(f"Generating query embedding using {model_name}...")
95
+ if model_name == "matscibert":
96
+ return generate_matscibert_embedding(query_text)
97
+ elif model_name == "bert":
98
+ return generate_bert_embedding(query_text)
99
+ else:
100
+ raise ValueError(f"Unknown model: {model_name}")
101
+
102
+ def generate_matscibert_embedding(query_text: str):
103
+ print("Generating Matscibert embedding...")
104
+ tokenizer = AutoTokenizer.from_pretrained('m3rg-iitd/matscibert')
105
+ model = AutoModel.from_pretrained('m3rg-iitd/matscibert')
106
+
107
+ norm_sents = [normalize(query_text)]
108
+ tokenized_sents = tokenizer(norm_sents, padding=True, truncation=True, return_tensors='pt')
109
+
110
+ with torch.no_grad():
111
+ last_hidden_state = model(**tokenized_sents).last_hidden_state
112
+
113
+ sentence_embedding = last_hidden_state.mean(dim=1).squeeze().numpy()
114
+ print("Matscibert embedding generated.")
115
+ return sentence_embedding
116
+
117
+ def generate_bert_embedding(query_text: str):
118
+ print("Generating BERT embedding...")
119
+ tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')
120
+ model = AutoModel.from_pretrained("bert-base-uncased")
121
+
122
+ encoded_input = tokenizer(query_text, return_tensors='pt', truncation=True, padding=True)
123
+ with torch.no_grad():
124
+ output = model(**encoded_input)
125
+
126
+ sentence_embedding = output.last_hidden_state.mean(dim=1).squeeze().numpy()
127
+ print("BERT embedding generated.")
128
+ return sentence_embedding
129
+
130
+ def retrieve(state: State, embeddings_df: pd.DataFrame, model_name: str):
131
+ print("Retrieving relevant documents...")
132
+ query_embedding = generate_query_embedding(state.question, model_name)
133
+ document_embeddings = np.array(embeddings_df['embedding'].tolist())
134
+ similarities = cosine_similarity([query_embedding], document_embeddings)
135
+ top_indices = similarities.argsort()[0][::-1]
136
+ state.context = [Document(page_content=embeddings_df.iloc[i]['document']) for i in top_indices[:3]]
137
+ print("Documents retrieved.")
138
+ return state
139
+
140
+ def generate(state: State):
141
+ print("Generating answer based on context and question...")
142
+ docs_content = "\n\n".join(doc.page_content for doc in state.context)
143
+ messages = custom_rag_prompt.invoke({"question": state.question, "context": docs_content})
144
+ response = model.invoke(messages)
145
+ state.answer = response.content
146
+ print("Answer generated.")
147
+ return state
148
+
149
+ def workflow(state_input: Dict[str, Any], embeddings_df: pd.DataFrame, model_name: str) -> Dict[str, Any]:
150
+ print(f"Running workflow for question: {state_input['question']} with model: {model_name}")
151
+ state = State(state_input["question"])
152
+ state = retrieve(state, embeddings_df, model_name)
153
+ state = generate(state)
154
+ print(f"Workflow complete for question: {state_input['question']}.")
155
+ return {"context": state.context, "answer": state.answer}
156
+
157
+ def compute_bertscore(answer: str, context: str) -> Dict[str, float]:
158
+ P, R, F1 = bert_score([answer], [context], lang="en")
159
+ return {
160
+ "BERTScore Precision": P.mean().item(),
161
+ "BERTScore Recall": R.mean().item(),
162
+ "BERTScore F1": F1.mean().item()
163
+ }
164
+
165
+ def compute_rouge(answer: str, context: str) -> Dict[str, float]:
166
+ scorer = rouge_scorer.RougeScorer(['rouge1', 'rouge2', 'rougeL'])
167
+ scores = scorer.score(context, answer)
168
+ return {
169
+ "ROUGE-1": scores["rouge1"].fmeasure,
170
+ "ROUGE-2": scores["rouge2"].fmeasure,
171
+ "ROUGE-L": scores["rougeL"].fmeasure
172
+ }
173
+
174
+ def evaluate_answer(answer: str, context: str) -> Dict[str, Dict[str, float]]:
175
+ return {
176
+ "BERTScore": compute_bertscore(answer, context),
177
+ "ROUGE": compute_rouge(answer, context)
178
+ }
179
+
180
+ @st.cache_resource
181
+ def load_data():
182
+ matscibert_csv = 'matscibert_embeddings.csv'
183
+ bert_csv = 'bert_embeddings.csv'
184
+ embeddings_df_matscibert = load_embeddings_from_csv(matscibert_csv)
185
+ embeddings_df_bert = load_embeddings_from_csv(bert_csv)
186
+ return embeddings_df_matscibert, embeddings_df_bert
187
+
188
+ embeddings_df_matscibert, embeddings_df_bert = load_data()
189
+
190
+ def ask_question(question: str):
191
+ print(f"Asking question: {question}")
192
+ matscibert_result = workflow({"question": question}, embeddings_df_matscibert, model_name="matscibert")
193
+ bert_result = workflow({"question": question}, embeddings_df_bert, model_name="bert")
194
+
195
+ matscibert_context = "\n\n".join(doc.page_content for doc in matscibert_result["context"])
196
+ matscibert_answer = matscibert_result["answer"]
197
+ matscibert_scores = evaluate_answer(matscibert_answer, matscibert_context)
198
+
199
+ bert_context = "\n\n".join(doc.page_content for doc in bert_result["context"])
200
+ bert_answer = bert_result["answer"]
201
+ bert_scores = evaluate_answer(bert_answer, bert_context)
202
+
203
+ return {
204
+ "matscibert": {
205
+ "Context": matscibert_context,
206
+ "Answer": matscibert_answer,
207
+ "Scores": matscibert_scores
208
+ },
209
+ "bert": {
210
+ "Context": bert_context,
211
+ "Answer": bert_answer,
212
+ "Scores": bert_scores
213
+ }
214
+ }
215
+
216
+ def create_bertscore_chart(scores: Dict[str, float]):
217
+ metrics = ['Precision', 'Recall', 'F1']
218
+ values = [scores['BERTScore Precision'], scores['BERTScore Recall'], scores['BERTScore F1']]
219
+
220
+ fig = go.Figure(data=[
221
+ go.Bar(
222
+ x=metrics,
223
+ y=values,
224
+ marker_color=['#4285F4', '#34A853', '#FBBC05'],
225
+ text=[f"{v:.4f}" for v in values],
226
+ textposition='auto'
227
+ )
228
+ ])
229
+
230
+ fig.update_layout(
231
+ title='BERTScore Metrics',
232
+ yaxis=dict(range=[0, 1]),
233
+ height=400
234
+ )
235
+
236
+ return fig
237
+
238
+ def create_rouge_chart(scores: Dict[str, float]):
239
+ metrics = ['ROUGE-1', 'ROUGE-2', 'ROUGE-L']
240
+ values = [scores['ROUGE-1'], scores['ROUGE-2'], scores['ROUGE-L']]
241
+
242
+ fig = go.Figure(data=[
243
+ go.Bar(
244
+ x=metrics,
245
+ y=values,
246
+ marker_color=['#EA4335', '#34A853', '#FBBC05'],
247
+ text=[f"{v:.4f}" for v in values],
248
+ textposition='auto'
249
+ )
250
+ ])
251
+
252
+ fig.update_layout(
253
+ title='ROUGE Metrics',
254
+ yaxis=dict(range=[0, 1]),
255
+ height=400
256
+ )
257
+
258
+ return fig
259
+
260
+ def create_comparison_chart(matscibert_scores: Dict[str, Dict[str, float]], bert_scores: Dict[str, Dict[str, float]]):
261
+ metrics = ['Precision', 'Recall', 'F1', 'ROUGE-1', 'ROUGE-2', 'ROUGE-L']
262
+ matscibert_values = [
263
+ matscibert_scores['BERTScore']['BERTScore Precision'],
264
+ matscibert_scores['BERTScore']['BERTScore Recall'],
265
+ matscibert_scores['BERTScore']['BERTScore F1'],
266
+ matscibert_scores['ROUGE']['ROUGE-1'],
267
+ matscibert_scores['ROUGE']['ROUGE-2'],
268
+ matscibert_scores['ROUGE']['ROUGE-L']
269
+ ]
270
+
271
+ bert_values = [
272
+ bert_scores['BERTScore']['BERTScore Precision'],
273
+ bert_scores['BERTScore']['BERTScore Recall'],
274
+ bert_scores['BERTScore']['BERTScore F1'],
275
+ bert_scores['ROUGE']['ROUGE-1'],
276
+ bert_scores['ROUGE']['ROUGE-2'],
277
+ bert_scores['ROUGE']['ROUGE-L']
278
+ ]
279
+
280
+ fig = go.Figure()
281
+
282
+ fig.add_trace(go.Bar(
283
+ x=metrics,
284
+ y=matscibert_values,
285
+ name='Matscibert',
286
+ marker_color='#4285F4'
287
+ ))
288
+
289
+ fig.add_trace(go.Bar(
290
+ x=metrics,
291
+ y=bert_values,
292
+ name='BERT',
293
+ marker_color='#EA4335'
294
+ ))
295
+
296
+ fig.update_layout(
297
+ title='Model Comparison',
298
+ barmode='group',
299
+ height=500
300
+ )
301
+
302
+ return fig
303
+
304
+ def home_page():
305
+ # CSS to center content vertically from middle to bottom
306
+ st.markdown("""
307
+ <style>
308
+ .main .block-container {
309
+ padding-top: 0;
310
+ display: flex;
311
+ flex-direction: column;
312
+ justify-content: center;
313
+ min-height: 70vh;
314
+ }
315
+ @media (max-height: 700px) {
316
+ .main .block-container {
317
+ min-height: 80vh;
318
+ }
319
+ }
320
+ </style>
321
+ """, unsafe_allow_html=True)
322
+
323
+ # Centered heading
324
+ st.markdown("""
325
+ <div style='text-align: center; margin-bottom: 1rem;'>
326
+ <h1>Welcome to the Alloy Based Chatbot</h1>
327
+ </div>
328
+ """, unsafe_allow_html=True)
329
+
330
+ # Search components - centered in the middle of available space
331
+ col1, col2, col3 = st.columns([1, 2, 1])
332
+ with col2:
333
+ user_input = st.text_area(
334
+ "Enter your question about alloys:",
335
+ key="user_input",
336
+ value=st.session_state.question,
337
+ height=100,
338
+ label_visibility="collapsed",
339
+ placeholder="Ask your question here"
340
+ )
341
+
342
+ submit_button = st.button(
343
+ "Search",
344
+ key="search_button",
345
+ use_container_width=True
346
+ )
347
+
348
+ if submit_button and user_input:
349
+ st.session_state.question = user_input
350
+ st.session_state.results = ask_question(user_input)
351
+ st.session_state.page = 'results'
352
+ st.rerun()
353
+
354
+ def results_page():
355
+ st.title("Search Results")
356
+
357
+ if st.session_state.results:
358
+ results = st.session_state.results
359
+
360
+ # First show answers in columns
361
+ st.subheader("Model Answers")
362
+ col1, col2 = st.columns(2)
363
+
364
+ with col1:
365
+ with st.container(border=True):
366
+ st.markdown("### Matscibert Answer")
367
+ st.write(results["matscibert"]["Answer"])
368
+
369
+ with col2:
370
+ with st.container(border=True):
371
+ st.markdown("### BERT Answer")
372
+ st.write(results["bert"]["Answer"])
373
+
374
+ # Then show the comparison chart
375
+ st.subheader("Model Performance Comparison")
376
+ st.plotly_chart(
377
+ create_comparison_chart(results["matscibert"]["Scores"], results["bert"]["Scores"]),
378
+ use_container_width=True
379
+ )
380
+
381
+ # Detailed metrics in tabs
382
+ st.subheader("Detailed Metrics")
383
+ tab1, tab2 = st.tabs(["Matscibert Metrics", "BERT Metrics"])
384
+
385
+ with tab1:
386
+ col1, col2 = st.columns(2)
387
+ with col1:
388
+ st.plotly_chart(
389
+ create_bertscore_chart(results["matscibert"]["Scores"]["BERTScore"]),
390
+ use_container_width=True
391
+ )
392
+ with col2:
393
+ st.plotly_chart(
394
+ create_rouge_chart(results["matscibert"]["Scores"]["ROUGE"]),
395
+ use_container_width=True
396
+ )
397
+
398
+ with tab2:
399
+ col1, col2 = st.columns(2)
400
+ with col1:
401
+ st.plotly_chart(
402
+ create_bertscore_chart(results["bert"]["Scores"]["BERTScore"]),
403
+ use_container_width=True
404
+ )
405
+ with col2:
406
+ st.plotly_chart(
407
+ create_rouge_chart(results["bert"]["Scores"]["ROUGE"]),
408
+ use_container_width=True
409
+ )
410
+
411
+ # Navigation buttons at the bottom
412
+ st.markdown("---")
413
+ col1, col2 = st.columns([1, 1])
414
+ with col1:
415
+ if st.button("Start New Search", use_container_width=True):
416
+ st.session_state.page = 'home'
417
+ st.session_state.question = ''
418
+ st.rerun()
419
+ with col2:
420
+ if st.button("View Context", use_container_width=True):
421
+ st.session_state.page = 'context_choice'
422
+ st.rerun()
423
+
424
+ def context_choice_page():
425
+ st.title("Select Context to View")
426
+
427
+ st.write("Choose which model's context you'd like to examine:")
428
+
429
+ col1, col2 = st.columns(2)
430
+ with col1:
431
+ if st.button("View Matscibert Context", use_container_width=True):
432
+ st.session_state.selected_context = "matscibert"
433
+ st.session_state.page = 'context_view'
434
+ st.rerun()
435
+ with col2:
436
+ if st.button("View BERT Context", use_container_width=True):
437
+ st.session_state.selected_context = "bert"
438
+ st.session_state.page = 'context_view'
439
+ st.rerun()
440
+
441
+ st.markdown("---")
442
+ if st.button("Back to Results", use_container_width=True):
443
+ st.session_state.page = 'results'
444
+ st.rerun()
445
+
446
+ def context_view_page():
447
+ st.title(f"{st.session_state.selected_context.capitalize()} Context")
448
+
449
+ # Context switching buttons at top
450
+ col1, col2 = st.columns(2)
451
+ with col1:
452
+ if st.button("Switch to Matscibert Context",
453
+ disabled=st.session_state.selected_context == "matscibert",
454
+ use_container_width=True):
455
+ st.session_state.selected_context = "matscibert"
456
+ st.rerun()
457
+ with col2:
458
+ if st.button("Switch to BERT Context",
459
+ disabled=st.session_state.selected_context == "bert",
460
+ use_container_width=True):
461
+ st.session_state.selected_context = "bert"
462
+ st.rerun()
463
+
464
+ # Display the context in a scrollable container
465
+ if st.session_state.results and st.session_state.selected_context:
466
+ context = st.session_state.results[st.session_state.selected_context]["Context"]
467
+ with st.container(height=600, border=True):
468
+ st.markdown(f"```\n{context}\n```")
469
+
470
+ # Navigation buttons at bottom
471
+ st.markdown("---")
472
+ col1, col2 = st.columns([1, 1])
473
+ with col1:
474
+ if st.button("Back to Results", use_container_width=True):
475
+ st.session_state.page = 'results'
476
+ st.rerun()
477
+ with col2:
478
+ if st.button("New Search", use_container_width=True):
479
+ st.session_state.page = 'home'
480
+ st.session_state.question = ''
481
+ st.rerun()
482
+
483
+ def main():
484
+ # Add some custom CSS
485
+ st.markdown("""
486
+ <style>
487
+ /* Search bar styling */
488
+ .stTextArea textarea {
489
+ min-height: 100px;
490
+ border: none !important;
491
+ box-shadow: none !important;
492
+ padding: 12px !important;
493
+ }
494
+ .stTextArea div[data-baseweb="base-input"] {
495
+ border-radius: 8px !important;
496
+ border: none !important;
497
+ box-shadow: none !important;
498
+ background-color: transparent !important;
499
+ }
500
+
501
+ /* Button styling */
502
+ .stButton button {
503
+ width: 100%;
504
+ margin-top: 0.5rem;
505
+ }
506
+
507
+ /* Layout adjustments */
508
+ div[data-testid="stHorizontalBlock"] {
509
+ gap: 0.5rem;
510
+ }
511
+
512
+ /* Remove extra padding */
513
+ .main .block-container {
514
+ padding-top: 0;
515
+ }
516
+ </style>
517
+ """, unsafe_allow_html=True)
518
+
519
+ if st.session_state.page == 'home':
520
+ home_page()
521
+ elif st.session_state.page == 'results':
522
+ results_page()
523
+ elif st.session_state.page == 'context_choice':
524
+ context_choice_page()
525
+ elif st.session_state.page == 'context_view':
526
+ context_view_page()
527
+
528
+ if __name__ == "__main__":
529
+ main()