snowflake2_m_uint8 / README.md
electroglyph's picture
Update README.md
732bd9b verified
metadata
pipeline_tag: sentence-similarity
tags:
  - sentence-transformers
  - feature-extraction
  - sentence-similarity
  - mteb
  - arctic
  - embedding
  - snowflake2_m_uint8
  - snowflake
  - transformers.js
license: apache-2.0
language:
  - af
  - ar
  - az
  - be
  - bg
  - bn
  - ca
  - ceb
  - cs
  - cy
  - da
  - de
  - el
  - en
  - es
  - et
  - eu
  - fa
  - fi
  - fr
  - gl
  - gu
  - he
  - hi
  - hr
  - ht
  - hu
  - hy
  - id
  - is
  - it
  - ja
  - jv
  - ka
  - kk
  - km
  - kn
  - ko
  - ky
  - lo
  - lt
  - lv
  - mk
  - ml
  - mn
  - mr
  - ms
  - my
  - ne
  - nl
  - pa
  - pl
  - pt
  - qu
  - ro
  - ru
  - si
  - sk
  - sl
  - so
  - sq
  - sr
  - sv
  - sw
  - ta
  - te
  - th
  - tl
  - tr
  - uk
  - ur
  - vi
  - yo
  - zh

Update

I've updated this model to be compatible with Fastembed.

I removed the sentence_embedding output and quantized the main model output instead. This now outputs a dimension 768 multivector.

To use the output you should use CLS pooling with normalization disabled.

snowflake2_m_uint8

This is a slightly modified version of the uint8 quantized ONNX model from https://huggingface.co/Snowflake/snowflake-arctic-embed-m-v2.0

I have added a linear quantization node before the token_embeddings output so that it directly outputs a dimension 768 uint8 multivector.

This is compatible with the qdrant uint8 datatype for collections.

I took the liberty of removing the sentence_embedding output (since I would've had to re-create it), I can add it back in if anybody wants it.

Quantization method

Linear quantization for the scale -7 to 7.

Here's what the graph of the original output looks like:

original model graph

Here's what the new graph in this model looks like:

modified model graph

Benchmark

I used beir-qdrant with the scifact dataset.

quantized output (this model):

ndcg: {'NDCG@1': 0.59333, 'NDCG@3': 0.64619, 'NDCG@5': 0.6687, 'NDCG@10': 0.69228, 'NDCG@100': 0.72204, 'NDCG@1000': 0.72747}
recall: {'Recall@1': 0.56094, 'Recall@3': 0.68394, 'Recall@5': 0.73983, 'Recall@10': 0.80689, 'Recall@100': 0.94833, 'Recall@1000': 0.99333}
precision: {'P@1': 0.59333, 'P@3': 0.25, 'P@5': 0.16467, 'P@10': 0.09167, 'P@100': 0.01077, 'P@1000': 0.00112}

unquantized output (model_uint8.onnx):

ndcg: {'NDCG@1': 0.59333, 'NDCG@3': 0.65417, 'NDCG@5': 0.6741, 'NDCG@10': 0.69675, 'NDCG@100': 0.7242, 'NDCG@1000': 0.7305}
recall: {'Recall@1': 0.56094, 'Recall@3': 0.69728, 'Recall@5': 0.74817, 'Recall@10': 0.81356, 'Recall@100': 0.945, 'Recall@1000': 0.99667}
precision: {'P@1': 0.59333, 'P@3': 0.25444, 'P@5': 0.16667, 'P@10': 0.09233, 'P@100': 0.01073, 'P@1000': 0.00113}

Example inference/benchmark code and how to use the model with Fastembed

After installing beir-qdrant make sure to upgrade fastembed.

# pip install qdrant_client beir-qdrant
# pip install -U fastembed
from fastembed import TextEmbedding
from fastembed.common.model_description import PoolingType, ModelSource
from beir import util
from beir.datasets.data_loader import GenericDataLoader
from beir.retrieval.evaluation import EvaluateRetrieval
from qdrant_client import QdrantClient
from qdrant_client.models import Datatype
from beir_qdrant.retrieval.models.fastembed import DenseFastEmbedModelAdapter
from beir_qdrant.retrieval.search.dense import DenseQdrantSearch

TextEmbedding.add_custom_model(
    model="electroglyph/snowflake2_m_uint8",
    pooling=PoolingType.CLS,
    normalization=False,
    sources=ModelSource(hf="electroglyph/snowflake2_m_uint8"),
    dim=768,
    model_file="snowflake2_m_uint8.onnx",
)

dataset = "scifact"
url = "https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{}.zip".format(dataset)
data_path = util.download_and_unzip(url, "datasets")
corpus, queries, qrels = GenericDataLoader(data_folder=data_path).load(split="test")

qdrant_client = QdrantClient("http://localhost:6333")

model = DenseQdrantSearch(
    qdrant_client,
    model=DenseFastEmbedModelAdapter(
        model_name="electroglyph/snowflake2_m_uint8"
    ),
    collection_name="scifact-uint8",
    initialize=True,
    datatype=Datatype.UINT8
)
retriever = EvaluateRetrieval(model)
results = retriever.retrieve(corpus, queries)

ndcg, _map, recall, precision = retriever.evaluate(qrels, results, retriever.k_values)
print(f"ndcg: {ndcg}\nrecall: {recall}\nprecision: {precision}")