Rohit1412 commited on
Commit
4d1faa0
·
verified ·
1 Parent(s): 4688d6b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -34,7 +34,7 @@ def extract_text_from_pdf(pdf_file):
34
  pages.append(f"Error reading PDF: {str(e)}")
35
  return pages
36
 
37
- def chunk_text(text, chunk_size=500):
38
  """Split text into chunks of approximately chunk_size characters."""
39
  words = text.split()
40
  chunks = []
@@ -95,7 +95,7 @@ def rag_pipeline(question, pdf_files):
95
 
96
  # Retrieve top 3 chunks using cosine similarity
97
  cos_scores = util.pytorch_cos_sim(query_embedding, doc_embeddings)[0]
98
- top_results = torch.topk(cos_scores, k=min(3, len(documents)))
99
  retrieved_context = ""
100
  for score, idx in zip(top_results.values, top_results.indices):
101
  retrieved_context += f"- {documents[idx]} (score: {score:.2f})\n"
@@ -120,7 +120,7 @@ def rag_pipeline(question, pdf_files):
120
 
121
  # Generate answer with more tokens
122
  inputs = gen_tokenizer(prompt, return_tensors="pt")
123
- outputs = gen_model.generate(**inputs, max_new_tokens=150, num_beams=2)
124
  answer = gen_tokenizer.decode(outputs[0], skip_special_tokens=True)
125
 
126
  # Log processing time
 
34
  pages.append(f"Error reading PDF: {str(e)}")
35
  return pages
36
 
37
+ def chunk_text(text, chunk_size=1500):
38
  """Split text into chunks of approximately chunk_size characters."""
39
  words = text.split()
40
  chunks = []
 
95
 
96
  # Retrieve top 3 chunks using cosine similarity
97
  cos_scores = util.pytorch_cos_sim(query_embedding, doc_embeddings)[0]
98
+ top_results = torch.topk(cos_scores, k=min(5, len(documents)))
99
  retrieved_context = ""
100
  for score, idx in zip(top_results.values, top_results.indices):
101
  retrieved_context += f"- {documents[idx]} (score: {score:.2f})\n"
 
120
 
121
  # Generate answer with more tokens
122
  inputs = gen_tokenizer(prompt, return_tensors="pt")
123
+ outputs = gen_model.generate(**inputs, max_new_tokens=1500, num_beams=2)
124
  answer = gen_tokenizer.decode(outputs[0], skip_special_tokens=True)
125
 
126
  # Log processing time