Rohit1412 commited on
Commit
d5ba7eb
·
verified ·
1 Parent(s): 4d1faa0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -32
app.py CHANGED
@@ -1,7 +1,9 @@
1
  import gradio as gr
2
  import torch
3
  from sentence_transformers import SentenceTransformer, util
4
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
 
 
5
  import PyPDF2
6
  import os
7
  import time
@@ -13,12 +15,21 @@ logger = logging.getLogger(__name__)
13
 
14
  # Load models
15
  retriever_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
16
- gen_tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
17
- gen_model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small")
18
 
19
  # Cache for document embeddings
20
  embedding_cache = {}
21
 
 
 
 
 
 
 
 
 
 
22
  def extract_text_from_pdf(pdf_file):
23
  """Extract text from a PDF file, returning a list of page texts."""
24
  pages = []
@@ -34,7 +45,7 @@ def extract_text_from_pdf(pdf_file):
34
  pages.append(f"Error reading PDF: {str(e)}")
35
  return pages
36
 
37
- def chunk_text(text, chunk_size=1500):
38
  """Split text into chunks of approximately chunk_size characters."""
39
  words = text.split()
40
  chunks = []
@@ -46,7 +57,7 @@ def chunk_text(text, chunk_size=1500):
46
  current_chunk = []
47
  current_length = 0
48
  current_chunk.append(word)
49
- current_length += len(word) + 1 # +1 for space
50
  if current_chunk:
51
  chunks.append(" ".join(current_chunk))
52
  return chunks
@@ -64,7 +75,7 @@ def get_document_embeddings(documents):
64
  return torch.stack(embeddings)
65
 
66
  def rag_pipeline(question, pdf_files):
67
- """Optimized RAG pipeline with improved prompting and fallback."""
68
  start_time = time.time()
69
  documents = []
70
 
@@ -95,45 +106,51 @@ def rag_pipeline(question, pdf_files):
95
 
96
  # Retrieve top 3 chunks using cosine similarity
97
  cos_scores = util.pytorch_cos_sim(query_embedding, doc_embeddings)[0]
98
- top_results = torch.topk(cos_scores, k=min(5, len(documents)))
99
  retrieved_context = ""
100
  for score, idx in zip(top_results.values, top_results.indices):
101
  retrieved_context += f"- {documents[idx]} (score: {score:.2f})\n"
102
 
103
- # Log retrieved context for debugging
104
  logger.info(f"Retrieved context:\n{retrieved_context}")
105
 
106
- # Improved prompt with fallback
107
- if retrieved_context.strip():
108
- prompt = (
109
- f"Based on the following context, provide a concise and accurate answer to the question.\n\n"
110
- f"Context:\n{retrieved_context}\n\n"
111
- f"Question: {question}\n\n"
112
- f"Answer:"
 
113
  )
114
- else:
115
- prompt = (
116
- f"No relevant context found. Provide a general answer to the question based on your knowledge.\n\n"
117
- f"Question: {question}\n\n"
118
- f"Answer:"
 
 
 
 
 
 
 
 
119
  )
 
 
 
120
 
121
- # Generate answer with more tokens
122
- inputs = gen_tokenizer(prompt, return_tensors="pt")
123
- outputs = gen_model.generate(**inputs, max_new_tokens=1500, num_beams=2)
124
- answer = gen_tokenizer.decode(outputs[0], skip_special_tokens=True)
125
-
126
- # Log processing time
127
  logger.info(f"Processing time: {time.time() - start_time:.2f} seconds")
128
- return answer if answer else "Unable to generate a meaningful response."
129
 
130
  # Gradio UI
131
  with gr.Blocks() as demo:
132
- gr.Markdown("# Improved Lightweight Local RAG Pipeline with PDF Input")
133
  gr.Markdown(
134
- "Upload one or more PDF files (or leave blank for default AI/Data Science documents), enter your question, "
135
- "and get an answer generated using an optimized retrieval step (all-MiniLM-L6-v2) and a small "
136
- "generator model (flan-t5-small). Designed for 2 vCPUs and 16GB RAM."
137
  )
138
  with gr.Row():
139
  with gr.Column():
@@ -145,4 +162,4 @@ with gr.Blocks() as demo:
145
 
146
  submit_button.click(fn=rag_pipeline, inputs=[question_input, pdf_input], outputs=response_output)
147
 
148
- demo.launch()
 
1
  import gradio as gr
2
  import torch
3
  from sentence_transformers import SentenceTransformer, util
4
+ from transformers import AutoTokenizer, AutoModelForCausalLM
5
+ from langchain.chains import LLMChain
6
+ from langchain.prompts import PromptTemplate
7
  import PyPDF2
8
  import os
9
  import time
 
15
 
16
  # Load models
17
  retriever_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
18
+ gen_tokenizer = AutoTokenizer.from_pretrained("microsoft/phi-1")
19
+ gen_model = AutoModelForCausalLM.from_pretrained("microsoft/phi-1", torch_dtype=torch.float16)
20
 
21
  # Cache for document embeddings
22
  embedding_cache = {}
23
 
24
+ # LangChain wrapper for Phi-1
25
+ class Phi1LLM:
26
+ def __call__(self, prompt, **kwargs):
27
+ inputs = gen_tokenizer(prompt, return_tensors="pt")
28
+ outputs = gen_model.generate(**inputs, max_new_tokens=150, num_beams=2)
29
+ return gen_tokenizer.decode(outputs[0], skip_special_tokens=True)
30
+
31
+ phi1_llm = Phi1LLM()
32
+
33
  def extract_text_from_pdf(pdf_file):
34
  """Extract text from a PDF file, returning a list of page texts."""
35
  pages = []
 
45
  pages.append(f"Error reading PDF: {str(e)}")
46
  return pages
47
 
48
+ def chunk_text(text, chunk_size=500):
49
  """Split text into chunks of approximately chunk_size characters."""
50
  words = text.split()
51
  chunks = []
 
57
  current_chunk = []
58
  current_length = 0
59
  current_chunk.append(word)
60
+ current_length += len(word) + 1
61
  if current_chunk:
62
  chunks.append(" ".join(current_chunk))
63
  return chunks
 
75
  return torch.stack(embeddings)
76
 
77
  def rag_pipeline(question, pdf_files):
78
+ """RAG pipeline with multi-step thinking using Phi-1 and LangChain."""
79
  start_time = time.time()
80
  documents = []
81
 
 
106
 
107
  # Retrieve top 3 chunks using cosine similarity
108
  cos_scores = util.pytorch_cos_sim(query_embedding, doc_embeddings)[0]
109
+ top_results = torch.topk(cos_scores, k=min(3, len(documents)))
110
  retrieved_context = ""
111
  for score, idx in zip(top_results.values, top_results.indices):
112
  retrieved_context += f"- {documents[idx]} (score: {score:.2f})\n"
113
 
 
114
  logger.info(f"Retrieved context:\n{retrieved_context}")
115
 
116
+ # Step 1: Initial Answer
117
+ initial_prompt = PromptTemplate(
118
+ input_variables=["context", "question"],
119
+ template=(
120
+ "Using the following context, provide a concise answer to the question:\n\n"
121
+ "Context:\n{context}\n\n"
122
+ "Question: {question}\n\n"
123
+ "Answer:"
124
  )
125
+ )
126
+ initial_chain = LLMChain(llm=phi1_llm, prompt=initial_prompt)
127
+ initial_answer = initial_chain.run(context=retrieved_context, question=question)
128
+
129
+ # Step 2: Refine Answer
130
+ refine_prompt = PromptTemplate(
131
+ input_variables=["context", "question", "initial_answer"],
132
+ template=(
133
+ "Given the context and initial answer, refine and improve the response to the question:\n\n"
134
+ "Context:\n{context}\n\n"
135
+ "Question: {question}\n\n"
136
+ "Initial Answer: {initial_answer}\n\n"
137
+ "Refined Answer:"
138
  )
139
+ )
140
+ refine_chain = LLMChain(llm=phi1_llm, prompt=refine_prompt)
141
+ refined_answer = refine_chain.run(context=retrieved_context, question=question, initial_answer=initial_answer)
142
 
143
+ logger.info(f"Initial answer: {initial_answer}")
144
+ logger.info(f"Refined answer: {refined_answer}")
 
 
 
 
145
  logger.info(f"Processing time: {time.time() - start_time:.2f} seconds")
146
+ return refined_answer if refined_answer else "Unable to generate a meaningful response."
147
 
148
  # Gradio UI
149
  with gr.Blocks() as demo:
150
+ gr.Markdown("# RAG Pipeline with microsoft/phi-1 and Multi-Step Thinking")
151
  gr.Markdown(
152
+ "Upload PDFs (or use default AI/Data Science docs), ask a question, "
153
+ "and get refined answers using Phi-1 with multi-step reasoning on 2 vCPUs and 16GB RAM."
 
154
  )
155
  with gr.Row():
156
  with gr.Column():
 
162
 
163
  submit_button.click(fn=rag_pipeline, inputs=[question_input, pdf_input], outputs=response_output)
164
 
165
+ demo.launch(share=True, debug=True)