bourdoiscatie's picture
Update app.py
0452c49 verified
raw
history blame
7.2 kB
import time
import gradio as gr
from datasets import load_dataset
import pandas as pd
from sentence_transformers import SentenceTransformer
from sentence_transformers.quantization import quantize_embeddings
import faiss
from usearch.index import Index
# Load titles and texts
title_text_dataset = load_dataset("bourdoiscatie/wikipedia_fr_2022_250K", split="train", num_proc=4).select_columns(["title", "text"])
# Load the int8 and binary indices. Int8 is loaded as a view to save memory, as we never actually perform search with it.
int8_view = Index.restore("wikipedia_fr_2022_250K_int8_usearch.index", view=True)
binary_index: faiss.IndexBinaryFlat = faiss.read_index_binary("wikipedia_fr_2022_250K_ubinary_faiss.index")
# binary_ivf: faiss.IndexBinaryIVF = faiss.read_index_binary("wikipedia_fr_2022_250K_ubinary_ivf_faiss.index")
# Load the SentenceTransformer model for embedding the queries
model = SentenceTransformer("OrdalieTech/Solon-embeddings-large-0.1")
def search(query, top_k: int = 20, rescore_multiplier: int = 1, use_approx: bool = False):
# 1. Embed the query as float32
start_time = time.time()
query_embedding = model.encode(query, prompt="query: ")
embed_time = time.time() - start_time
# 2. Quantize the query to ubinary
start_time = time.time()
query_embedding_ubinary = quantize_embeddings(query_embedding.reshape(1, -1), "ubinary")
quantize_time = time.time() - start_time
# 3. Search the binary index (either exact or approximate)
index = binary_index # binary_ivf if use_approx else binary_index
start_time = time.time()
_scores, binary_ids = index.search(query_embedding_ubinary, top_k * rescore_multiplier)
binary_ids = binary_ids[0]
search_time = time.time() - start_time
# 4. Load the corresponding int8 embeddings
start_time = time.time()
int8_embeddings = int8_view[binary_ids].astype(int)
load_time = time.time() - start_time
# 5. Rescore the top_k * rescore_multiplier using the float32 query embedding and the int8 document embeddings
start_time = time.time()
scores = query_embedding @ int8_embeddings.T
rescore_time = time.time() - start_time
# 6. Sort the scores and return the top_k
start_time = time.time()
indices = scores.argsort()[::-1][:top_k]
top_k_indices = binary_ids[indices]
top_k_scores = scores[indices]
top_k_titles, top_k_texts = zip(*[(title_text_dataset[idx]["title"], title_text_dataset[idx]["text"]) for idx in top_k_indices.tolist()])
df = pd.DataFrame({"Score": [round(value, 2) for value in top_k_scores], "Titre": top_k_titles, "Texte": top_k_texts})
sort_time = time.time() - start_time
return df, {
"Temps pour enchâsser la requête ": f"{embed_time:.4f} s",
"Temps pour la quantisation ": f"{quantize_time:.4f} s",
"Temps pour effectuer la recherche ": f"{search_time:.4f} s",
"Temps de chargement ": f"{load_time:.4f} s",
"Temps de rescorage ": f"{rescore_time:.4f} s",
"Temps pour trier les résustats ": f"{sort_time:.4f} s",
"Temps total pour la recherche ": f"{quantize_time + search_time + load_time + rescore_time + sort_time:.4f} s",
}
with gr.Blocks(title="Requêter Wikipedia en temps réel") as demo:
gr.Markdown(
"""
## Requêter Wikipedia en temps réel
Effectuer une requête dans un corpus composé de 250K paragraphes tirés d'articles de Wikipédia.
Les résultats sont renvoyés en temps réel via une architecture tournant sur un CPU 🚀
<details><summary>Détails du processus</summary>
Détails :
1. La requête est enchâssée en float32 à l'aide du modèle [`OrdalieTech/Solon-embeddings-large-0.1`](https://hf.co/OrdalieTech/Solon-embeddings-large-0.1).
2. La requête est quantizée en binaire à l'aide de la fonction `quantize_embeddings` de la bibliothèque [SentenceTransformers](https://sbert.net/).
3. Un index binaire (XXM d'enchâssements binaires pesant XXGB de mémoire/espace disque) est requêté (en binaire si l'option approximatie sélectionnée, en int8 si l'option exacte est sélectionnée).
4. Les *n* textes demandés par l'utilisateur jugés les plus pertinents sont chargés à la volée à partir d'un index int8 sur disque (XXM d'enchâssements int8 ; 0 bytes de mémoire, XXGB d'espace disque).
5. Les *n* textes sont rescorés en utilisant la requête en float32 et les enchâssements en int8.
6. Les *n* premiers textes sont triés par score et affichés.
Ce processus est conçu pour être rapide et efficace en termes de mémoire : l'index binaire étant suffisamment petit pour tenir dans la mémoire et l'index int8 étant chargé en tant que vue pour économiser de la mémoire.
Au total, ce processus nécessite de conserver 1) le modèle en mémoire, 2) l'index binaire en mémoire et 3) l'index int8 sur le disque.
Avec une dimension de 1024, nous avons besoin de `1024 / 8 * num_docs` octets pour l'index binaire et de `1024 * num_docs` octets pour l'index int8.
C'est nettement moins cher que de faire le même processus avec des enchâssements en float32 qui nécessiterait `4 * 1024 * num_docs` octets de mémoire/espace disque pour l'index float32, soit 32x plus de mémoire et 4x plus d'espace disque.
De plus, l'index binaire est beaucoup plus rapide (jusqu'à 32x) à rechercher que l'index float32, tandis que le rescorage est également extrêmement efficace.
En conclusion, ce processus permet une recherche rapide, évolutive, peu coûteuse et efficace en termes de mémoire.
</details>
"""
)
with gr.Row():
with gr.Column(scale=75):
query = gr.Textbox(
label="Recherche d'articles dans le Wikipédia francophone",
placeholder="Saisissez une requête pour rechercher des textes pertinents dans Wikipédia.",
)
with gr.Column(scale=25):
use_approx = gr.Radio(
choices=[("Recherche exacte", False), ("Recherche approximative", True)],
value=True,
label="Index de recherche",
)
with gr.Row():
with gr.Column(scale=2):
top_k = gr.Slider(
minimum=10,
maximum=200,
step=5,
value=20,
label="Nombre de documents à rechercher",
info="Recherche effectué via un bi-encodeur binaire",
)
with gr.Column(scale=2):
rescore_multiplier = gr.Slider(
minimum=1,
maximum=10,
step=1,
value=1,
label="Coefficient de rescorage",
info="Reranking via le coefficient`",
)
search_button = gr.Button(value="Search")
with gr.Row():
with gr.Column(scale=4):
output = gr.Dataframe(headers=["Score", "Titre", "Texte"])
with gr.Column(scale=1):
json = gr.JSON()
query.submit(search, inputs=[query, top_k, rescore_multiplier, use_approx], outputs=[output, json])
search_button.click(search, inputs=[query, top_k, rescore_multiplier, use_approx], outputs=[output, json])
demo.queue()
demo.launch()