snsynth commited on
Commit
533f28c
·
1 Parent(s): 62ec630

compute probabilities

Browse files
Files changed (1) hide show
  1. rag_app/rag_2.py +39 -13
rag_app/rag_2.py CHANGED
@@ -1,12 +1,13 @@
1
- import re
2
  import os
3
- from llama_cpp import Llama, LlamaGrammar
 
 
4
  from llama_index.llms.llama_cpp import LlamaCPP
5
  from llama_index.core import VectorStoreIndex, SimpleDirectoryReader, Settings
6
  from llama_index.retrievers.bm25 import BM25Retriever
7
  from llama_index.core.retrievers import QueryFusionRetriever
8
  from llama_index.core.query_engine import RetrieverQueryEngine
9
- from llama_index.core import StorageContext, load_index_from_storage
10
  from llama_index.embeddings.huggingface import HuggingFaceEmbedding
11
  from llama_index.core.postprocessor import LLMRerank
12
  from llama_index.core.node_parser import TokenTextSplitter
@@ -37,9 +38,12 @@ llm = LlamaCPP(
37
  temperature=0.1,
38
  max_new_tokens=256,
39
  context_window=16384,
40
- model_kwargs={"n_gpu_layers":-1},
41
  messages_to_prompt=messages_to_prompt,
42
- completion_to_prompt=completion_to_prompt)
 
 
 
43
 
44
 
45
  embedding_model = HuggingFaceEmbedding(
@@ -86,13 +90,39 @@ def is_relevant(query, index, threshold=0.7):
86
  return not similarity <= threshold
87
 
88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  def answer_question(query):
 
 
 
90
  print("loading bm25 retriever")
91
  bm25_retriever = BM25Retriever.from_persist_dir("models/bm25_retriever")
92
  print("loading saved vector index")
93
  storage_context = StorageContext.from_defaults(persist_dir="models/precomputed_index")
94
  index = load_index_from_storage(storage_context)
95
 
 
 
 
96
  retriever = QueryFusionRetriever(
97
  [
98
  index.as_retriever(similarity_top_k=5, verbose=True),
@@ -111,12 +141,8 @@ def answer_question(query):
111
  retriever=retriever,
112
  node_postprocessors=[reranker],
113
  )
114
-
115
- if is_harmful(query):
116
- return "This query has been flagged as unsafe."
117
-
118
- if not is_relevant(query, index, 0.2):
119
- return "This query doesn't appear relevant to finance."
120
-
121
  response = keyword_query_engine.query(query)
122
- return str(response)
 
 
 
 
 
1
  import os
2
+ import math
3
+ import numpy as np
4
+ from llama_cpp import Llama
5
  from llama_index.llms.llama_cpp import LlamaCPP
6
  from llama_index.core import VectorStoreIndex, SimpleDirectoryReader, Settings
7
  from llama_index.retrievers.bm25 import BM25Retriever
8
  from llama_index.core.retrievers import QueryFusionRetriever
9
  from llama_index.core.query_engine import RetrieverQueryEngine
10
+ from llama_index.core import StorageContext, load_index_from_storage, QueryBundle
11
  from llama_index.embeddings.huggingface import HuggingFaceEmbedding
12
  from llama_index.core.postprocessor import LLMRerank
13
  from llama_index.core.node_parser import TokenTextSplitter
 
38
  temperature=0.1,
39
  max_new_tokens=256,
40
  context_window=16384,
41
+ model_kwargs={"n_gpu_layers":-1, 'logits_all': True, 'logprobs': True,},
42
  messages_to_prompt=messages_to_prompt,
43
+ completion_to_prompt=completion_to_prompt,)
44
+
45
+ llm2 = Llama(model_path="models/Llama-3.2-1B-Instruct-Q4_K_M.gguf",
46
+ n_gpu_layers=-1, n_ctx=8000, logits_all=True)
47
 
48
 
49
  embedding_model = HuggingFaceEmbedding(
 
90
  return not similarity <= threshold
91
 
92
 
93
+ def get_sequence_probability(llm, input_sequence):
94
+ input_tokens = llm.tokenize(input_sequence.encode("utf-8"))
95
+ sequence_logits = []
96
+ sequence_logprobs = []
97
+
98
+ eval_tokens = input_tokens[:1]
99
+
100
+ for token in input_tokens[1:]:
101
+ llm.eval(eval_tokens)
102
+
103
+ probs = llm.logits_to_logprobs(llm.eval_logits)
104
+ sequence_logits.append(llm.eval_logits[-1][token])
105
+ sequence_logprobs.append(probs[-1][token])
106
+ eval_tokens.append(token)
107
+
108
+ total_log_prob = sum(sequence_logprobs)
109
+ sequence_probability = math.exp(total_log_prob)
110
+ return sequence_probability
111
+
112
+
113
  def answer_question(query):
114
+ if is_harmful(query):
115
+ return "This query has been flagged as unsafe."
116
+
117
  print("loading bm25 retriever")
118
  bm25_retriever = BM25Retriever.from_persist_dir("models/bm25_retriever")
119
  print("loading saved vector index")
120
  storage_context = StorageContext.from_defaults(persist_dir="models/precomputed_index")
121
  index = load_index_from_storage(storage_context)
122
 
123
+ if not is_relevant(query, index, 0.2):
124
+ return "This query doesn't appear relevant to finance."
125
+
126
  retriever = QueryFusionRetriever(
127
  [
128
  index.as_retriever(similarity_top_k=5, verbose=True),
 
141
  retriever=retriever,
142
  node_postprocessors=[reranker],
143
  )
 
 
 
 
 
 
 
144
  response = keyword_query_engine.query(query)
145
+ response_text = str(response)
146
+ response_prob = get_sequence_probability(llm2, response_text)
147
+ print(f"Output probability: {response_prob}")
148
+ return response_text