himel06 commited on
Commit
3fb63a8
·
verified ·
1 Parent(s): ccea2c8

Update BanglaRAG/bangla_rag_pipeline.py

Browse files
Files changed (1) hide show
  1. BanglaRAG/bangla_rag_pipeline.py +20 -98
BanglaRAG/bangla_rag_pipeline.py CHANGED
@@ -4,7 +4,6 @@ from transformers import (
4
  AutoTokenizer,
5
  AutoModelForCausalLM,
6
  pipeline,
7
- GenerationConfig,
8
  BitsAndBytesConfig,
9
  )
10
  from langchain_core.prompts import PromptTemplate
@@ -14,25 +13,12 @@ from langchain_community.vectorstores import Chroma
14
  from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
15
  from langchain_core.runnables import RunnableParallel, RunnablePassthrough
16
  from langchain_core.output_parsers import StrOutputParser
17
- from langchain_core.prompts import ChatPromptTemplate, HumanMessagePromptTemplate
18
- from rich import print as rprint
19
- from rich.panel import Panel
20
- from tqdm import tqdm
21
  import warnings
22
- import re
23
 
24
  warnings.filterwarnings("ignore")
25
 
26
  class BanglaRAGChain:
27
- """
28
- Bangla Retrieval-Augmented Generation (RAG) Chain for question answering.
29
- This class uses a HuggingFace/local language model for text generation, a Chroma vector database for
30
- document retrieval, and a custom prompt template to create a RAG chain that can generate
31
- responses to user queries in Bengali.
32
- """
33
-
34
  def __init__(self):
35
- """Initializes the BanglaRAGChain with default parameters."""
36
  self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
37
  self.chat_model_id = None
38
  self.embed_model_id = None
@@ -71,22 +57,6 @@ class BanglaRAGChain:
71
  chunk_overlap=150,
72
  hf_token=None,
73
  ):
74
- """
75
- Loads the required models and data for the RAG chain.
76
- Args:
77
- chat_model_id (str): The Hugging Face model ID for the chat model.
78
- embed_model_id (str): The Hugging Face model ID for the embedding model.
79
- text_path (str): Path to the text file to be indexed.
80
- quantization (bool): Whether to quantize the model or not.
81
- k (int): The number of documents to retrieve.
82
- top_k (int): The top_k parameter for the generation configuration.
83
- top_p (float): The top_p parameter for the generation configuration.
84
- max_new_tokens (int): The maximum number of new tokens to generate.
85
- temperature (float): The temperature parameter for the generation configuration.
86
- chunk_size (int): The chunk size for text splitting.
87
- chunk_overlap (int): The chunk overlap for text splitting.
88
- hf_token (str): The Hugging Face token for authentication.
89
- """
90
  self.chat_model_id = chat_model_id
91
  self.embed_model_id = embed_model_id
92
  self.k = k
@@ -103,26 +73,14 @@ class BanglaRAGChain:
103
  if self.hf_token is not None:
104
  os.environ["HF_TOKEN"] = str(self.hf_token)
105
 
106
- rprint(Panel("[bold green]Loading chat models...", expand=False))
107
  self._load_models()
108
-
109
- rprint(Panel("[bold green]Creating document...", expand=False))
110
  self._create_document()
111
-
112
- rprint(Panel("[bold green]Updating Chroma database...", expand=False))
113
  self._update_chroma_db()
114
-
115
- rprint(Panel("[bold green]Initializing retriever...", expand=False))
116
  self._get_retriever()
117
-
118
- rprint(Panel("[bold green]Initializing LLM...", expand=False))
119
  self._get_llm()
120
-
121
- rprint(Panel("[bold green]Creating chain...", expand=False))
122
  self._create_chain()
123
 
124
  def _load_models(self):
125
- """Loads the chat model and tokenizer."""
126
  try:
127
  self.tokenizer = AutoTokenizer.from_pretrained(self.chat_model_id)
128
  bnb_config = None
@@ -133,28 +91,23 @@ class BanglaRAGChain:
133
  bnb_4bit_quant_type="nf4",
134
  bnb_4bit_compute_dtype=torch.float16,
135
  )
136
- rprint(Panel("[bold green]Applying 4bit quantization...", expand=False))
137
  self.chat_model = AutoModelForCausalLM.from_pretrained(
138
  self.chat_model_id,
139
- torch_dtype=torch.float16,
140
- low_cpu_mem_usage=True,
141
- quantization_config=bnb_config,
142
  device_map="auto",
 
143
  )
144
- rprint(Panel("[bold green]Applied 4bit quantization successfully", expand=False))
145
  else:
146
  self.chat_model = AutoModelForCausalLM.from_pretrained(
147
  self.chat_model_id,
148
- torch_dtype=torch.float16,
149
- low_cpu_mem_usage=True,
150
  device_map="auto",
151
  )
152
- rprint(Panel("[bold green]Chat Model loaded successfully!", expand=False))
153
  except Exception as e:
154
- rprint(Panel(f"[red]Error loading chat model: {e}", expand=False))
155
 
156
  def _create_document(self):
157
- """Splits the input text into chunks using RecursiveCharacterTextSplitter."""
158
  try:
159
  with open(self.text_path, "r", encoding="utf-8") as file:
160
  self._text_content = file.read()
@@ -163,44 +116,21 @@ class BanglaRAGChain:
163
  chunk_size=self.chunk_size,
164
  chunk_overlap=self.chunk_overlap,
165
  )
166
- self._documents = list(
167
- tqdm(
168
- character_splitter.split_text(self._text_content),
169
- desc="Chunking text",
170
- )
171
- )
172
- print(f"Number of chunks: {len(self._documents)}")
173
- if False:
174
- for i, chunk in enumerate(self._documents):
175
- if i > 5:
176
- break
177
- print(f"Chunk {i}: {chunk}")
178
- rprint(Panel("[bold green]Document created successfully!", expand=False))
179
  except Exception as e:
180
- rprint(Panel(f"[red]Chunking failed: {e}", expand=False))
181
 
182
  def _update_chroma_db(self):
183
- """Updates the Chroma vector database with the text chunks."""
184
  try:
185
- try:
186
- rprint(Panel(f"[bold green]Loading embedding model...",expand=False))
187
- model_kwargs = {"device": self._device}
188
- embeddings = HuggingFaceEmbeddings(
189
- model_name=self.embed_model_id, model_kwargs=model_kwargs
190
- )
191
- rprint(Panel(f"[bold green]Loaded embedding model successfully!", expand=False))
192
- except Exception as e:
193
- rprint(Panel(f"[red]embedding model loading failed: {e}", expand=False))
194
-
195
- self._db = Chroma.from_texts(texts=self._documents, embedding=embeddings)
196
- rprint(
197
- Panel("[bold green]Chroma database updated successfully!", expand=False)
198
  )
 
199
  except Exception as e:
200
- rprint(Panel(f"[red]Vector DB initialization failed: {e}", expand=False))
201
 
202
  def _create_chain(self):
203
- """Creates the retrieval-augmented generation (RAG) chain."""
204
  template = """Below is an instruction in Bengali language that describes a task, paired with an input also in Bengali language that provides further context. Write a response in Bengali that appropriately completes the request.
205
  ### Instruction:
206
  {question}
@@ -242,22 +172,18 @@ class BanglaRAGChain:
242
  ).assign(answer=rag_chain_from_docs)
243
 
244
  self._chain = rag_chain_with_source
245
- rprint(Panel("[bold green]Chain created successfully!", expand=False))
246
  except Exception as e:
247
- rprint(Panel(f"[red]Chain creation failed: {e}", expand=False))
248
 
249
  def _get_retriever(self):
250
- """Creates a retriever for the vector database."""
251
  try:
252
  self._retriever = self._db.as_retriever(
253
  search_type="similarity", search_kwargs={"k": self.k}
254
  )
255
- rprint(Panel("[bold green]Retriever created successfully!", expand=False))
256
  except Exception as e:
257
- rprint(Panel(f"[red]Retriever creation failed: {e}", expand=False))
258
 
259
  def _get_llm(self):
260
- """Initializes the language model using the Hugging Face pipeline."""
261
  try:
262
  pipe = pipeline(
263
  "text-generation",
@@ -271,26 +197,22 @@ class BanglaRAGChain:
271
  top_p=self.top_p,
272
  top_k=self.top_k,
273
  repetition_penalty=1.2,
274
- torch_dtype=torch.float16,
275
  )
276
 
277
  self._llm = HuggingFacePipeline(pipeline=pipe)
278
- rprint(Panel("[bold green]LLM initialized successfully!", expand=False))
279
  except Exception as e:
280
- rprint(Panel(f"[red]LLM initialization failed: {e}", expand=False))
281
- self._llm = None # Ensure it’s set to None on failure
282
 
283
  def __call__(self, query):
284
- """Runs the RAG chain on a user query and returns the generated answer."""
285
  if not self._chain:
286
  raise ValueError("The chain has not been initialized.")
287
- if self._chain:
288
- result = self._chain.invoke({"question": query})
289
- return result["answer"], result["context"]
290
 
291
  def _format_docs(self, docs):
292
- """Formats retrieved documents into a string format."""
293
  context = ""
294
  for i, doc in enumerate(docs):
295
  context += f"\nDocument {i + 1}:\n{doc.page_content}\n\n"
296
- return context
 
4
  AutoTokenizer,
5
  AutoModelForCausalLM,
6
  pipeline,
 
7
  BitsAndBytesConfig,
8
  )
9
  from langchain_core.prompts import PromptTemplate
 
13
  from langchain_community.llms.huggingface_pipeline import HuggingFacePipeline
14
  from langchain_core.runnables import RunnableParallel, RunnablePassthrough
15
  from langchain_core.output_parsers import StrOutputParser
 
 
 
 
16
  import warnings
 
17
 
18
  warnings.filterwarnings("ignore")
19
 
20
  class BanglaRAGChain:
 
 
 
 
 
 
 
21
  def __init__(self):
 
22
  self._device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
  self.chat_model_id = None
24
  self.embed_model_id = None
 
57
  chunk_overlap=150,
58
  hf_token=None,
59
  ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  self.chat_model_id = chat_model_id
61
  self.embed_model_id = embed_model_id
62
  self.k = k
 
73
  if self.hf_token is not None:
74
  os.environ["HF_TOKEN"] = str(self.hf_token)
75
 
 
76
  self._load_models()
 
 
77
  self._create_document()
 
 
78
  self._update_chroma_db()
 
 
79
  self._get_retriever()
 
 
80
  self._get_llm()
 
 
81
  self._create_chain()
82
 
83
  def _load_models(self):
 
84
  try:
85
  self.tokenizer = AutoTokenizer.from_pretrained(self.chat_model_id)
86
  bnb_config = None
 
91
  bnb_4bit_quant_type="nf4",
92
  bnb_4bit_compute_dtype=torch.float16,
93
  )
 
94
  self.chat_model = AutoModelForCausalLM.from_pretrained(
95
  self.chat_model_id,
96
+ load_in_8bit=True,
97
+ torch_dtype=torch.bfloat16,
 
98
  device_map="auto",
99
+ quantization_config=bnb_config,
100
  )
 
101
  else:
102
  self.chat_model = AutoModelForCausalLM.from_pretrained(
103
  self.chat_model_id,
104
+ torch_dtype=torch.bfloat16,
 
105
  device_map="auto",
106
  )
 
107
  except Exception as e:
108
+ raise RuntimeError(f"Error loading chat model: {e}")
109
 
110
  def _create_document(self):
 
111
  try:
112
  with open(self.text_path, "r", encoding="utf-8") as file:
113
  self._text_content = file.read()
 
116
  chunk_size=self.chunk_size,
117
  chunk_overlap=self.chunk_overlap,
118
  )
119
+ self._documents = character_splitter.split_text(self._text_content)
 
 
 
 
 
 
 
 
 
 
 
 
120
  except Exception as e:
121
+ raise RuntimeError(f"Chunking failed: {e}")
122
 
123
  def _update_chroma_db(self):
 
124
  try:
125
+ model_kwargs = {"device": self._device}
126
+ embeddings = HuggingFaceEmbeddings(
127
+ model_name=self.embed_model_id, model_kwargs=model_kwargs
 
 
 
 
 
 
 
 
 
 
128
  )
129
+ self._db = Chroma.from_texts(texts=self._documents, embedding=embeddings)
130
  except Exception as e:
131
+ raise RuntimeError(f"Vector DB initialization failed: {e}")
132
 
133
  def _create_chain(self):
 
134
  template = """Below is an instruction in Bengali language that describes a task, paired with an input also in Bengali language that provides further context. Write a response in Bengali that appropriately completes the request.
135
  ### Instruction:
136
  {question}
 
172
  ).assign(answer=rag_chain_from_docs)
173
 
174
  self._chain = rag_chain_with_source
 
175
  except Exception as e:
176
+ raise RuntimeError(f"Chain creation failed: {e}")
177
 
178
  def _get_retriever(self):
 
179
  try:
180
  self._retriever = self._db.as_retriever(
181
  search_type="similarity", search_kwargs={"k": self.k}
182
  )
 
183
  except Exception as e:
184
+ raise RuntimeError(f"Retriever creation failed: {e}")
185
 
186
  def _get_llm(self):
 
187
  try:
188
  pipe = pipeline(
189
  "text-generation",
 
197
  top_p=self.top_p,
198
  top_k=self.top_k,
199
  repetition_penalty=1.2,
200
+ torch_dtype=torch.bfloat16,
201
  )
202
 
203
  self._llm = HuggingFacePipeline(pipeline=pipe)
 
204
  except Exception as e:
205
+ raise RuntimeError(f"LLM initialization failed: {e}")
206
+ self._llm = None
207
 
208
  def __call__(self, query):
 
209
  if not self._chain:
210
  raise ValueError("The chain has not been initialized.")
211
+ result = self._chain.invoke({"question": query})
212
+ return result["answer"], result["context"]
 
213
 
214
  def _format_docs(self, docs):
 
215
  context = ""
216
  for i, doc in enumerate(docs):
217
  context += f"\nDocument {i + 1}:\n{doc.page_content}\n\n"
218
+ return context