ruslanmv commited on
Commit
178e9cd
·
1 Parent(s): 907b637

Update webchat.py

Browse files
Files changed (1) hide show
  1. webchat.py +10 -7
webchat.py CHANGED
@@ -79,18 +79,21 @@ def get_model_test(model_type, max_tokens, min_tokens, decoding, temperature):
79
 
80
  return model
81
 
82
- import os
83
- # Get the current working directory
84
  current_dir = os.getcwd()
85
  cache_dir = os.path.join(current_dir, ".cache")
86
  if not os.path.exists(cache_dir):
87
- os.makedirs(cache_dir)
88
- # Embedding function
89
- class MiniLML6V2EmbeddingFunction(EmbeddingFunction):
90
- #MODEL = SentenceTransformer('all-MiniLM-L6-v2')
91
- MODEL = SentenceTransformer('all-MiniLM-L6-v2', cache_dir=cache_dir)
92
 
 
 
93
 
 
 
 
 
 
 
94
  def __call__(self, texts):
95
  return MiniLML6V2EmbeddingFunction.MODEL.encode(texts).tolist()
96
 
 
79
 
80
  return model
81
 
82
+ # Set up cache directory
 
83
  current_dir = os.getcwd()
84
  cache_dir = os.path.join(current_dir, ".cache")
85
  if not os.path.exists(cache_dir):
86
+ os.makedirs(cache_dir)
 
 
 
 
87
 
88
+ # Set the TRANSFORMERS_CACHE environment variable
89
+ os.environ['TRANSFORMERS_CACHE'] = cache_dir
90
 
91
+ # Download the model first
92
+ model = SentenceTransformer('all-MiniLM-L6-v2')
93
+
94
+ # Embedding function
95
+ class MiniLML6V2EmbeddingFunction(EmbeddingFunction):
96
+ MODEL = model
97
  def __call__(self, texts):
98
  return MiniLML6V2EmbeddingFunction.MODEL.encode(texts).tolist()
99