crossroderick's picture
Added all files
0eb636f
raw
history blame contribute delete
861 Bytes
from src.utils.config import MINIDALALM_MODEL
from sentence_transformers import SentenceTransformer
class DalaEmbedder:
"""
Simple wrapper for the MiniDalaLM embedding model
"""
def __init__(self, model_path: str = MINIDALALM_MODEL):
self.model = SentenceTransformer(model_path)
def embed_text(self, text: str) -> list[float]:
"""
Embed a single string of text.
"""
return self.model.encode(text, convert_to_numpy = True).tolist()
def embed_batch(self, texts: list[str]) -> list[list[float]]:
"""
Embed a batch of text strings.
"""
return self.model.encode(texts, convert_to_numpy = True)
def get_model(self) -> SentenceTransformer:
"""
Get function to enable access to the MiniDalaLM model.
"""
return self.model