YuITC
commited on
Commit
·
0063d17
1
Parent(s):
819910a
feat: initial project upload after testing
Browse files- .gitignore +8 -0
- Dockerfile +21 -0
- LICENSE +21 -0
- README.md +4 -0
- main.py +65 -0
- requirements.txt +12 -0
- results/no_model_name_available/no_revision_available/BKAILegalDocRetrieval.json +158 -0
- results/no_model_name_available/no_revision_available/model_meta.json +1 -0
- settings.py +44 -0
- step_01_Prepare_Data.ipynb +411 -0
- step_02_Finetune_SBERT.ipynb +580 -0
- step_03_Eval_with_MTEB.ipynb +479 -0
- step_04_Retrieval.ipynb +383 -0
.gitignore
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__/
|
2 |
+
.gradio/
|
3 |
+
cache/
|
4 |
+
data/original/
|
5 |
+
models/
|
6 |
+
data/
|
7 |
+
tmp/
|
8 |
+
.env
|
Dockerfile
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM continuumio/miniconda3
|
2 |
+
|
3 |
+
WORKDIR /app
|
4 |
+
|
5 |
+
# Get dependencies
|
6 |
+
COPY requirements.txt .
|
7 |
+
|
8 |
+
RUN conda install -y python=3.10 \
|
9 |
+
pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia \
|
10 |
+
faiss-gpu=1.9.0 -c pytorch -c nvidia && \
|
11 |
+
conda clean -afy
|
12 |
+
|
13 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
14 |
+
|
15 |
+
# Copy the rest of the code
|
16 |
+
COPY . /app
|
17 |
+
|
18 |
+
# Run the application
|
19 |
+
EXPOSE 7860
|
20 |
+
|
21 |
+
CMD ["python", "main.py"]
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2025 Nguyen Phu Tai
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
README.md
CHANGED
@@ -6,6 +6,10 @@ colorTo: pink
|
|
6 |
sdk: docker
|
7 |
pinned: false
|
8 |
short_description: Fine-tuned Retrieval System for Vietnamese Legal Documents
|
|
|
|
|
|
|
|
|
9 |
---
|
10 |
|
11 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
6 |
sdk: docker
|
7 |
pinned: false
|
8 |
short_description: Fine-tuned Retrieval System for Vietnamese Legal Documents
|
9 |
+
models:
|
10 |
+
- YuITC/bert-base-multilingual-cased-finetuned-VNLegalDocs
|
11 |
+
datasets:
|
12 |
+
- YuITC/Vietnamese-Legal-Doc-Retrieval-Data
|
13 |
---
|
14 |
|
15 |
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
main.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
import pandas as pd
|
4 |
+
import gradio as gr
|
5 |
+
|
6 |
+
import faiss
|
7 |
+
from sentence_transformers import SentenceTransformer
|
8 |
+
from settings import OUTPUT_DIR, DEVICE
|
9 |
+
os.environ['WANDB_DISABLED'] = 'true'
|
10 |
+
|
11 |
+
|
12 |
+
fine_tuned_model = SentenceTransformer(OUTPUT_DIR, device=DEVICE)
|
13 |
+
passages = pd.read_parquet('data/processed/corpus_data.parquet')['text'].tolist()
|
14 |
+
legal_index = faiss.read_index('data/retrieval/legal_faiss.index')
|
15 |
+
|
16 |
+
def retrieval(emb_model, query, index, top_k=10):
|
17 |
+
q_emb = emb_model.encode(
|
18 |
+
query,
|
19 |
+
convert_to_numpy=True,
|
20 |
+
normalize_embeddings=True,
|
21 |
+
).astype(np.float32).reshape(1, -1)
|
22 |
+
|
23 |
+
scores, indices = index.search(q_emb, top_k) # shape: (1, top_k)
|
24 |
+
|
25 |
+
cand_idxs = indices[0]
|
26 |
+
cand_scores = scores[0]
|
27 |
+
cand_texts = [passages[i] for i in cand_idxs]
|
28 |
+
|
29 |
+
results = [{
|
30 |
+
'index': int(cand_idxs[i]),
|
31 |
+
'score': float(cand_scores[i]),
|
32 |
+
'text': cand_texts[i]
|
33 |
+
} for i in range(len(cand_idxs))]
|
34 |
+
|
35 |
+
return results
|
36 |
+
|
37 |
+
def get_results(query, top_k):
|
38 |
+
hits = retrieval(fine_tuned_model, query, legal_index, top_k=top_k)
|
39 |
+
|
40 |
+
result = ""
|
41 |
+
for rank, h in enumerate(hits, start=1):
|
42 |
+
result += f"[Kết quả {rank} - Độ tin cậy={h['score']:.4f}]\n\n{h['text']}\n{'-'*100}\n"
|
43 |
+
return result
|
44 |
+
|
45 |
+
|
46 |
+
demo = gr.Interface(
|
47 |
+
'huggingface/YuITC/bert-base-multilingual-cased-finetuned-VNLegalDocs',
|
48 |
+
fn=get_results,
|
49 |
+
inputs=[
|
50 |
+
gr.Textbox(lines=2, placeholder='Nhập câu hỏi pháp lý của bạn...', label='Câu hỏi'),
|
51 |
+
gr.Slider(minimum=5, maximum=20, value=10, step=1, label='Số lượng kết quả'),
|
52 |
+
],
|
53 |
+
outputs=gr.Textbox(lines=20, label='Kết quả'),
|
54 |
+
title='Vietnamese Legal Document Retrieval System',
|
55 |
+
description='🔍 Nhập câu hỏi pháp lý của bạn bằng tiếng Việt để nhận các đoạn văn bản pháp luật liên quan.',
|
56 |
+
examples=[
|
57 |
+
['Tội xúc phạm danh dự?'],
|
58 |
+
['Quyền lợi của người lao động?'],
|
59 |
+
['Thủ tục đăng ký kết hôn?'],
|
60 |
+
],
|
61 |
+
flagging_mode='never'
|
62 |
+
)
|
63 |
+
|
64 |
+
if __name__ == '__main__':
|
65 |
+
demo.launch(share=True)
|
requirements.txt
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# !conda install pytorch torchvision torchaudio pytorch-cuda=12.1 -c pytorch -c nvidia
|
2 |
+
# !conda install faiss-gpu=1.9.0 -c pytorch -c nvidia
|
3 |
+
|
4 |
+
transformers
|
5 |
+
sentence-transformers
|
6 |
+
accelerate
|
7 |
+
datasets
|
8 |
+
mteb
|
9 |
+
tqdm
|
10 |
+
pandas
|
11 |
+
gradio
|
12 |
+
huggingface-hub
|
results/no_model_name_available/no_revision_available/BKAILegalDocRetrieval.json
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"dataset_revision": "d4c5a8ba10ae71224752c727094ac4c46947fa29",
|
3 |
+
"task_name": "BKAILegalDocRetrieval",
|
4 |
+
"mteb_version": "1.38.1",
|
5 |
+
"scores": {
|
6 |
+
"test": [
|
7 |
+
{
|
8 |
+
"ndcg_at_1": 0.42425,
|
9 |
+
"ndcg_at_3": 0.53937,
|
10 |
+
"ndcg_at_5": 0.57387,
|
11 |
+
"ndcg_at_10": 0.60389,
|
12 |
+
"ndcg_at_20": 0.6216,
|
13 |
+
"ndcg_at_100": 0.63894,
|
14 |
+
"ndcg_at_1000": 0.64436,
|
15 |
+
"map_at_1": 0.40328,
|
16 |
+
"map_at_3": 0.50314,
|
17 |
+
"map_at_5": 0.52297,
|
18 |
+
"map_at_10": 0.53608,
|
19 |
+
"map_at_20": 0.54136,
|
20 |
+
"map_at_100": 0.54418,
|
21 |
+
"map_at_1000": 0.54442,
|
22 |
+
"recall_at_1": 0.40328,
|
23 |
+
"recall_at_3": 0.62323,
|
24 |
+
"recall_at_5": 0.70466,
|
25 |
+
"recall_at_10": 0.79407,
|
26 |
+
"recall_at_20": 0.86112,
|
27 |
+
"recall_at_100": 0.94805,
|
28 |
+
"recall_at_1000": 0.98787,
|
29 |
+
"precision_at_1": 0.42425,
|
30 |
+
"precision_at_3": 0.22147,
|
31 |
+
"precision_at_5": 0.15119,
|
32 |
+
"precision_at_10": 0.08587,
|
33 |
+
"precision_at_20": 0.04687,
|
34 |
+
"precision_at_100": 0.01045,
|
35 |
+
"precision_at_1000": 0.0011,
|
36 |
+
"mrr_at_1": 0.424183,
|
37 |
+
"mrr_at_3": 0.524672,
|
38 |
+
"mrr_at_5": 0.543372,
|
39 |
+
"mrr_at_10": 0.555102,
|
40 |
+
"mrr_at_20": 0.559556,
|
41 |
+
"mrr_at_100": 0.561719,
|
42 |
+
"mrr_at_1000": 0.561878,
|
43 |
+
"nauc_ndcg_at_1_max": 0.252524,
|
44 |
+
"nauc_ndcg_at_1_std": -0.130263,
|
45 |
+
"nauc_ndcg_at_1_diff1": 0.488176,
|
46 |
+
"nauc_ndcg_at_3_max": 0.298482,
|
47 |
+
"nauc_ndcg_at_3_std": -0.120077,
|
48 |
+
"nauc_ndcg_at_3_diff1": 0.423316,
|
49 |
+
"nauc_ndcg_at_5_max": 0.307625,
|
50 |
+
"nauc_ndcg_at_5_std": -0.110964,
|
51 |
+
"nauc_ndcg_at_5_diff1": 0.419743,
|
52 |
+
"nauc_ndcg_at_10_max": 0.312344,
|
53 |
+
"nauc_ndcg_at_10_std": -0.101157,
|
54 |
+
"nauc_ndcg_at_10_diff1": 0.419576,
|
55 |
+
"nauc_ndcg_at_20_max": 0.31366,
|
56 |
+
"nauc_ndcg_at_20_std": -0.093809,
|
57 |
+
"nauc_ndcg_at_20_diff1": 0.423325,
|
58 |
+
"nauc_ndcg_at_100_max": 0.308888,
|
59 |
+
"nauc_ndcg_at_100_std": -0.091458,
|
60 |
+
"nauc_ndcg_at_100_diff1": 0.428327,
|
61 |
+
"nauc_ndcg_at_1000_max": 0.303777,
|
62 |
+
"nauc_ndcg_at_1000_std": -0.098258,
|
63 |
+
"nauc_ndcg_at_1000_diff1": 0.430885,
|
64 |
+
"nauc_map_at_1_max": 0.238748,
|
65 |
+
"nauc_map_at_1_std": -0.133375,
|
66 |
+
"nauc_map_at_1_diff1": 0.476974,
|
67 |
+
"nauc_map_at_3_max": 0.28179,
|
68 |
+
"nauc_map_at_3_std": -0.124789,
|
69 |
+
"nauc_map_at_3_diff1": 0.435363,
|
70 |
+
"nauc_map_at_5_max": 0.286506,
|
71 |
+
"nauc_map_at_5_std": -0.120112,
|
72 |
+
"nauc_map_at_5_diff1": 0.433864,
|
73 |
+
"nauc_map_at_10_max": 0.288218,
|
74 |
+
"nauc_map_at_10_std": -0.116509,
|
75 |
+
"nauc_map_at_10_diff1": 0.434401,
|
76 |
+
"nauc_map_at_20_max": 0.288517,
|
77 |
+
"nauc_map_at_20_std": -0.114629,
|
78 |
+
"nauc_map_at_20_diff1": 0.435545,
|
79 |
+
"nauc_map_at_100_max": 0.287963,
|
80 |
+
"nauc_map_at_100_std": -0.114142,
|
81 |
+
"nauc_map_at_100_diff1": 0.436281,
|
82 |
+
"nauc_map_at_1000_max": 0.287808,
|
83 |
+
"nauc_map_at_1000_std": -0.114315,
|
84 |
+
"nauc_map_at_1000_diff1": 0.436377,
|
85 |
+
"nauc_recall_at_1_max": 0.238748,
|
86 |
+
"nauc_recall_at_1_std": -0.133375,
|
87 |
+
"nauc_recall_at_1_diff1": 0.476974,
|
88 |
+
"nauc_recall_at_3_max": 0.330773,
|
89 |
+
"nauc_recall_at_3_std": -0.107907,
|
90 |
+
"nauc_recall_at_3_diff1": 0.366506,
|
91 |
+
"nauc_recall_at_5_max": 0.36187,
|
92 |
+
"nauc_recall_at_5_std": -0.080013,
|
93 |
+
"nauc_recall_at_5_diff1": 0.345161,
|
94 |
+
"nauc_recall_at_10_max": 0.399711,
|
95 |
+
"nauc_recall_at_10_std": -0.026693,
|
96 |
+
"nauc_recall_at_10_diff1": 0.318554,
|
97 |
+
"nauc_recall_at_20_max": 0.445634,
|
98 |
+
"nauc_recall_at_20_std": 0.057536,
|
99 |
+
"nauc_recall_at_20_diff1": 0.30652,
|
100 |
+
"nauc_recall_at_100_max": 0.544189,
|
101 |
+
"nauc_recall_at_100_std": 0.325327,
|
102 |
+
"nauc_recall_at_100_diff1": 0.272927,
|
103 |
+
"nauc_recall_at_1000_max": 0.578666,
|
104 |
+
"nauc_recall_at_1000_std": 0.566039,
|
105 |
+
"nauc_recall_at_1000_diff1": 0.23906,
|
106 |
+
"nauc_precision_at_1_max": 0.252524,
|
107 |
+
"nauc_precision_at_1_std": -0.130263,
|
108 |
+
"nauc_precision_at_1_diff1": 0.488176,
|
109 |
+
"nauc_precision_at_3_max": 0.343321,
|
110 |
+
"nauc_precision_at_3_std": -0.090953,
|
111 |
+
"nauc_precision_at_3_diff1": 0.354789,
|
112 |
+
"nauc_precision_at_5_max": 0.356368,
|
113 |
+
"nauc_precision_at_5_std": -0.05169,
|
114 |
+
"nauc_precision_at_5_diff1": 0.308044,
|
115 |
+
"nauc_precision_at_10_max": 0.338907,
|
116 |
+
"nauc_precision_at_10_std": 0.01503,
|
117 |
+
"nauc_precision_at_10_diff1": 0.230763,
|
118 |
+
"nauc_precision_at_20_max": 0.299075,
|
119 |
+
"nauc_precision_at_20_std": 0.08907,
|
120 |
+
"nauc_precision_at_20_diff1": 0.154507,
|
121 |
+
"nauc_precision_at_100_max": 0.148044,
|
122 |
+
"nauc_precision_at_100_std": 0.170043,
|
123 |
+
"nauc_precision_at_100_diff1": 0.008958,
|
124 |
+
"nauc_precision_at_1000_max": 0.011265,
|
125 |
+
"nauc_precision_at_1000_std": 0.110291,
|
126 |
+
"nauc_precision_at_1000_diff1": -0.064328,
|
127 |
+
"nauc_mrr_at_1_max": 0.252492,
|
128 |
+
"nauc_mrr_at_1_std": -0.130181,
|
129 |
+
"nauc_mrr_at_1_diff1": 0.488352,
|
130 |
+
"nauc_mrr_at_3_max": 0.295039,
|
131 |
+
"nauc_mrr_at_3_std": -0.119392,
|
132 |
+
"nauc_mrr_at_3_diff1": 0.445442,
|
133 |
+
"nauc_mrr_at_5_max": 0.298742,
|
134 |
+
"nauc_mrr_at_5_std": -0.114751,
|
135 |
+
"nauc_mrr_at_5_diff1": 0.444417,
|
136 |
+
"nauc_mrr_at_10_max": 0.299598,
|
137 |
+
"nauc_mrr_at_10_std": -0.111479,
|
138 |
+
"nauc_mrr_at_10_diff1": 0.444763,
|
139 |
+
"nauc_mrr_at_20_max": 0.299328,
|
140 |
+
"nauc_mrr_at_20_std": -0.110211,
|
141 |
+
"nauc_mrr_at_20_diff1": 0.44605,
|
142 |
+
"nauc_mrr_at_100_max": 0.298458,
|
143 |
+
"nauc_mrr_at_100_std": -0.110232,
|
144 |
+
"nauc_mrr_at_100_diff1": 0.446632,
|
145 |
+
"nauc_mrr_at_1000_max": 0.298311,
|
146 |
+
"nauc_mrr_at_1000_std": -0.110415,
|
147 |
+
"nauc_mrr_at_1000_diff1": 0.446697,
|
148 |
+
"main_score": 0.60389,
|
149 |
+
"hf_subset": "default",
|
150 |
+
"languages": [
|
151 |
+
"vi"
|
152 |
+
]
|
153 |
+
}
|
154 |
+
]
|
155 |
+
},
|
156 |
+
"evaluation_time": 3061.7869832515717,
|
157 |
+
"kg_co2_emissions": null
|
158 |
+
}
|
results/no_model_name_available/no_revision_available/model_meta.json
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"name": "no_model_name_available", "revision": "no_revision_available", "release_date": null, "languages": [], "n_parameters": null, "memory_usage_mb": null, "max_tokens": null, "embed_dim": null, "license": null, "open_weights": true, "public_training_code": null, "public_training_data": null, "framework": ["Sentence Transformers"], "reference": null, "similarity_fn_name": "cosine", "use_instructions": null, "training_datasets": null, "adapted_from": null, "superseded_by": null, "is_cross_encoder": null, "modalities": ["text"], "loader": null}
|
settings.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import random
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
|
7 |
+
# Data settings
|
8 |
+
os.makedirs('data', exist_ok=True)
|
9 |
+
os.makedirs('data/original', exist_ok=True)
|
10 |
+
os.makedirs('data/processed', exist_ok=True)
|
11 |
+
os.makedirs('data/retrieval', exist_ok=True)
|
12 |
+
|
13 |
+
|
14 |
+
# Model settings
|
15 |
+
MODEL_ID = 'google-bert/bert-base-multilingual-cased'
|
16 |
+
MODEL_NAME = 'VN-legalDocs-SBERT'
|
17 |
+
|
18 |
+
CACHE_DIR = f"cache/{MODEL_NAME}"
|
19 |
+
OUTPUT_DIR = f"models/{MODEL_NAME}"
|
20 |
+
|
21 |
+
os.makedirs(CACHE_DIR, exist_ok=True)
|
22 |
+
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
23 |
+
|
24 |
+
|
25 |
+
# Reproducibility
|
26 |
+
SEED = 42
|
27 |
+
random.seed(SEED)
|
28 |
+
np.random.seed(SEED)
|
29 |
+
torch.manual_seed(SEED)
|
30 |
+
torch.cuda.manual_seed_all(SEED)
|
31 |
+
|
32 |
+
# Reproducibility: deterministic=True, benchmark=False
|
33 |
+
# Optimize inference/training speed: deterministic=False, benchmark=True
|
34 |
+
torch.backends.cudnn.deterministic = False
|
35 |
+
torch.backends.cudnn.benchmark = True
|
36 |
+
|
37 |
+
|
38 |
+
# Hyperparameters
|
39 |
+
MAX_SEQ_LEN = 512
|
40 |
+
EPOCHS = 5
|
41 |
+
LR = 3e-5
|
42 |
+
BATCH_SIZE = 128
|
43 |
+
DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
|
44 |
+
print(f"Using device: {DEVICE}")
|
step_01_Prepare_Data.ipynb
ADDED
@@ -0,0 +1,411 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 3,
|
6 |
+
"id": "29a91458",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [
|
9 |
+
{
|
10 |
+
"name": "stdout",
|
11 |
+
"output_type": "stream",
|
12 |
+
"text": [
|
13 |
+
"Using device: cuda\n"
|
14 |
+
]
|
15 |
+
}
|
16 |
+
],
|
17 |
+
"source": [
|
18 |
+
"!python settings.py"
|
19 |
+
]
|
20 |
+
},
|
21 |
+
{
|
22 |
+
"cell_type": "code",
|
23 |
+
"execution_count": 4,
|
24 |
+
"id": "97c0ec5c",
|
25 |
+
"metadata": {},
|
26 |
+
"outputs": [],
|
27 |
+
"source": [
|
28 |
+
"import os\n",
|
29 |
+
"import zipfile\n",
|
30 |
+
"import requests\n",
|
31 |
+
"import pandas as pd"
|
32 |
+
]
|
33 |
+
},
|
34 |
+
{
|
35 |
+
"cell_type": "code",
|
36 |
+
"execution_count": 5,
|
37 |
+
"id": "f7b1ed51",
|
38 |
+
"metadata": {},
|
39 |
+
"outputs": [],
|
40 |
+
"source": [
|
41 |
+
"# Download the dataset\n",
|
42 |
+
"url = 'https://huggingface.co/datasets/tmnam20/BKAI-Legal-Retrieval/resolve/main/archive.zip'\n",
|
43 |
+
"zip_path = 'data/original/archive.zip'\n",
|
44 |
+
"\n",
|
45 |
+
"response = requests.get(url)\n",
|
46 |
+
"with open(zip_path, 'wb') as f:\n",
|
47 |
+
" f.write(response.content)\n",
|
48 |
+
"\n",
|
49 |
+
"with zipfile.ZipFile(zip_path, 'r') as zip_ref:\n",
|
50 |
+
" zip_ref.extractall('data/original')\n",
|
51 |
+
" \n",
|
52 |
+
"os.remove(zip_path)"
|
53 |
+
]
|
54 |
+
},
|
55 |
+
{
|
56 |
+
"cell_type": "code",
|
57 |
+
"execution_count": 6,
|
58 |
+
"id": "4fe0c4f8",
|
59 |
+
"metadata": {},
|
60 |
+
"outputs": [
|
61 |
+
{
|
62 |
+
"name": "stdout",
|
63 |
+
"output_type": "stream",
|
64 |
+
"text": [
|
65 |
+
"Train split data: 89592\n",
|
66 |
+
"Test split data : 29864\n"
|
67 |
+
]
|
68 |
+
}
|
69 |
+
],
|
70 |
+
"source": [
|
71 |
+
"corpus_data = pd.read_csv('data/original/corpus.csv')\n",
|
72 |
+
"train_split = pd.read_csv('data/original/train_split.csv')\n",
|
73 |
+
"test_split = pd.read_csv('data/original/val_split.csv')\n",
|
74 |
+
"\n",
|
75 |
+
"print(f\"Train split data: {len(train_split)}\")\n",
|
76 |
+
"print(f\"Test split data : {len(test_split)}\")"
|
77 |
+
]
|
78 |
+
},
|
79 |
+
{
|
80 |
+
"cell_type": "code",
|
81 |
+
"execution_count": 7,
|
82 |
+
"id": "6e3fbd6e",
|
83 |
+
"metadata": {},
|
84 |
+
"outputs": [
|
85 |
+
{
|
86 |
+
"data": {
|
87 |
+
"text/html": [
|
88 |
+
"<div>\n",
|
89 |
+
"<style scoped>\n",
|
90 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
91 |
+
" vertical-align: middle;\n",
|
92 |
+
" }\n",
|
93 |
+
"\n",
|
94 |
+
" .dataframe tbody tr th {\n",
|
95 |
+
" vertical-align: top;\n",
|
96 |
+
" }\n",
|
97 |
+
"\n",
|
98 |
+
" .dataframe thead th {\n",
|
99 |
+
" text-align: right;\n",
|
100 |
+
" }\n",
|
101 |
+
"</style>\n",
|
102 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
103 |
+
" <thead>\n",
|
104 |
+
" <tr style=\"text-align: right;\">\n",
|
105 |
+
" <th></th>\n",
|
106 |
+
" <th>text</th>\n",
|
107 |
+
" <th>cid</th>\n",
|
108 |
+
" </tr>\n",
|
109 |
+
" </thead>\n",
|
110 |
+
" <tbody>\n",
|
111 |
+
" <tr>\n",
|
112 |
+
" <th>0</th>\n",
|
113 |
+
" <td>Thông tư này hướng dẫn tuần tra, canh gác bảo ...</td>\n",
|
114 |
+
" <td>0</td>\n",
|
115 |
+
" </tr>\n",
|
116 |
+
" <tr>\n",
|
117 |
+
" <th>1</th>\n",
|
118 |
+
" <td>1. Hàng năm trước mùa mưa, lũ, Ủy ban nhân dân...</td>\n",
|
119 |
+
" <td>1</td>\n",
|
120 |
+
" </tr>\n",
|
121 |
+
" <tr>\n",
|
122 |
+
" <th>2</th>\n",
|
123 |
+
" <td>Tiêu chuẩn của các thành viên thuộc lực lượng ...</td>\n",
|
124 |
+
" <td>2</td>\n",
|
125 |
+
" </tr>\n",
|
126 |
+
" <tr>\n",
|
127 |
+
" <th>3</th>\n",
|
128 |
+
" <td>Nhiệm vụ của lực lượng tuần tra, canh gác đê\\n...</td>\n",
|
129 |
+
" <td>3</td>\n",
|
130 |
+
" </tr>\n",
|
131 |
+
" <tr>\n",
|
132 |
+
" <th>4</th>\n",
|
133 |
+
" <td>Phù hiệu của lực lượng tuần tra, canh gác đê\\n...</td>\n",
|
134 |
+
" <td>4</td>\n",
|
135 |
+
" </tr>\n",
|
136 |
+
" </tbody>\n",
|
137 |
+
"</table>\n",
|
138 |
+
"</div>"
|
139 |
+
],
|
140 |
+
"text/plain": [
|
141 |
+
" text cid\n",
|
142 |
+
"0 Thông tư này hướng dẫn tuần tra, canh gác bảo ... 0\n",
|
143 |
+
"1 1. Hàng năm trước mùa mưa, lũ, Ủy ban nhân dân... 1\n",
|
144 |
+
"2 Tiêu chuẩn của các thành viên thuộc lực lượng ... 2\n",
|
145 |
+
"3 Nhiệm vụ của lực lượng tuần tra, canh gác đê\\n... 3\n",
|
146 |
+
"4 Phù hiệu của lực lượng tuần tra, canh gác đê\\n... 4"
|
147 |
+
]
|
148 |
+
},
|
149 |
+
"execution_count": 7,
|
150 |
+
"metadata": {},
|
151 |
+
"output_type": "execute_result"
|
152 |
+
}
|
153 |
+
],
|
154 |
+
"source": [
|
155 |
+
"corpus_data.head()"
|
156 |
+
]
|
157 |
+
},
|
158 |
+
{
|
159 |
+
"cell_type": "code",
|
160 |
+
"execution_count": 8,
|
161 |
+
"id": "3d32d13a",
|
162 |
+
"metadata": {},
|
163 |
+
"outputs": [
|
164 |
+
{
|
165 |
+
"data": {
|
166 |
+
"text/html": [
|
167 |
+
"<div>\n",
|
168 |
+
"<style scoped>\n",
|
169 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
170 |
+
" vertical-align: middle;\n",
|
171 |
+
" }\n",
|
172 |
+
"\n",
|
173 |
+
" .dataframe tbody tr th {\n",
|
174 |
+
" vertical-align: top;\n",
|
175 |
+
" }\n",
|
176 |
+
"\n",
|
177 |
+
" .dataframe thead th {\n",
|
178 |
+
" text-align: right;\n",
|
179 |
+
" }\n",
|
180 |
+
"</style>\n",
|
181 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
182 |
+
" <thead>\n",
|
183 |
+
" <tr style=\"text-align: right;\">\n",
|
184 |
+
" <th></th>\n",
|
185 |
+
" <th>question</th>\n",
|
186 |
+
" <th>context</th>\n",
|
187 |
+
" <th>cid</th>\n",
|
188 |
+
" <th>qid</th>\n",
|
189 |
+
" <th>context_list</th>\n",
|
190 |
+
" </tr>\n",
|
191 |
+
" </thead>\n",
|
192 |
+
" <tbody>\n",
|
193 |
+
" <tr>\n",
|
194 |
+
" <th>0</th>\n",
|
195 |
+
" <td>Liên đoàn Luật sư Việt Nam là tổ chức xã hội –...</td>\n",
|
196 |
+
" <td>['“Điều 2. Địa vị pháp lý của Liên đoàn Luật s...</td>\n",
|
197 |
+
" <td>[142820]</td>\n",
|
198 |
+
" <td>72600</td>\n",
|
199 |
+
" <td>[“Điều 2. Địa vị pháp lý của Liên đoàn Luật sư...</td>\n",
|
200 |
+
" </tr>\n",
|
201 |
+
" <tr>\n",
|
202 |
+
" <th>1</th>\n",
|
203 |
+
" <td>Tên hợp tác xã bị rơi vào trường hợp cấm thì c...</td>\n",
|
204 |
+
" <td>['Tên hợp tác xã, liên hiệp hợp tác xã\\n1. Tên...</td>\n",
|
205 |
+
" <td>[27817, 72117]</td>\n",
|
206 |
+
" <td>147562</td>\n",
|
207 |
+
" <td>[\"Điều 7. Tên hợp tác xã, liên hiệp hợp tác xã...</td>\n",
|
208 |
+
" </tr>\n",
|
209 |
+
" <tr>\n",
|
210 |
+
" <th>2</th>\n",
|
211 |
+
" <td>Tài xế lái xe ô tô khách 50 chỗ ngồi bao lâu t...</td>\n",
|
212 |
+
" <td>['\"1. Sử dụng lái xe bảo đảm sức khỏe theo tiê...</td>\n",
|
213 |
+
" <td>[33215, 56201]</td>\n",
|
214 |
+
" <td>142107</td>\n",
|
215 |
+
" <td>[\"1. Sử dụng lái xe bảo đảm sức khỏe theo tiêu...</td>\n",
|
216 |
+
" </tr>\n",
|
217 |
+
" <tr>\n",
|
218 |
+
" <th>3</th>\n",
|
219 |
+
" <td>Các bước chuẩn bị thủ thuật bó bột Cravate sẽ ...</td>\n",
|
220 |
+
" <td>['BỘT CRAVATE\\n...\\nIV. CHUẨN BỊ\\n1. Người thự...</td>\n",
|
221 |
+
" <td>[148158]</td>\n",
|
222 |
+
" <td>77353</td>\n",
|
223 |
+
" <td>[BỘT CRAVATE\\n...\\nIV. CHUẨN BỊ\\n1. Người thực...</td>\n",
|
224 |
+
" </tr>\n",
|
225 |
+
" <tr>\n",
|
226 |
+
" <th>4</th>\n",
|
227 |
+
" <td>Viên chức Hộ sinh hạng 4 có những nhiệm vụ gì ...</td>\n",
|
228 |
+
" <td>['Hộ sinh hạng IV - Mã số: V.08.06.16\\n1. Nhiệ...</td>\n",
|
229 |
+
" <td>[188132]</td>\n",
|
230 |
+
" <td>113090</td>\n",
|
231 |
+
" <td>[Hộ sinh hạng IV - Mã số: V.08.06.16\\n1. Nhiệm...</td>\n",
|
232 |
+
" </tr>\n",
|
233 |
+
" </tbody>\n",
|
234 |
+
"</table>\n",
|
235 |
+
"</div>"
|
236 |
+
],
|
237 |
+
"text/plain": [
|
238 |
+
" question \\\n",
|
239 |
+
"0 Liên đoàn Luật sư Việt Nam là tổ chức xã hội –... \n",
|
240 |
+
"1 Tên hợp tác xã bị rơi vào trường hợp cấm thì c... \n",
|
241 |
+
"2 Tài xế lái xe ô tô khách 50 chỗ ngồi bao lâu t... \n",
|
242 |
+
"3 Các bước chuẩn bị thủ thuật bó bột Cravate sẽ ... \n",
|
243 |
+
"4 Viên chức Hộ sinh hạng 4 có những nhiệm vụ gì ... \n",
|
244 |
+
"\n",
|
245 |
+
" context cid qid \\\n",
|
246 |
+
"0 ['“Điều 2. Địa vị pháp lý của Liên đoàn Luật s... [142820] 72600 \n",
|
247 |
+
"1 ['Tên hợp tác xã, liên hiệp hợp tác xã\\n1. Tên... [27817, 72117] 147562 \n",
|
248 |
+
"2 ['\"1. Sử dụng lái xe bảo đảm sức khỏe theo tiê... [33215, 56201] 142107 \n",
|
249 |
+
"3 ['BỘT CRAVATE\\n...\\nIV. CHUẨN BỊ\\n1. Người thự... [148158] 77353 \n",
|
250 |
+
"4 ['Hộ sinh hạng IV - Mã số: V.08.06.16\\n1. Nhiệ... [188132] 113090 \n",
|
251 |
+
"\n",
|
252 |
+
" context_list \n",
|
253 |
+
"0 [“Điều 2. Địa vị pháp lý của Liên đoàn Luật sư... \n",
|
254 |
+
"1 [\"Điều 7. Tên hợp tác xã, liên hiệp hợp tác xã... \n",
|
255 |
+
"2 [\"1. Sử dụng lái xe bảo đảm sức khỏe theo tiêu... \n",
|
256 |
+
"3 [BỘT CRAVATE\\n...\\nIV. CHUẨN BỊ\\n1. Người thực... \n",
|
257 |
+
"4 [Hộ sinh hạng IV - Mã số: V.08.06.16\\n1. Nhiệm... "
|
258 |
+
]
|
259 |
+
},
|
260 |
+
"execution_count": 8,
|
261 |
+
"metadata": {},
|
262 |
+
"output_type": "execute_result"
|
263 |
+
}
|
264 |
+
],
|
265 |
+
"source": [
|
266 |
+
"# 'cid' column: '[1 2 3]'\n",
|
267 |
+
"train_split['cid'] = train_split['cid'].apply(lambda x: [int(i) for i in x[1:-1].split()])\n",
|
268 |
+
"test_split['cid'] = test_split['cid'].apply(lambda x: [int(i) for i in x[1:-1].split()])\n",
|
269 |
+
"\n",
|
270 |
+
"\n",
|
271 |
+
"# Mapping from corpus \n",
|
272 |
+
"mapping = dict(zip(corpus_data['cid'], corpus_data['text']))\n",
|
273 |
+
"\n",
|
274 |
+
"def get_context_list(cid_list):\n",
|
275 |
+
" return [mapping[cid] for cid in cid_list if cid in mapping]\n",
|
276 |
+
"\n",
|
277 |
+
"train_split['context_list'] = train_split['cid'].apply(get_context_list)\n",
|
278 |
+
"test_split['context_list'] = test_split['cid'].apply(get_context_list)\n",
|
279 |
+
"\n",
|
280 |
+
"train_split.head()"
|
281 |
+
]
|
282 |
+
},
|
283 |
+
{
|
284 |
+
"cell_type": "code",
|
285 |
+
"execution_count": 9,
|
286 |
+
"id": "e0450414",
|
287 |
+
"metadata": {},
|
288 |
+
"outputs": [
|
289 |
+
{
|
290 |
+
"name": "stdout",
|
291 |
+
"output_type": "stream",
|
292 |
+
"text": [
|
293 |
+
"430 99 331\n",
|
294 |
+
"question <class 'str'>\n",
|
295 |
+
"context <class 'str'>\n",
|
296 |
+
"cid <class 'list'>\n",
|
297 |
+
"qid <class 'numpy.int64'>\n",
|
298 |
+
"context_list <class 'list'>\n"
|
299 |
+
]
|
300 |
+
}
|
301 |
+
],
|
302 |
+
"source": [
|
303 |
+
"# Debug\n",
|
304 |
+
"print(\n",
|
305 |
+
" len(train_split[train_split['context_list'].apply(len) != train_split['cid'].apply(len)]),\n",
|
306 |
+
" \n",
|
307 |
+
" len(\n",
|
308 |
+
" train_split[\n",
|
309 |
+
" (train_split['context_list'].apply(len) != train_split['cid'].apply(len)) &\n",
|
310 |
+
" (train_split['context_list'].apply(len) != 0)\n",
|
311 |
+
" ]\n",
|
312 |
+
" ),\n",
|
313 |
+
" \n",
|
314 |
+
" len(\n",
|
315 |
+
" train_split[\n",
|
316 |
+
" (train_split['context_list'].apply(len) != train_split['cid'].apply(len)) &\n",
|
317 |
+
" (train_split['context_list'].apply(len) == 0)\n",
|
318 |
+
" ]\n",
|
319 |
+
" )\n",
|
320 |
+
")\n",
|
321 |
+
"\n",
|
322 |
+
"for col in train_split.columns:\n",
|
323 |
+
" print(col, type(train_split[col][0]))"
|
324 |
+
]
|
325 |
+
},
|
326 |
+
{
|
327 |
+
"cell_type": "code",
|
328 |
+
"execution_count": 10,
|
329 |
+
"id": "fd1eb4a2",
|
330 |
+
"metadata": {},
|
331 |
+
"outputs": [],
|
332 |
+
"source": [
|
333 |
+
"# Drop invalid data\n",
|
334 |
+
"train_data = train_split.loc[\n",
|
335 |
+
" ~(train_split['context_list'].apply(len) != train_split['cid'].apply(len)), \n",
|
336 |
+
" ['question', 'context_list', 'qid', 'cid']\n",
|
337 |
+
"]\n",
|
338 |
+
"\n",
|
339 |
+
"test_data = test_split.loc[\n",
|
340 |
+
" ~(test_split['context_list'].apply(len) != test_split['cid'].apply(len)), \n",
|
341 |
+
" ['question', 'context_list', 'qid', 'cid']\n",
|
342 |
+
"]"
|
343 |
+
]
|
344 |
+
},
|
345 |
+
{
|
346 |
+
"cell_type": "code",
|
347 |
+
"execution_count": 11,
|
348 |
+
"id": "3661c9cb",
|
349 |
+
"metadata": {},
|
350 |
+
"outputs": [
|
351 |
+
{
|
352 |
+
"name": "stdout",
|
353 |
+
"output_type": "stream",
|
354 |
+
"text": [
|
355 |
+
"Train data saved: 89162\n",
|
356 |
+
"Test data saved : 29723\n"
|
357 |
+
]
|
358 |
+
}
|
359 |
+
],
|
360 |
+
"source": [
|
361 |
+
"# Save the processed data to parquet files\n",
|
362 |
+
"corpus_data.to_parquet('data/processed/corpus_data.parquet', index=False)\n",
|
363 |
+
"train_data.to_parquet('data/processed/train_data.parquet', index=False)\n",
|
364 |
+
"test_data.to_parquet('data/processed/test_data.parquet', index=False)\n",
|
365 |
+
"\n",
|
366 |
+
"print(f\"Train data saved: {len(train_data)}\")\n",
|
367 |
+
"print(f\"Test data saved : {len(test_data)}\")"
|
368 |
+
]
|
369 |
+
},
|
370 |
+
{
|
371 |
+
"cell_type": "code",
|
372 |
+
"execution_count": 12,
|
373 |
+
"id": "6382a715",
|
374 |
+
"metadata": {},
|
375 |
+
"outputs": [],
|
376 |
+
"source": [
|
377 |
+
"# # Get demo data\n",
|
378 |
+
"# os.makedirs('data/demo', exist_ok=True)\n",
|
379 |
+
"\n",
|
380 |
+
"# demo_corpus_data = corpus_data.sample(10, random_state=42).reset_index(drop=True)\n",
|
381 |
+
"# demo_train_data = train_data.sample(10, random_state=42).reset_index(drop=True)\n",
|
382 |
+
"# demo_test_data = test_data.sample(10, random_state=42).reset_index(drop=True)\n",
|
383 |
+
"\n",
|
384 |
+
"# demo_corpus_data.to_csv('data/demo/demo_corpus_data.csv', index=False)\n",
|
385 |
+
"# demo_train_data.to_csv('data/demo/demo_train_data.csv', index=False)\n",
|
386 |
+
"# demo_test_data.to_csv('data/demo/demo_test_data.csv', index=False)"
|
387 |
+
]
|
388 |
+
}
|
389 |
+
],
|
390 |
+
"metadata": {
|
391 |
+
"kernelspec": {
|
392 |
+
"display_name": "legal_doc_retrieval",
|
393 |
+
"language": "python",
|
394 |
+
"name": "python3"
|
395 |
+
},
|
396 |
+
"language_info": {
|
397 |
+
"codemirror_mode": {
|
398 |
+
"name": "ipython",
|
399 |
+
"version": 3
|
400 |
+
},
|
401 |
+
"file_extension": ".py",
|
402 |
+
"mimetype": "text/x-python",
|
403 |
+
"name": "python",
|
404 |
+
"nbconvert_exporter": "python",
|
405 |
+
"pygments_lexer": "ipython3",
|
406 |
+
"version": "3.10.16"
|
407 |
+
}
|
408 |
+
},
|
409 |
+
"nbformat": 4,
|
410 |
+
"nbformat_minor": 5
|
411 |
+
}
|
step_02_Finetune_SBERT.ipynb
ADDED
@@ -0,0 +1,580 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"id": "24106202",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [
|
9 |
+
{
|
10 |
+
"name": "stdout",
|
11 |
+
"output_type": "stream",
|
12 |
+
"text": [
|
13 |
+
"Using device: cuda\n"
|
14 |
+
]
|
15 |
+
}
|
16 |
+
],
|
17 |
+
"source": [
|
18 |
+
"!python settings.py"
|
19 |
+
]
|
20 |
+
},
|
21 |
+
{
|
22 |
+
"cell_type": "code",
|
23 |
+
"execution_count": null,
|
24 |
+
"id": "0086aabe",
|
25 |
+
"metadata": {},
|
26 |
+
"outputs": [
|
27 |
+
{
|
28 |
+
"name": "stdout",
|
29 |
+
"output_type": "stream",
|
30 |
+
"text": [
|
31 |
+
"Using device: cuda\n"
|
32 |
+
]
|
33 |
+
}
|
34 |
+
],
|
35 |
+
"source": [
|
36 |
+
"import os\n",
|
37 |
+
"import pandas as pd\n",
|
38 |
+
"from datasets import Dataset\n",
|
39 |
+
"from tqdm.autonotebook import tqdm\n",
|
40 |
+
"\n",
|
41 |
+
"from sentence_transformers import (\n",
|
42 |
+
" SentenceTransformer,\n",
|
43 |
+
" SentenceTransformerTrainer,\n",
|
44 |
+
" SentenceTransformerTrainingArguments,\n",
|
45 |
+
" SentenceTransformerModelCardData,\n",
|
46 |
+
")\n",
|
47 |
+
"from sentence_transformers.readers import InputExample\n",
|
48 |
+
"from sentence_transformers.models import Transformer, Pooling\n",
|
49 |
+
"from sentence_transformers.losses import CachedMultipleNegativesRankingLoss\n",
|
50 |
+
"from sentence_transformers.training_args import BatchSamplers\n",
|
51 |
+
"\n",
|
52 |
+
"from settings import MODEL_ID, MODEL_NAME, CACHE_DIR, OUTPUT_DIR, MAX_SEQ_LEN, EPOCHS, LR, BATCH_SIZE, DEVICE\n",
|
53 |
+
"\n",
|
54 |
+
"os.environ['WANDB_DISABLED'] = 'true'"
|
55 |
+
]
|
56 |
+
},
|
57 |
+
{
|
58 |
+
"cell_type": "code",
|
59 |
+
"execution_count": 3,
|
60 |
+
"id": "3a5cc53d",
|
61 |
+
"metadata": {},
|
62 |
+
"outputs": [],
|
63 |
+
"source": [
|
64 |
+
"data = {\n",
|
65 |
+
" 'corpus': pd.read_parquet('data/processed/corpus_data.parquet'),\n",
|
66 |
+
" 'train' : pd.read_parquet('data/processed/train_data.parquet'),\n",
|
67 |
+
" 'test' : pd.read_parquet('data/processed/test_data.parquet')\n",
|
68 |
+
"}\n",
|
69 |
+
"for split in ['train', 'test']:\n",
|
70 |
+
" data[split]['cid'] = data[split]['cid'].apply(lambda x: x.tolist())\n",
|
71 |
+
" data[split]['context_list'] = data[split]['context_list'].apply(lambda x: x.tolist())\n",
|
72 |
+
" \n",
|
73 |
+
"examples = {'train': [], 'test': []}"
|
74 |
+
]
|
75 |
+
},
|
76 |
+
{
|
77 |
+
"cell_type": "code",
|
78 |
+
"execution_count": 4,
|
79 |
+
"id": "30ebbd40",
|
80 |
+
"metadata": {},
|
81 |
+
"outputs": [
|
82 |
+
{
|
83 |
+
"data": {
|
84 |
+
"text/html": [
|
85 |
+
"<div>\n",
|
86 |
+
"<style scoped>\n",
|
87 |
+
" .dataframe tbody tr th:only-of-type {\n",
|
88 |
+
" vertical-align: middle;\n",
|
89 |
+
" }\n",
|
90 |
+
"\n",
|
91 |
+
" .dataframe tbody tr th {\n",
|
92 |
+
" vertical-align: top;\n",
|
93 |
+
" }\n",
|
94 |
+
"\n",
|
95 |
+
" .dataframe thead th {\n",
|
96 |
+
" text-align: right;\n",
|
97 |
+
" }\n",
|
98 |
+
"</style>\n",
|
99 |
+
"<table border=\"1\" class=\"dataframe\">\n",
|
100 |
+
" <thead>\n",
|
101 |
+
" <tr style=\"text-align: right;\">\n",
|
102 |
+
" <th></th>\n",
|
103 |
+
" <th>question</th>\n",
|
104 |
+
" <th>context_list</th>\n",
|
105 |
+
" <th>qid</th>\n",
|
106 |
+
" <th>cid</th>\n",
|
107 |
+
" </tr>\n",
|
108 |
+
" </thead>\n",
|
109 |
+
" <tbody>\n",
|
110 |
+
" <tr>\n",
|
111 |
+
" <th>0</th>\n",
|
112 |
+
" <td>Liên đoàn Luật sư Việt Nam là tổ chức xã hội –...</td>\n",
|
113 |
+
" <td>[“Điều 2. Địa vị pháp lý của Liên đoàn Luật sư...</td>\n",
|
114 |
+
" <td>72600</td>\n",
|
115 |
+
" <td>[142820]</td>\n",
|
116 |
+
" </tr>\n",
|
117 |
+
" <tr>\n",
|
118 |
+
" <th>1</th>\n",
|
119 |
+
" <td>Tên hợp tác xã bị rơi vào trường hợp cấm thì c...</td>\n",
|
120 |
+
" <td>[\"Điều 7. Tên hợp tác xã, liên hiệp hợp tác xã...</td>\n",
|
121 |
+
" <td>147562</td>\n",
|
122 |
+
" <td>[27817, 72117]</td>\n",
|
123 |
+
" </tr>\n",
|
124 |
+
" <tr>\n",
|
125 |
+
" <th>2</th>\n",
|
126 |
+
" <td>Tài xế lái xe ô tô khách 50 chỗ ngồi bao lâu t...</td>\n",
|
127 |
+
" <td>[\"1. Sử dụng lái xe bảo đảm sức khỏe theo tiêu...</td>\n",
|
128 |
+
" <td>142107</td>\n",
|
129 |
+
" <td>[33215, 56201]</td>\n",
|
130 |
+
" </tr>\n",
|
131 |
+
" <tr>\n",
|
132 |
+
" <th>3</th>\n",
|
133 |
+
" <td>Các bước chuẩn bị thủ thuật bó bột Cravate sẽ ...</td>\n",
|
134 |
+
" <td>[BỘT CRAVATE\\n...\\nIV. CHUẨN BỊ\\n1. Người thực...</td>\n",
|
135 |
+
" <td>77353</td>\n",
|
136 |
+
" <td>[148158]</td>\n",
|
137 |
+
" </tr>\n",
|
138 |
+
" <tr>\n",
|
139 |
+
" <th>4</th>\n",
|
140 |
+
" <td>Viên chức Hộ sinh hạng 4 có những nhiệm vụ gì ...</td>\n",
|
141 |
+
" <td>[Hộ sinh hạng IV - Mã số: V.08.06.16\\n1. Nhiệm...</td>\n",
|
142 |
+
" <td>113090</td>\n",
|
143 |
+
" <td>[188132]</td>\n",
|
144 |
+
" </tr>\n",
|
145 |
+
" </tbody>\n",
|
146 |
+
"</table>\n",
|
147 |
+
"</div>"
|
148 |
+
],
|
149 |
+
"text/plain": [
|
150 |
+
" question \\\n",
|
151 |
+
"0 Liên đoàn Luật sư Việt Nam là tổ chức xã hội –... \n",
|
152 |
+
"1 Tên hợp tác xã bị rơi vào trường hợp cấm thì c... \n",
|
153 |
+
"2 Tài xế lái xe ô tô khách 50 chỗ ngồi bao lâu t... \n",
|
154 |
+
"3 Các bước chuẩn bị thủ thuật bó bột Cravate sẽ ... \n",
|
155 |
+
"4 Viên chức Hộ sinh hạng 4 có những nhiệm vụ gì ... \n",
|
156 |
+
"\n",
|
157 |
+
" context_list qid cid \n",
|
158 |
+
"0 [“Điều 2. Địa vị pháp lý của Liên đoàn Luật sư... 72600 [142820] \n",
|
159 |
+
"1 [\"Điều 7. Tên hợp tác xã, liên hiệp hợp tác xã... 147562 [27817, 72117] \n",
|
160 |
+
"2 [\"1. Sử dụng lái xe bảo đảm sức khỏe theo tiêu... 142107 [33215, 56201] \n",
|
161 |
+
"3 [BỘT CRAVATE\\n...\\nIV. CHUẨN BỊ\\n1. Người thực... 77353 [148158] \n",
|
162 |
+
"4 [Hộ sinh hạng IV - Mã số: V.08.06.16\\n1. Nhiệm... 113090 [188132] "
|
163 |
+
]
|
164 |
+
},
|
165 |
+
"execution_count": 4,
|
166 |
+
"metadata": {},
|
167 |
+
"output_type": "execute_result"
|
168 |
+
}
|
169 |
+
],
|
170 |
+
"source": [
|
171 |
+
"data['train'].head()"
|
172 |
+
]
|
173 |
+
},
|
174 |
+
{
|
175 |
+
"cell_type": "code",
|
176 |
+
"execution_count": 5,
|
177 |
+
"id": "943bf8ce",
|
178 |
+
"metadata": {},
|
179 |
+
"outputs": [
|
180 |
+
{
|
181 |
+
"name": "stdout",
|
182 |
+
"output_type": "stream",
|
183 |
+
"text": [
|
184 |
+
"question <class 'str'>\n",
|
185 |
+
"context_list <class 'list'>\n",
|
186 |
+
"qid <class 'numpy.int64'>\n",
|
187 |
+
"cid <class 'list'>\n",
|
188 |
+
"True\n"
|
189 |
+
]
|
190 |
+
}
|
191 |
+
],
|
192 |
+
"source": [
|
193 |
+
"# Debug\n",
|
194 |
+
"for col in data['test'].columns:\n",
|
195 |
+
" print(col, type(data['test'][col][0]))\n",
|
196 |
+
" \n",
|
197 |
+
"print((data['test']['cid'].apply(len) == data['test']['context_list'].apply(len)).all())"
|
198 |
+
]
|
199 |
+
},
|
200 |
+
{
|
201 |
+
"cell_type": "code",
|
202 |
+
"execution_count": 6,
|
203 |
+
"id": "2c751cf4",
|
204 |
+
"metadata": {},
|
205 |
+
"outputs": [
|
206 |
+
{
|
207 |
+
"data": {
|
208 |
+
"application/vnd.jupyter.widget-view+json": {
|
209 |
+
"model_id": "509893cf5cfd4a8d9e18bba47561a41c",
|
210 |
+
"version_major": 2,
|
211 |
+
"version_minor": 0
|
212 |
+
},
|
213 |
+
"text/plain": [
|
214 |
+
"Processing train: 0%| | 0/89162 [00:00<?, ?rows/s]"
|
215 |
+
]
|
216 |
+
},
|
217 |
+
"metadata": {},
|
218 |
+
"output_type": "display_data"
|
219 |
+
},
|
220 |
+
{
|
221 |
+
"data": {
|
222 |
+
"application/vnd.jupyter.widget-view+json": {
|
223 |
+
"model_id": "12f4fcee4e4244128d8fb472881862ae",
|
224 |
+
"version_major": 2,
|
225 |
+
"version_minor": 0
|
226 |
+
},
|
227 |
+
"text/plain": [
|
228 |
+
"Processing test: 0%| | 0/29723 [00:00<?, ?rows/s]"
|
229 |
+
]
|
230 |
+
},
|
231 |
+
"metadata": {},
|
232 |
+
"output_type": "display_data"
|
233 |
+
},
|
234 |
+
{
|
235 |
+
"name": "stdout",
|
236 |
+
"output_type": "stream",
|
237 |
+
"text": [
|
238 |
+
"Training examples: 99580\n"
|
239 |
+
]
|
240 |
+
}
|
241 |
+
],
|
242 |
+
"source": [
|
243 |
+
"for split in ['train', 'test']:\n",
|
244 |
+
" rows = list(data[split].itertuples(index=False))\n",
|
245 |
+
" \n",
|
246 |
+
" for row in tqdm(rows, desc=f\"Processing {split}\", unit='rows'):\n",
|
247 |
+
" q = row.question\n",
|
248 |
+
" for c in row.context_list:\n",
|
249 |
+
" examples[split].append(InputExample(texts=[q, c]))\n",
|
250 |
+
"\n",
|
251 |
+
"print(f\"Training examples: {len(examples['train'])}\") # Compare with sum(data['train']['cid'].apply(len))"
|
252 |
+
]
|
253 |
+
},
|
254 |
+
{
|
255 |
+
"cell_type": "code",
|
256 |
+
"execution_count": 7,
|
257 |
+
"id": "aadda6e7",
|
258 |
+
"metadata": {},
|
259 |
+
"outputs": [],
|
260 |
+
"source": [
|
261 |
+
"embedding_model = Transformer(MODEL_ID, max_seq_length=MAX_SEQ_LEN, cache_dir=CACHE_DIR)\n",
|
262 |
+
"pooling_model = Pooling(\n",
|
263 |
+
" embedding_model.get_word_embedding_dimension(), \n",
|
264 |
+
" pooling_mode_mean_tokens=True\n",
|
265 |
+
")\n",
|
266 |
+
"\n",
|
267 |
+
"model = SentenceTransformer(\n",
|
268 |
+
" modules=[embedding_model, pooling_model], device=DEVICE, \n",
|
269 |
+
" cache_folder=CACHE_DIR,\n",
|
270 |
+
" model_card_data=SentenceTransformerModelCardData(\n",
|
271 |
+
" model_id=MODEL_ID, \n",
|
272 |
+
" model_name=MODEL_NAME, \n",
|
273 |
+
" language='vi',\n",
|
274 |
+
" license='mit',\n",
|
275 |
+
" )\n",
|
276 |
+
")"
|
277 |
+
]
|
278 |
+
},
|
279 |
+
{
|
280 |
+
"cell_type": "code",
|
281 |
+
"execution_count": null,
|
282 |
+
"id": "8967eb55",
|
283 |
+
"metadata": {},
|
284 |
+
"outputs": [
|
285 |
+
{
|
286 |
+
"name": "stderr",
|
287 |
+
"output_type": "stream",
|
288 |
+
"text": [
|
289 |
+
"Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).\n"
|
290 |
+
]
|
291 |
+
}
|
292 |
+
],
|
293 |
+
"source": [
|
294 |
+
"loss = CachedMultipleNegativesRankingLoss(model=model)\n",
|
295 |
+
"\n",
|
296 |
+
"args = SentenceTransformerTrainingArguments(\n",
|
297 |
+
" output_dir=OUTPUT_DIR,\n",
|
298 |
+
" num_train_epochs=EPOCHS,\n",
|
299 |
+
" per_device_train_batch_size=BATCH_SIZE,\n",
|
300 |
+
" learning_rate=LR,\n",
|
301 |
+
" warmup_ratio=0.1,\n",
|
302 |
+
" fp16=True,\n",
|
303 |
+
" batch_sampler=BatchSamplers.NO_DUPLICATES,\n",
|
304 |
+
" logging_steps=100\n",
|
305 |
+
")"
|
306 |
+
]
|
307 |
+
},
|
308 |
+
{
|
309 |
+
"cell_type": "code",
|
310 |
+
"execution_count": 9,
|
311 |
+
"id": "8bb935fe",
|
312 |
+
"metadata": {},
|
313 |
+
"outputs": [
|
314 |
+
{
|
315 |
+
"name": "stderr",
|
316 |
+
"output_type": "stream",
|
317 |
+
"text": [
|
318 |
+
"Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).\n"
|
319 |
+
]
|
320 |
+
},
|
321 |
+
{
|
322 |
+
"data": {
|
323 |
+
"application/vnd.jupyter.widget-view+json": {
|
324 |
+
"model_id": "3d68dc4ff84244488d9de723e68b37ca",
|
325 |
+
"version_major": 2,
|
326 |
+
"version_minor": 0
|
327 |
+
},
|
328 |
+
"text/plain": [
|
329 |
+
"Computing widget examples: 0%| | 0/1 [00:00<?, ?example/s]"
|
330 |
+
]
|
331 |
+
},
|
332 |
+
"metadata": {},
|
333 |
+
"output_type": "display_data"
|
334 |
+
},
|
335 |
+
{
|
336 |
+
"data": {
|
337 |
+
"text/html": [
|
338 |
+
"\n",
|
339 |
+
" <div>\n",
|
340 |
+
" \n",
|
341 |
+
" <progress value='3890' max='3890' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
342 |
+
" [3890/3890 3:32:33, Epoch 5/5]\n",
|
343 |
+
" </div>\n",
|
344 |
+
" <table border=\"1\" class=\"dataframe\">\n",
|
345 |
+
" <thead>\n",
|
346 |
+
" <tr style=\"text-align: left;\">\n",
|
347 |
+
" <th>Step</th>\n",
|
348 |
+
" <th>Training Loss</th>\n",
|
349 |
+
" </tr>\n",
|
350 |
+
" </thead>\n",
|
351 |
+
" <tbody>\n",
|
352 |
+
" <tr>\n",
|
353 |
+
" <td>100</td>\n",
|
354 |
+
" <td>1.882700</td>\n",
|
355 |
+
" </tr>\n",
|
356 |
+
" <tr>\n",
|
357 |
+
" <td>200</td>\n",
|
358 |
+
" <td>0.442800</td>\n",
|
359 |
+
" </tr>\n",
|
360 |
+
" <tr>\n",
|
361 |
+
" <td>300</td>\n",
|
362 |
+
" <td>0.356400</td>\n",
|
363 |
+
" </tr>\n",
|
364 |
+
" <tr>\n",
|
365 |
+
" <td>400</td>\n",
|
366 |
+
" <td>0.285600</td>\n",
|
367 |
+
" </tr>\n",
|
368 |
+
" <tr>\n",
|
369 |
+
" <td>500</td>\n",
|
370 |
+
" <td>0.244500</td>\n",
|
371 |
+
" </tr>\n",
|
372 |
+
" <tr>\n",
|
373 |
+
" <td>600</td>\n",
|
374 |
+
" <td>0.224100</td>\n",
|
375 |
+
" </tr>\n",
|
376 |
+
" <tr>\n",
|
377 |
+
" <td>700</td>\n",
|
378 |
+
" <td>0.193800</td>\n",
|
379 |
+
" </tr>\n",
|
380 |
+
" <tr>\n",
|
381 |
+
" <td>800</td>\n",
|
382 |
+
" <td>0.189400</td>\n",
|
383 |
+
" </tr>\n",
|
384 |
+
" <tr>\n",
|
385 |
+
" <td>900</td>\n",
|
386 |
+
" <td>0.143200</td>\n",
|
387 |
+
" </tr>\n",
|
388 |
+
" <tr>\n",
|
389 |
+
" <td>1000</td>\n",
|
390 |
+
" <td>0.143200</td>\n",
|
391 |
+
" </tr>\n",
|
392 |
+
" <tr>\n",
|
393 |
+
" <td>1100</td>\n",
|
394 |
+
" <td>0.134100</td>\n",
|
395 |
+
" </tr>\n",
|
396 |
+
" <tr>\n",
|
397 |
+
" <td>1200</td>\n",
|
398 |
+
" <td>0.131100</td>\n",
|
399 |
+
" </tr>\n",
|
400 |
+
" <tr>\n",
|
401 |
+
" <td>1300</td>\n",
|
402 |
+
" <td>0.124900</td>\n",
|
403 |
+
" </tr>\n",
|
404 |
+
" <tr>\n",
|
405 |
+
" <td>1400</td>\n",
|
406 |
+
" <td>0.122700</td>\n",
|
407 |
+
" </tr>\n",
|
408 |
+
" <tr>\n",
|
409 |
+
" <td>1500</td>\n",
|
410 |
+
" <td>0.124100</td>\n",
|
411 |
+
" </tr>\n",
|
412 |
+
" <tr>\n",
|
413 |
+
" <td>1600</td>\n",
|
414 |
+
" <td>0.102800</td>\n",
|
415 |
+
" </tr>\n",
|
416 |
+
" <tr>\n",
|
417 |
+
" <td>1700</td>\n",
|
418 |
+
" <td>0.085200</td>\n",
|
419 |
+
" </tr>\n",
|
420 |
+
" <tr>\n",
|
421 |
+
" <td>1800</td>\n",
|
422 |
+
" <td>0.085000</td>\n",
|
423 |
+
" </tr>\n",
|
424 |
+
" <tr>\n",
|
425 |
+
" <td>1900</td>\n",
|
426 |
+
" <td>0.082000</td>\n",
|
427 |
+
" </tr>\n",
|
428 |
+
" <tr>\n",
|
429 |
+
" <td>2000</td>\n",
|
430 |
+
" <td>0.080000</td>\n",
|
431 |
+
" </tr>\n",
|
432 |
+
" <tr>\n",
|
433 |
+
" <td>2100</td>\n",
|
434 |
+
" <td>0.082400</td>\n",
|
435 |
+
" </tr>\n",
|
436 |
+
" <tr>\n",
|
437 |
+
" <td>2200</td>\n",
|
438 |
+
" <td>0.080200</td>\n",
|
439 |
+
" </tr>\n",
|
440 |
+
" <tr>\n",
|
441 |
+
" <td>2300</td>\n",
|
442 |
+
" <td>0.082200</td>\n",
|
443 |
+
" </tr>\n",
|
444 |
+
" <tr>\n",
|
445 |
+
" <td>2400</td>\n",
|
446 |
+
" <td>0.063300</td>\n",
|
447 |
+
" </tr>\n",
|
448 |
+
" <tr>\n",
|
449 |
+
" <td>2500</td>\n",
|
450 |
+
" <td>0.061500</td>\n",
|
451 |
+
" </tr>\n",
|
452 |
+
" <tr>\n",
|
453 |
+
" <td>2600</td>\n",
|
454 |
+
" <td>0.061200</td>\n",
|
455 |
+
" </tr>\n",
|
456 |
+
" <tr>\n",
|
457 |
+
" <td>2700</td>\n",
|
458 |
+
" <td>0.058000</td>\n",
|
459 |
+
" </tr>\n",
|
460 |
+
" <tr>\n",
|
461 |
+
" <td>2800</td>\n",
|
462 |
+
" <td>0.056600</td>\n",
|
463 |
+
" </tr>\n",
|
464 |
+
" <tr>\n",
|
465 |
+
" <td>2900</td>\n",
|
466 |
+
" <td>0.052100</td>\n",
|
467 |
+
" </tr>\n",
|
468 |
+
" <tr>\n",
|
469 |
+
" <td>3000</td>\n",
|
470 |
+
" <td>0.054800</td>\n",
|
471 |
+
" </tr>\n",
|
472 |
+
" <tr>\n",
|
473 |
+
" <td>3100</td>\n",
|
474 |
+
" <td>0.054700</td>\n",
|
475 |
+
" </tr>\n",
|
476 |
+
" <tr>\n",
|
477 |
+
" <td>3200</td>\n",
|
478 |
+
" <td>0.047900</td>\n",
|
479 |
+
" </tr>\n",
|
480 |
+
" <tr>\n",
|
481 |
+
" <td>3300</td>\n",
|
482 |
+
" <td>0.044900</td>\n",
|
483 |
+
" </tr>\n",
|
484 |
+
" <tr>\n",
|
485 |
+
" <td>3400</td>\n",
|
486 |
+
" <td>0.044000</td>\n",
|
487 |
+
" </tr>\n",
|
488 |
+
" <tr>\n",
|
489 |
+
" <td>3500</td>\n",
|
490 |
+
" <td>0.043900</td>\n",
|
491 |
+
" </tr>\n",
|
492 |
+
" <tr>\n",
|
493 |
+
" <td>3600</td>\n",
|
494 |
+
" <td>0.044400</td>\n",
|
495 |
+
" </tr>\n",
|
496 |
+
" <tr>\n",
|
497 |
+
" <td>3700</td>\n",
|
498 |
+
" <td>0.045700</td>\n",
|
499 |
+
" </tr>\n",
|
500 |
+
" <tr>\n",
|
501 |
+
" <td>3800</td>\n",
|
502 |
+
" <td>0.046100</td>\n",
|
503 |
+
" </tr>\n",
|
504 |
+
" </tbody>\n",
|
505 |
+
"</table><p>"
|
506 |
+
],
|
507 |
+
"text/plain": [
|
508 |
+
"<IPython.core.display.HTML object>"
|
509 |
+
]
|
510 |
+
},
|
511 |
+
"metadata": {},
|
512 |
+
"output_type": "display_data"
|
513 |
+
},
|
514 |
+
{
|
515 |
+
"data": {
|
516 |
+
"text/plain": [
|
517 |
+
"TrainOutput(global_step=3890, training_loss=0.1604946916084976, metrics={'train_runtime': 12756.5123, 'train_samples_per_second': 39.031, 'train_steps_per_second': 0.305, 'total_flos': 0.0, 'train_loss': 0.1604946916084976, 'epoch': 5.0})"
|
518 |
+
]
|
519 |
+
},
|
520 |
+
"execution_count": 9,
|
521 |
+
"metadata": {},
|
522 |
+
"output_type": "execute_result"
|
523 |
+
}
|
524 |
+
],
|
525 |
+
"source": [
|
526 |
+
"def to_frame(ex_list):\n",
|
527 |
+
" rows = [(ex.texts[0], ex.texts[1]) for ex in ex_list]\n",
|
528 |
+
" return pd.DataFrame(rows, columns=['text_0', 'text_1'])\n",
|
529 |
+
"\n",
|
530 |
+
"train_ds = Dataset.from_pandas(to_frame(examples['train']))\n",
|
531 |
+
"\n",
|
532 |
+
"trainer = SentenceTransformerTrainer(\n",
|
533 |
+
" model=model,\n",
|
534 |
+
" args=args,\n",
|
535 |
+
" train_dataset=train_ds,\n",
|
536 |
+
" loss=loss,\n",
|
537 |
+
")\n",
|
538 |
+
"trainer.train()"
|
539 |
+
]
|
540 |
+
},
|
541 |
+
{
|
542 |
+
"cell_type": "code",
|
543 |
+
"execution_count": null,
|
544 |
+
"id": "f47a01a1",
|
545 |
+
"metadata": {},
|
546 |
+
"outputs": [],
|
547 |
+
"source": [
|
548 |
+
"model.save_pretrained(OUTPUT_DIR)\n",
|
549 |
+
"# model.push_to_hub(\n",
|
550 |
+
"# repo_id='YuITC/bert-base-multilingual-cased-finetuned-VNLegalDocs', \n",
|
551 |
+
"# commit_message='Update README.md',\n",
|
552 |
+
"# exist_ok=True,\n",
|
553 |
+
"# replace_model_card=False,\n",
|
554 |
+
"# train_datasets=['tmnam20/BKAI-Legal-Retrieval']\n",
|
555 |
+
"# )"
|
556 |
+
]
|
557 |
+
}
|
558 |
+
],
|
559 |
+
"metadata": {
|
560 |
+
"kernelspec": {
|
561 |
+
"display_name": "legal_doc_retrieval",
|
562 |
+
"language": "python",
|
563 |
+
"name": "python3"
|
564 |
+
},
|
565 |
+
"language_info": {
|
566 |
+
"codemirror_mode": {
|
567 |
+
"name": "ipython",
|
568 |
+
"version": 3
|
569 |
+
},
|
570 |
+
"file_extension": ".py",
|
571 |
+
"mimetype": "text/x-python",
|
572 |
+
"name": "python",
|
573 |
+
"nbconvert_exporter": "python",
|
574 |
+
"pygments_lexer": "ipython3",
|
575 |
+
"version": "3.10.16"
|
576 |
+
}
|
577 |
+
},
|
578 |
+
"nbformat": 4,
|
579 |
+
"nbformat_minor": 5
|
580 |
+
}
|
step_03_Eval_with_MTEB.ipynb
ADDED
@@ -0,0 +1,479 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"id": "b41fd227",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [
|
9 |
+
{
|
10 |
+
"name": "stdout",
|
11 |
+
"output_type": "stream",
|
12 |
+
"text": [
|
13 |
+
"Using device: cuda\n"
|
14 |
+
]
|
15 |
+
}
|
16 |
+
],
|
17 |
+
"source": [
|
18 |
+
"!python settings.py"
|
19 |
+
]
|
20 |
+
},
|
21 |
+
{
|
22 |
+
"cell_type": "code",
|
23 |
+
"execution_count": 2,
|
24 |
+
"id": "b5fd917b",
|
25 |
+
"metadata": {},
|
26 |
+
"outputs": [
|
27 |
+
{
|
28 |
+
"name": "stdout",
|
29 |
+
"output_type": "stream",
|
30 |
+
"text": [
|
31 |
+
"📦 PyTorch version: 2.5.1\n",
|
32 |
+
"🚀 CUDA available : True\n",
|
33 |
+
"🧠 GPU Name : NVIDIA RTX A4000\n",
|
34 |
+
"📦 FAISS version : 1.9.0\n",
|
35 |
+
"🚀 FAISS is using GPU ✅\n"
|
36 |
+
]
|
37 |
+
}
|
38 |
+
],
|
39 |
+
"source": [
|
40 |
+
"import torch\n",
|
41 |
+
"\n",
|
42 |
+
"print(\"📦 PyTorch version:\", torch.__version__)\n",
|
43 |
+
"print(\"🚀 CUDA available :\", torch.cuda.is_available())\n",
|
44 |
+
"if torch.cuda.is_available():\n",
|
45 |
+
" print(\"🧠 GPU Name :\", torch.cuda.get_device_name(0))\n",
|
46 |
+
" \n",
|
47 |
+
"import faiss\n",
|
48 |
+
"\n",
|
49 |
+
"print(\"📦 FAISS version :\", faiss.__version__)\n",
|
50 |
+
"\n",
|
51 |
+
"# Kiểm tra module FAISS-GPU có hoạt động không\n",
|
52 |
+
"try:\n",
|
53 |
+
" res = faiss.StandardGpuResources() # Nếu không lỗi là có GPU\n",
|
54 |
+
" print(\"🚀 FAISS is using GPU ✅\")\n",
|
55 |
+
"except Exception as e:\n",
|
56 |
+
" print(\"❌ FAISS is NOT using GPU:\", str(e))"
|
57 |
+
]
|
58 |
+
},
|
59 |
+
{
|
60 |
+
"cell_type": "code",
|
61 |
+
"execution_count": 3,
|
62 |
+
"id": "030016c2",
|
63 |
+
"metadata": {},
|
64 |
+
"outputs": [
|
65 |
+
{
|
66 |
+
"name": "stderr",
|
67 |
+
"output_type": "stream",
|
68 |
+
"text": [
|
69 |
+
"C:\\Users\\Administrator\\AppData\\Local\\Temp\\2\\ipykernel_648\\3951191562.py:5: TqdmExperimentalWarning: Using `tqdm.autonotebook.tqdm` in notebook mode. Use `tqdm.tqdm` instead to force console mode (e.g. in jupyter console)\n",
|
70 |
+
" from tqdm.autonotebook import tqdm\n"
|
71 |
+
]
|
72 |
+
},
|
73 |
+
{
|
74 |
+
"name": "stdout",
|
75 |
+
"output_type": "stream",
|
76 |
+
"text": [
|
77 |
+
"Using device: cuda\n"
|
78 |
+
]
|
79 |
+
}
|
80 |
+
],
|
81 |
+
"source": [
|
82 |
+
"import os\n",
|
83 |
+
"import json\n",
|
84 |
+
"import pandas as pd\n",
|
85 |
+
"from pprint import pprint\n",
|
86 |
+
"from tqdm.autonotebook import tqdm\n",
|
87 |
+
"\n",
|
88 |
+
"from sentence_transformers import SentenceTransformer\n",
|
89 |
+
"from mteb import MTEB\n",
|
90 |
+
"from mteb.abstasks.TaskMetadata import TaskMetadata\n",
|
91 |
+
"from mteb.abstasks.AbsTaskRetrieval import AbsTaskRetrieval\n",
|
92 |
+
"\n",
|
93 |
+
"from settings import MODEL_NAME, OUTPUT_DIR, DEVICE, BATCH_SIZE\n",
|
94 |
+
"\n",
|
95 |
+
"os.environ['WANDB_DISABLED'] = 'true'"
|
96 |
+
]
|
97 |
+
},
|
98 |
+
{
|
99 |
+
"cell_type": "code",
|
100 |
+
"execution_count": 4,
|
101 |
+
"id": "dd3f53a3",
|
102 |
+
"metadata": {},
|
103 |
+
"outputs": [],
|
104 |
+
"source": [
|
105 |
+
"data = {\n",
|
106 |
+
" 'corpus': pd.read_parquet('data/processed/corpus_data.parquet'),\n",
|
107 |
+
" 'train' : pd.read_parquet('data/processed/train_data.parquet'),\n",
|
108 |
+
" 'test' : pd.read_parquet('data/processed/test_data.parquet')\n",
|
109 |
+
"}\n",
|
110 |
+
"for split in ['train', 'test']:\n",
|
111 |
+
" data[split]['cid'] = data[split]['cid'].apply(lambda x: x.tolist())\n",
|
112 |
+
" data[split]['context_list'] = data[split]['context_list'].apply(lambda x: x.tolist())"
|
113 |
+
]
|
114 |
+
},
|
115 |
+
{
|
116 |
+
"cell_type": "code",
|
117 |
+
"execution_count": 5,
|
118 |
+
"id": "41ffd5ce",
|
119 |
+
"metadata": {},
|
120 |
+
"outputs": [],
|
121 |
+
"source": [
|
122 |
+
"class BKAILegalDocRetrievalTask(AbsTaskRetrieval):\n",
|
123 |
+
" # Metadata definition used by MTEB benchmark\n",
|
124 |
+
" metadata = TaskMetadata(name='BKAILegalDocRetrieval',\n",
|
125 |
+
" description='',\n",
|
126 |
+
" reference='https://github.com/embeddings-benchmark/mteb/blob/main/docs/adding_a_dataset.md',\n",
|
127 |
+
" type='Retrieval',\n",
|
128 |
+
" category='s2p',\n",
|
129 |
+
" modalities=['text'],\n",
|
130 |
+
" eval_splits=['test'],\n",
|
131 |
+
" eval_langs=['vi'],\n",
|
132 |
+
" main_score='ndcg_at_10',\n",
|
133 |
+
" other_scores=['recall_at_10', 'precision_at_10', 'map'],\n",
|
134 |
+
" dataset={\n",
|
135 |
+
" 'path' : 'data',\n",
|
136 |
+
" 'revision': 'd4c5a8ba10ae71224752c727094ac4c46947fa29',\n",
|
137 |
+
" },\n",
|
138 |
+
" date=('2012-01-01', '2020-01-01'),\n",
|
139 |
+
" form='Written',\n",
|
140 |
+
" domains=['Academic', 'Non-fiction'],\n",
|
141 |
+
" task_subtypes=['Scientific Reranking'],\n",
|
142 |
+
" license='cc-by-nc-4.0',\n",
|
143 |
+
" annotations_creators='derived',\n",
|
144 |
+
" dialect=[],\n",
|
145 |
+
" text_creation='found',\n",
|
146 |
+
" bibtex_citation=''\n",
|
147 |
+
" )\n",
|
148 |
+
"\n",
|
149 |
+
" data_loaded = True # Flag\n",
|
150 |
+
"\n",
|
151 |
+
" def __init__(self, **kwargs):\n",
|
152 |
+
" super().__init__(**kwargs)\n",
|
153 |
+
"\n",
|
154 |
+
" self.corpus = {}\n",
|
155 |
+
" self.queries = {}\n",
|
156 |
+
" self.relevant_docs = {}\n",
|
157 |
+
"\n",
|
158 |
+
" shared_corpus = {}\n",
|
159 |
+
" for _, row in data['corpus'].iterrows():\n",
|
160 |
+
" shared_corpus[f\"c{row['cid']}\"] = {\n",
|
161 |
+
" 'text': row['text'],\n",
|
162 |
+
" '_id' : row['cid']\n",
|
163 |
+
" }\n",
|
164 |
+
" \n",
|
165 |
+
" for split in ['train', 'test']:\n",
|
166 |
+
" self.corpus[split] = shared_corpus\n",
|
167 |
+
" self.queries[split] = {}\n",
|
168 |
+
" self.relevant_docs[split] = {}\n",
|
169 |
+
"\n",
|
170 |
+
" for split in ['train', 'test']:\n",
|
171 |
+
" for _, row in data[split].iterrows():\n",
|
172 |
+
" qid, cids = row['qid'], row['cid']\n",
|
173 |
+
" \n",
|
174 |
+
" qid_str = f'q{qid}'\n",
|
175 |
+
" cids_str = [f'c{cid}' for cid in cids]\n",
|
176 |
+
" \n",
|
177 |
+
" self.queries[split][qid_str] = row['question']\n",
|
178 |
+
" \n",
|
179 |
+
" if qid_str not in self.relevant_docs[split]:\n",
|
180 |
+
" self.relevant_docs[split][qid_str] = {}\n",
|
181 |
+
" \n",
|
182 |
+
" for cid_str in cids_str:\n",
|
183 |
+
" self.relevant_docs[split][qid_str][cid_str] = 1\n",
|
184 |
+
" \n",
|
185 |
+
" self.data_loaded = True"
|
186 |
+
]
|
187 |
+
},
|
188 |
+
{
|
189 |
+
"cell_type": "code",
|
190 |
+
"execution_count": 6,
|
191 |
+
"id": "8c212fe9",
|
192 |
+
"metadata": {},
|
193 |
+
"outputs": [],
|
194 |
+
"source": [
|
195 |
+
"fine_tuned_model = SentenceTransformer(OUTPUT_DIR, device=DEVICE)"
|
196 |
+
]
|
197 |
+
},
|
198 |
+
{
|
199 |
+
"cell_type": "code",
|
200 |
+
"execution_count": 7,
|
201 |
+
"id": "aae09322",
|
202 |
+
"metadata": {},
|
203 |
+
"outputs": [
|
204 |
+
{
|
205 |
+
"name": "stderr",
|
206 |
+
"output_type": "stream",
|
207 |
+
"text": [
|
208 |
+
"The `batch_size` argument is deprecated and will be removed in the next release. Please use `encode_kwargs = {'batch_size': ...}` to set the batch size instead.\n"
|
209 |
+
]
|
210 |
+
},
|
211 |
+
{
|
212 |
+
"data": {
|
213 |
+
"text/html": [
|
214 |
+
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #262626; text-decoration-color: #262626\">───────────────────────────────────────────────── </span><span style=\"font-weight: bold\">Selected tasks </span><span style=\"color: #262626; text-decoration-color: #262626\"> ─────────────────────────────────────────────────</span>\n",
|
215 |
+
"</pre>\n"
|
216 |
+
],
|
217 |
+
"text/plain": [
|
218 |
+
"\u001b[38;5;235m───────────────────────────────────────────────── \u001b[0m\u001b[1mSelected tasks \u001b[0m\u001b[38;5;235m ─────────────────────────────────────────────────\u001b[0m\n"
|
219 |
+
]
|
220 |
+
},
|
221 |
+
"metadata": {},
|
222 |
+
"output_type": "display_data"
|
223 |
+
},
|
224 |
+
{
|
225 |
+
"data": {
|
226 |
+
"text/html": [
|
227 |
+
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Retrieval</span>\n",
|
228 |
+
"</pre>\n"
|
229 |
+
],
|
230 |
+
"text/plain": [
|
231 |
+
"\u001b[1mRetrieval\u001b[0m\n"
|
232 |
+
]
|
233 |
+
},
|
234 |
+
"metadata": {},
|
235 |
+
"output_type": "display_data"
|
236 |
+
},
|
237 |
+
{
|
238 |
+
"data": {
|
239 |
+
"text/html": [
|
240 |
+
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"> - BKAILegalDocRetrieval, <span style=\"color: #626262; text-decoration-color: #626262; font-style: italic\">s2p</span>\n",
|
241 |
+
"</pre>\n"
|
242 |
+
],
|
243 |
+
"text/plain": [
|
244 |
+
" - BKAILegalDocRetrieval, \u001b[3;38;5;241ms2p\u001b[0m\n"
|
245 |
+
]
|
246 |
+
},
|
247 |
+
"metadata": {},
|
248 |
+
"output_type": "display_data"
|
249 |
+
},
|
250 |
+
{
|
251 |
+
"data": {
|
252 |
+
"text/html": [
|
253 |
+
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
|
254 |
+
"\n",
|
255 |
+
"</pre>\n"
|
256 |
+
],
|
257 |
+
"text/plain": [
|
258 |
+
"\n",
|
259 |
+
"\n"
|
260 |
+
]
|
261 |
+
},
|
262 |
+
"metadata": {},
|
263 |
+
"output_type": "display_data"
|
264 |
+
},
|
265 |
+
{
|
266 |
+
"data": {
|
267 |
+
"application/vnd.jupyter.widget-view+json": {
|
268 |
+
"model_id": "53778754caf4456f8e140cfa58b60709",
|
269 |
+
"version_major": 2,
|
270 |
+
"version_minor": 0
|
271 |
+
},
|
272 |
+
"text/plain": [
|
273 |
+
"Batches: 0%| | 0/233 [00:00<?, ?it/s]"
|
274 |
+
]
|
275 |
+
},
|
276 |
+
"metadata": {},
|
277 |
+
"output_type": "display_data"
|
278 |
+
},
|
279 |
+
{
|
280 |
+
"data": {
|
281 |
+
"application/vnd.jupyter.widget-view+json": {
|
282 |
+
"model_id": "f9b27ae885fc46ad83f332f222a76381",
|
283 |
+
"version_major": 2,
|
284 |
+
"version_minor": 0
|
285 |
+
},
|
286 |
+
"text/plain": [
|
287 |
+
"Batches: 0%| | 0/391 [00:00<?, ?it/s]"
|
288 |
+
]
|
289 |
+
},
|
290 |
+
"metadata": {},
|
291 |
+
"output_type": "display_data"
|
292 |
+
},
|
293 |
+
{
|
294 |
+
"data": {
|
295 |
+
"application/vnd.jupyter.widget-view+json": {
|
296 |
+
"model_id": "0b6e38b0d54a4b429db05158604d24a5",
|
297 |
+
"version_major": 2,
|
298 |
+
"version_minor": 0
|
299 |
+
},
|
300 |
+
"text/plain": [
|
301 |
+
"Batches: 0%| | 0/391 [00:00<?, ?it/s]"
|
302 |
+
]
|
303 |
+
},
|
304 |
+
"metadata": {},
|
305 |
+
"output_type": "display_data"
|
306 |
+
},
|
307 |
+
{
|
308 |
+
"data": {
|
309 |
+
"application/vnd.jupyter.widget-view+json": {
|
310 |
+
"model_id": "20ec5df7261c43a7921abc968cc5e3a6",
|
311 |
+
"version_major": 2,
|
312 |
+
"version_minor": 0
|
313 |
+
},
|
314 |
+
"text/plain": [
|
315 |
+
"Batches: 0%| | 0/391 [00:02<?, ?it/s]"
|
316 |
+
]
|
317 |
+
},
|
318 |
+
"metadata": {},
|
319 |
+
"output_type": "display_data"
|
320 |
+
},
|
321 |
+
{
|
322 |
+
"data": {
|
323 |
+
"application/vnd.jupyter.widget-view+json": {
|
324 |
+
"model_id": "5f365f06d3de4becb965adb801aeee60",
|
325 |
+
"version_major": 2,
|
326 |
+
"version_minor": 0
|
327 |
+
},
|
328 |
+
"text/plain": [
|
329 |
+
"Batches: 0%| | 0/391 [00:00<?, ?it/s]"
|
330 |
+
]
|
331 |
+
},
|
332 |
+
"metadata": {},
|
333 |
+
"output_type": "display_data"
|
334 |
+
},
|
335 |
+
{
|
336 |
+
"data": {
|
337 |
+
"application/vnd.jupyter.widget-view+json": {
|
338 |
+
"model_id": "a43b764ac83e43aeb754c1e60771fd5c",
|
339 |
+
"version_major": 2,
|
340 |
+
"version_minor": 0
|
341 |
+
},
|
342 |
+
"text/plain": [
|
343 |
+
"Batches: 0%| | 0/391 [00:02<?, ?it/s]"
|
344 |
+
]
|
345 |
+
},
|
346 |
+
"metadata": {},
|
347 |
+
"output_type": "display_data"
|
348 |
+
},
|
349 |
+
{
|
350 |
+
"data": {
|
351 |
+
"application/vnd.jupyter.widget-view+json": {
|
352 |
+
"model_id": "ae46c8f76bc64eac8ca475d13f312875",
|
353 |
+
"version_major": 2,
|
354 |
+
"version_minor": 0
|
355 |
+
},
|
356 |
+
"text/plain": [
|
357 |
+
"Batches: 0%| | 0/91 [00:02<?, ?it/s]"
|
358 |
+
]
|
359 |
+
},
|
360 |
+
"metadata": {},
|
361 |
+
"output_type": "display_data"
|
362 |
+
},
|
363 |
+
{
|
364 |
+
"data": {
|
365 |
+
"text/plain": [
|
366 |
+
"[TaskResult(task_name=BKAILegalDocRetrieval, scores=...)]"
|
367 |
+
]
|
368 |
+
},
|
369 |
+
"execution_count": 7,
|
370 |
+
"metadata": {},
|
371 |
+
"output_type": "execute_result"
|
372 |
+
}
|
373 |
+
],
|
374 |
+
"source": [
|
375 |
+
"custom_task = BKAILegalDocRetrievalTask()\n",
|
376 |
+
"evaluation = MTEB(tasks=[custom_task])\n",
|
377 |
+
"evaluation.run(fine_tuned_model, batch_size=BATCH_SIZE)"
|
378 |
+
]
|
379 |
+
},
|
380 |
+
{
|
381 |
+
"cell_type": "code",
|
382 |
+
"execution_count": 8,
|
383 |
+
"id": "004e6930",
|
384 |
+
"metadata": {},
|
385 |
+
"outputs": [
|
386 |
+
{
|
387 |
+
"name": "stdout",
|
388 |
+
"output_type": "stream",
|
389 |
+
"text": [
|
390 |
+
"Main Evaluation Metrics (Top-K = 10):\n",
|
391 |
+
"{'evaluation_time (s)': 3061.7869832515717,\n",
|
392 |
+
" 'main_score': 0.60389,\n",
|
393 |
+
" 'mrr@10': 0.555102,\n",
|
394 |
+
" 'precision@10': 0.08587,\n",
|
395 |
+
" 'recall@10': 0.79407}\n"
|
396 |
+
]
|
397 |
+
}
|
398 |
+
],
|
399 |
+
"source": [
|
400 |
+
"file_path = f\"results/no_model_name_available/no_revision_available/BKAILegalDocRetrieval.json\"\n",
|
401 |
+
"\n",
|
402 |
+
"with open(file_path, 'r', encoding='utf-8') as f:\n",
|
403 |
+
" eval_data = json.load(f)\n",
|
404 |
+
"\n",
|
405 |
+
"scores = eval_data[\"scores\"][\"test\"][0]\n",
|
406 |
+
"main_metrics = {\n",
|
407 |
+
" 'main_score' : scores.get('ndcg_at_10'),\n",
|
408 |
+
" 'recall@10' : scores.get('recall_at_10'),\n",
|
409 |
+
" 'precision@10' : scores.get('precision_at_10'),\n",
|
410 |
+
" 'mrr@10' : scores.get('mrr_at_10'),\n",
|
411 |
+
" 'evaluation_time (s)': eval_data.get('evaluation_time')\n",
|
412 |
+
"}\n",
|
413 |
+
"\n",
|
414 |
+
"print('Main Evaluation Metrics (Top-K = 10):')\n",
|
415 |
+
"pprint(main_metrics)"
|
416 |
+
]
|
417 |
+
},
|
418 |
+
{
|
419 |
+
"cell_type": "code",
|
420 |
+
"execution_count": 9,
|
421 |
+
"id": "672ebc32",
|
422 |
+
"metadata": {},
|
423 |
+
"outputs": [
|
424 |
+
{
|
425 |
+
"name": "stdout",
|
426 |
+
"output_type": "stream",
|
427 |
+
"text": [
|
428 |
+
"\n",
|
429 |
+
"Evaluation Scores by K:\n",
|
430 |
+
"metric map mrr ndcg precision recall\n",
|
431 |
+
"k \n",
|
432 |
+
"1 0.4033 0.4242 0.4242 0.4242 0.4033\n",
|
433 |
+
"3 0.5031 0.5247 0.5394 0.2215 0.6232\n",
|
434 |
+
"5 0.5230 0.5434 0.5739 0.1512 0.7047\n",
|
435 |
+
"10 0.5361 0.5551 0.6039 0.0859 0.7941\n",
|
436 |
+
"20 0.5414 0.5596 0.6216 0.0469 0.8611\n",
|
437 |
+
"100 0.5442 0.5617 0.6389 0.0104 0.9480\n",
|
438 |
+
"1000 0.5444 0.5619 0.6444 0.0011 0.9879\n"
|
439 |
+
]
|
440 |
+
}
|
441 |
+
],
|
442 |
+
"source": [
|
443 |
+
"metrics = {k: v for k, v in scores.items() if '_at_' in k and not k.startswith('nauc')}\n",
|
444 |
+
"\n",
|
445 |
+
"parsed_metrics = []\n",
|
446 |
+
"for key, value in metrics.items():\n",
|
447 |
+
" metric, at_k = key.split('_at_')\n",
|
448 |
+
" parsed_metrics.append({'metric': metric, 'k': int(at_k), 'score': value})\n",
|
449 |
+
"\n",
|
450 |
+
"df_metrics = pd.DataFrame(parsed_metrics).pivot(index='k', columns='metric', values='score')\n",
|
451 |
+
"df_metrics = df_metrics.sort_index()\n",
|
452 |
+
"\n",
|
453 |
+
"print(\"\\nEvaluation Scores by K:\")\n",
|
454 |
+
"print(df_metrics.round(4))"
|
455 |
+
]
|
456 |
+
}
|
457 |
+
],
|
458 |
+
"metadata": {
|
459 |
+
"kernelspec": {
|
460 |
+
"display_name": "legal_doc_retrieval",
|
461 |
+
"language": "python",
|
462 |
+
"name": "python3"
|
463 |
+
},
|
464 |
+
"language_info": {
|
465 |
+
"codemirror_mode": {
|
466 |
+
"name": "ipython",
|
467 |
+
"version": 3
|
468 |
+
},
|
469 |
+
"file_extension": ".py",
|
470 |
+
"mimetype": "text/x-python",
|
471 |
+
"name": "python",
|
472 |
+
"nbconvert_exporter": "python",
|
473 |
+
"pygments_lexer": "ipython3",
|
474 |
+
"version": "3.10.16"
|
475 |
+
}
|
476 |
+
},
|
477 |
+
"nbformat": 4,
|
478 |
+
"nbformat_minor": 5
|
479 |
+
}
|
step_04_Retrieval.ipynb
ADDED
@@ -0,0 +1,383 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 11,
|
6 |
+
"id": "1195e917",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [
|
9 |
+
{
|
10 |
+
"name": "stdout",
|
11 |
+
"output_type": "stream",
|
12 |
+
"text": [
|
13 |
+
"Using device: cuda\n"
|
14 |
+
]
|
15 |
+
}
|
16 |
+
],
|
17 |
+
"source": [
|
18 |
+
"!python settings.py"
|
19 |
+
]
|
20 |
+
},
|
21 |
+
{
|
22 |
+
"cell_type": "code",
|
23 |
+
"execution_count": 12,
|
24 |
+
"id": "01589fc8",
|
25 |
+
"metadata": {},
|
26 |
+
"outputs": [],
|
27 |
+
"source": [
|
28 |
+
"import os\n",
|
29 |
+
"import numpy as np\n",
|
30 |
+
"import pandas as pd\n",
|
31 |
+
"from tqdm.autonotebook import tqdm\n",
|
32 |
+
"\n",
|
33 |
+
"import faiss\n",
|
34 |
+
"from sentence_transformers import SentenceTransformer, CrossEncoder\n",
|
35 |
+
"\n",
|
36 |
+
"from settings import OUTPUT_DIR, DEVICE\n",
|
37 |
+
"\n",
|
38 |
+
"os.environ['WANDB_DISABLED'] = 'true'\n",
|
39 |
+
"\n",
|
40 |
+
"from transformers import logging\n",
|
41 |
+
"logging.set_verbosity_error()"
|
42 |
+
]
|
43 |
+
},
|
44 |
+
{
|
45 |
+
"cell_type": "code",
|
46 |
+
"execution_count": 13,
|
47 |
+
"id": "057e852f",
|
48 |
+
"metadata": {},
|
49 |
+
"outputs": [],
|
50 |
+
"source": [
|
51 |
+
"# data = {\n",
|
52 |
+
"# 'corpus': pd.read_parquet('data/processed/corpus_data.parquet'),\n",
|
53 |
+
"# 'train' : pd.read_parquet('data/processed/train_data.parquet'),\n",
|
54 |
+
"# 'test' : pd.read_parquet('data/processed/test_data.parquet')\n",
|
55 |
+
"# }\n",
|
56 |
+
"# for split in ['train', 'test']:\n",
|
57 |
+
"# data[split]['cid'] = data[split]['cid'].apply(lambda x: x.tolist())\n",
|
58 |
+
"# data[split]['context_list'] = data[split]['context_list'].apply(lambda x: x.tolist())"
|
59 |
+
]
|
60 |
+
},
|
61 |
+
{
|
62 |
+
"cell_type": "code",
|
63 |
+
"execution_count": 14,
|
64 |
+
"id": "5634b72a",
|
65 |
+
"metadata": {},
|
66 |
+
"outputs": [
|
67 |
+
{
|
68 |
+
"data": {
|
69 |
+
"text/plain": [
|
70 |
+
"SentenceTransformer(\n",
|
71 |
+
" (0): Transformer({'max_seq_length': 512, 'do_lower_case': False}) with Transformer model: BertModel \n",
|
72 |
+
" (1): Pooling({'word_embedding_dimension': 768, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False, 'pooling_mode_weightedmean_tokens': False, 'pooling_mode_lasttoken': False, 'include_prompt': True})\n",
|
73 |
+
")"
|
74 |
+
]
|
75 |
+
},
|
76 |
+
"execution_count": 14,
|
77 |
+
"metadata": {},
|
78 |
+
"output_type": "execute_result"
|
79 |
+
}
|
80 |
+
],
|
81 |
+
"source": [
|
82 |
+
"fine_tuned_model = SentenceTransformer(OUTPUT_DIR, device=DEVICE)\n",
|
83 |
+
"fine_tuned_model.half()"
|
84 |
+
]
|
85 |
+
},
|
86 |
+
{
|
87 |
+
"cell_type": "code",
|
88 |
+
"execution_count": 15,
|
89 |
+
"id": "62cc0ead",
|
90 |
+
"metadata": {},
|
91 |
+
"outputs": [],
|
92 |
+
"source": [
|
93 |
+
"passages = pd.read_parquet('data/processed/corpus_data.parquet')['text'].tolist()\n",
|
94 |
+
"# corpus_embeddings = fine_tuned_model.encode(\n",
|
95 |
+
"# passages, \n",
|
96 |
+
"# batch_size=128,\n",
|
97 |
+
"# convert_to_numpy=True, \n",
|
98 |
+
"# normalize_embeddings=True,\n",
|
99 |
+
"# show_progress_bar=True, \n",
|
100 |
+
"# device=DEVICE,\n",
|
101 |
+
"# ).astype(np.float32)"
|
102 |
+
]
|
103 |
+
},
|
104 |
+
{
|
105 |
+
"cell_type": "code",
|
106 |
+
"execution_count": 16,
|
107 |
+
"id": "465e8d2a",
|
108 |
+
"metadata": {},
|
109 |
+
"outputs": [],
|
110 |
+
"source": [
|
111 |
+
"# d = corpus_embeddings.shape[1] # 768\n",
|
112 |
+
"# cpu_index = faiss.IndexFlatIP(d)\n",
|
113 |
+
"\n",
|
114 |
+
"# res = faiss.StandardGpuResources()\n",
|
115 |
+
"# gpu_index = faiss.index_cpu_to_gpu(res, 0, cpu_index)\n",
|
116 |
+
"# gpu_index.add(corpus_embeddings)"
|
117 |
+
]
|
118 |
+
},
|
119 |
+
{
|
120 |
+
"cell_type": "code",
|
121 |
+
"execution_count": 17,
|
122 |
+
"id": "af365371",
|
123 |
+
"metadata": {},
|
124 |
+
"outputs": [],
|
125 |
+
"source": [
|
126 |
+
"# final_cpu_index = faiss.index_gpu_to_cpu(gpu_index)\n",
|
127 |
+
"# faiss.write_index(final_cpu_index, 'data/retrieval/legal_faiss.index')"
|
128 |
+
]
|
129 |
+
},
|
130 |
+
{
|
131 |
+
"cell_type": "code",
|
132 |
+
"execution_count": 18,
|
133 |
+
"id": "9251d0db",
|
134 |
+
"metadata": {},
|
135 |
+
"outputs": [],
|
136 |
+
"source": [
|
137 |
+
"legal_index = faiss.read_index('data/retrieval/legal_faiss.index')"
|
138 |
+
]
|
139 |
+
},
|
140 |
+
{
|
141 |
+
"cell_type": "code",
|
142 |
+
"execution_count": 19,
|
143 |
+
"id": "9f54c596",
|
144 |
+
"metadata": {},
|
145 |
+
"outputs": [],
|
146 |
+
"source": [
|
147 |
+
"def retrieval(emb_model, query, index, top_k=10):\n",
|
148 |
+
" q_emb = emb_model.encode(\n",
|
149 |
+
" query, \n",
|
150 |
+
" convert_to_numpy=True, \n",
|
151 |
+
" normalize_embeddings=True,\n",
|
152 |
+
" ).astype(np.float32).reshape(1, -1)\n",
|
153 |
+
" \n",
|
154 |
+
" scores, indices = index.search(q_emb, top_k) # shape: (1, top_k)\n",
|
155 |
+
" \n",
|
156 |
+
" cand_idxs = indices[0]\n",
|
157 |
+
" cand_scores = scores[0]\n",
|
158 |
+
" cand_texts = [passages[i] for i in cand_idxs]\n",
|
159 |
+
"\n",
|
160 |
+
" results = [{\n",
|
161 |
+
" 'index': int(cand_idxs[i]),\n",
|
162 |
+
" 'score': float(cand_scores[i]),\n",
|
163 |
+
" 'text': cand_texts[i]\n",
|
164 |
+
" } for i in range(len(cand_idxs))]\n",
|
165 |
+
" \n",
|
166 |
+
" return results"
|
167 |
+
]
|
168 |
+
},
|
169 |
+
{
|
170 |
+
"cell_type": "code",
|
171 |
+
"execution_count": 22,
|
172 |
+
"id": "ece21ef6",
|
173 |
+
"metadata": {},
|
174 |
+
"outputs": [
|
175 |
+
{
|
176 |
+
"name": "stdout",
|
177 |
+
"output_type": "stream",
|
178 |
+
"text": [
|
179 |
+
"[Rank 1] index=76423, score=0.6417\n",
|
180 |
+
"Tội làm nhục người khác\n",
|
181 |
+
"1. Người nào xúc phạm nghiêm trọng nhân phẩm, danh dự của người khác, thì bị phạt cảnh cáo, phạt tiền từ 10.000.000 đồng đến 30.000.000 đồng hoặc phạt cải tạo không giam giữ đến 03 năm.\n",
|
182 |
+
"...\n",
|
183 |
+
"--------------------------------------------------------------------------------\n",
|
184 |
+
"[Rank 2] index=99131, score=0.6155\n",
|
185 |
+
"“Người nào có hành vi xâm phạm danh dự, nhân phẩm của người khác mà gây thiệt hại thì phải bồi thường.”\n",
|
186 |
+
"--------------------------------------------------------------------------------\n",
|
187 |
+
"[Rank 3] index=228550, score=0.5932\n",
|
188 |
+
"i) Điều 353, các khoản 2, 3 và 4 (tội tham ô tài sản); Điều 354, các khoản 2, 3 và 4 (tội nhận hối lộ); Điều 355, các khoản 2, 3 và 4 (tội lạm dụng chức vụ, quyền hạn chiếm đoạt tài sản); Điều 356, các khoản 2 và 3 (tội lợi dụng chức vụ, quyền hạn trong khi thi hành công vụ); Điều 357, các khoản 2 và 3 (tội lạm quyền trong khi thi hành công vụ); Điều 358, các khoản 2, 3 và 4 (tội lợi dụng chức vụ, quyền hạn gây ảnh hưởng đối với người khác để trục lợi); Điều 359, các khoản 2, 3 và 4 (tội giả mạo trong công tác); Điều 364, các khoản 2, 3 và 4 (tội đưa hối lộ); Điều 365, các khoản 2, 3 và 4 (tội làm môi giới hối lộ);\n",
|
189 |
+
"k) Điều 373, các khoản 3 và 4 (tội dùng nhục hình); Điều 374, các khoản 3 và 4 (tội bức cung); Điều 386, khoản 2 (tội trốn khỏi nơi giam, giữ hoặc trốn khi đang bị áp giải, dẫn giải, đang bị xét xử);\n",
|
190 |
+
"l) Các điều từ Điều 421 đến Điều 425 về các tội phá hoại hòa bình, chống loài người và tội phạm chiến tranh.\n",
|
191 |
+
"2. Phạm tội trong trường hợp lợi dụng chức vụ, quyền hạn cản trở việc phát hiện tội phạm hoặc có những hành vi khác bao che người phạm tội, thì bị phạt tù từ 02 năm đến 07 năm.\n",
|
192 |
+
"Điều 390. Tội không tố giác tội phạm\n",
|
193 |
+
"1. Người nào biết rõ một trong các tội phạm được quy định tại Điều 389 của Bộ luật này đang được chuẩn bị, đang hoặc đã được thực hiện mà không tố giác, nếu không thuộc trường hợp quy định tại khoản 2 Điều 19 của Bộ luật này, thì bị phạt cảnh cáo, phạt cải tạo không giam giữ đến 03 năm hoặc phạt tù từ 06 tháng đến 03 năm.\n",
|
194 |
+
"2. Người không tố giác nếu đã có hành động can ngăn người phạm tội hoặc hạn chế tác hại của tội phạm, thì có thể được miễn trách nhiệm hình sự hoặc miễn hình phạt.\n",
|
195 |
+
"Điều 391. Tội gây rối trật tự phiên tòa\n",
|
196 |
+
"1. Người nào tại phiên tòa mà thóa mạ, xúc phạm nghiêm trọng danh dự, nhân phẩm thành viên Hội đồng xét xử, những người khác có mặt tại phiên tòa hoặc có hành vi đập phá tài sản thì bị phạt tiền từ 10.000.000 đồng đến 100.000.000 đồng, phạt cải tạo không giam giữ đến 01 năm hoặc phạt tù từ 03 tháng đến 01 năm.\n",
|
197 |
+
"2. Phạm tội thuộc một trong các trường hợp sau đây, thì bị phạt tù từ 01 năm đến 03 năm:\n",
|
198 |
+
"a) Gây náo loạn phiên tòa dẫn đến phải dừng phiên tòa;\n",
|
199 |
+
"b) Hành hung thành viên Hội đồng xét xử.\n",
|
200 |
+
"--------------------------------------------------------------------------------\n",
|
201 |
+
"[Rank 4] index=228404, score=0.5660\n",
|
202 |
+
"Điều 155. Tội làm nhục người khác\n",
|
203 |
+
"1. Người nào xúc phạm nghiêm trọng nhân phẩm, danh dự của người khác, thì bị phạt cảnh cáo, phạt tiền từ 10.000.000 đồng đến 30.000.000 đồng hoặc phạt cải tạo không giam giữ đến 03 năm.\n",
|
204 |
+
"2. Phạm tội thuộc một trong các trường hợp sau đây, thì bị phạt tù từ 03 tháng đến 02 năm:\n",
|
205 |
+
"a) Phạm tội 02 lần trở lên;\n",
|
206 |
+
"b) Đối với 02 người trở lên;\n",
|
207 |
+
"c) Lợi dụng chức vụ, quyền hạn;\n",
|
208 |
+
"d) Đối với người đang thi hành công vụ;\n",
|
209 |
+
"đ) Đối với người dạy dỗ, nuôi dưỡng, chăm sóc, chữa bệnh cho mình;\n",
|
210 |
+
"e) Sử dụng mạng máy tính hoặc mạng viễn thông, phương tiện điện tử để phạm tội;\n",
|
211 |
+
"g) Gây rối loạn tâm thần và hành vi của nạn nhân từ 11% đến 45%.\n",
|
212 |
+
"3. Phạm tội thuộc một trong các trường hợp sau đây, thì bị phạt tù từ 02 năm đến 05 năm:\n",
|
213 |
+
"a) Gây rối loạn tâm thần và hành vi của nạn nhân 46% trở lên;\n",
|
214 |
+
"b) Làm nạn nhân tự sát.\n",
|
215 |
+
"4. Người phạm tội còn có thể bị cấm đảm nhiệm chức vụ, cấm hành nghề hoặc làm công việc nhất định từ 01 năm đến 05 năm.\n",
|
216 |
+
"--------------------------------------------------------------------------------\n",
|
217 |
+
"[Rank 5] index=143035, score=0.5470\n",
|
218 |
+
"Khoản 4. Người có hành vi xâm phạm thân thể, sức khỏe, tính mạng hoặc xúc phạm danh dự, nhân phẩm của người hành nghề và người khác làm việc tại cơ sở khám bệnh, chữa bệnh thì tùy theo tính chất, mức độ vi phạm mà bị xử lý vi phạm hành chính hoặc bị truy cứu trách nhiệm hình sự theo quy định của pháp luật.\n",
|
219 |
+
"--------------------------------------------------------------------------------\n",
|
220 |
+
"[Rank 6] index=57787, score=0.5443\n",
|
221 |
+
"\"Điều 7. Vi phạm quy định về trật tự công cộng\n",
|
222 |
+
"..\n",
|
223 |
+
"3. Phạt tiền từ 2.000.000 đồng đến 3.000.000 đồng đối với một trong những hành vi sau đây:\n",
|
224 |
+
"a) Có hành vi khiêu khích, trêu ghẹo, xúc phạm, lăng mạ, bôi nhọ danh dự, nhân phẩm của người khác, trừ trường hợp quy định tại điểm b khoản 2 Điều 21 và Điều 54 Nghị định này;\n",
|
225 |
+
"...\n",
|
226 |
+
"14. Biện pháp khắc phục hậu quả:\n",
|
227 |
+
"a) Buộc khôi phục lại tình trạng ban đầu đối với hành vi vi phạm quy định tại điểm c khoản 1, điểm l khoản 2 và điểm e khoản 4 Điều này;\n",
|
228 |
+
"b) Buộc cải chính thông tin sai sự thật hoặc gây nhầm lẫn đối với hành vi vi phạm quy định tại điểm a khoản 3 và điểm i khoản 4 Điều này;\n",
|
229 |
+
"c) Buộc xin lỗi công khai đối với hành vi vi phạm quy định tại điểm a khoản 3, các điểm d và đ khoản 5 Điều này trừ trường hợp nạn nhân có đơn không yêu cầu;\n",
|
230 |
+
"d) Buộc thực hiện biện pháp khắc phục tình trạng ô nhiễm môi trường đối với hành vi vi phạm quy định tại điểm h khoản 5 Điều này;\n",
|
231 |
+
"đ) Buộc chi trả toàn bộ chi phí khám bệnh, chữa bệnh đối với hành vi vi phạm quy định tại điểm d khoản 1, điểm c khoản 2, điểm b khoản 3 và điểm a khoản 5 Điều này.\"\n",
|
232 |
+
"--------------------------------------------------------------------------------\n",
|
233 |
+
"[Rank 7] index=57120, score=0.5337\n",
|
234 |
+
"Vi phạm quy định về trật tự công cộng\n",
|
235 |
+
"...\n",
|
236 |
+
"2. Phạt tiền từ 1.000.000 đồng đến 2.000.000 đồng đối với một trong những hành vi sau đây:\n",
|
237 |
+
"...\n",
|
238 |
+
"b) Tổ chức, tham gia tụ tập nhiều người ở nơi công cộng gây mất trật tự công cộng;\n",
|
239 |
+
"...\n",
|
240 |
+
"3. Phạt tiền từ 2.000.000 đồng đến 3.000.000 đồng đối với một trong những hành vi sau đây:\n",
|
241 |
+
"a) Có hành vi khiêu khích, trêu ghẹo, xúc phạm, lăng mạ, bôi nhọ danh dự, nhân phẩm của người khác, trừ trường hợp quy định tại điểm b khoản 2 Điều 21 và Điều 54 Nghị định này;\n",
|
242 |
+
"b) Tổ chức, thuê, xúi giục, lôi kéo, dụ dỗ, kích động người khác cố ý gây thương tích hoặc gây tổn hại cho sức khỏe người khác hoặc xâm phạm danh dự, nhân phẩm của người khác nhưng không bị truy cứu trách nhiệm hình sự;\n",
|
243 |
+
"...\n",
|
244 |
+
"4. Phạt tiền từ 3.000.000 đồng đến 5.000.000 đồng đối với một trong những hành vi sau đây:\n",
|
245 |
+
"a) Tổ chức thuê, xúi giục, lôi kéo, dụ dỗ hoặc kích động người khác gây rối, làm mất trật tự công cộng;\n",
|
246 |
+
"b) Mang theo trong người hoặc tàng trữ, cất giấu các loại vũ khí thô sơ, công cụ hỗ trợ hoặc các loại công cụ, phương tiện khác có khả năng sát thương; đồ vật, phương tiện giao thông nhằm mục đích gây rối trật tự công cộng, cố ý gây thương tích cho người khác;\n",
|
247 |
+
"...\n",
|
248 |
+
"5. Phạt tiền từ 5.000.000 đồng đến 8.000.000 đồng đối với một trong những hành vi sau đây:\n",
|
249 |
+
"a) Cố ý gây thương tích hoặc gây tổn hại cho sức khỏe của người khác nhưng không bị truy cứu trách nhiệm hình sự;\n",
|
250 |
+
"b) Gây rối trật tự công cộng mà có mang theo các loại vũ khí thô sơ, công cụ hỗ trợ hoặc công cụ, đồ vật, phương tiện khác có khả năng sát thương;\n",
|
251 |
+
"...\n",
|
252 |
+
"13. Hình thức xử phạt bổ sung:\n",
|
253 |
+
"a) Tịch thu tang vật, phương tiện vi phạm hành chính đối với hành vi vi phạm quy định tại các điểm a, d, đ và g khoản 2; điểm đ khoản 3; các đi��m b, e và i khoản 4; các điểm a, b và c khoản 5; các khoản 6 và 10 Điều này;\n",
|
254 |
+
"...\n",
|
255 |
+
"14. Biện pháp khắc phục hậu quả:\n",
|
256 |
+
"...\n",
|
257 |
+
"b) Buộc cải chính thông tin sai sự thật hoặc gây nhầm lẫn đối với hành vi vi phạm quy định tại điểm a khoản 3 và điểm i khoản 4 Điều này;\n",
|
258 |
+
"c) Buộc xin lỗi công khai đối với hành vi vi phạm quy định tại điểm a khoản 3, các điểm d và đ khoản 5 Điều này trừ trường hợp nạn nhân có đơn không yêu cầu;\n",
|
259 |
+
"...\n",
|
260 |
+
"đ) Buộc chi trả toàn bộ chi phí khám bệnh, chữa bệnh đối với hành vi vi phạm quy định tại điểm d khoản 1, điểm c khoản 2, điểm b khoản 3 và điểm a khoản 5 Điều này.\n",
|
261 |
+
"--------------------------------------------------------------------------------\n",
|
262 |
+
"[Rank 8] index=56183, score=0.5270\n",
|
263 |
+
"\"Điều 155. Tội làm nhục người khác\n",
|
264 |
+
"1. Người nào xúc phạm nghiêm trọng nhân phẩm, danh dự của người khác, thì bị phạt cảnh cáo, phạt tiền từ 10.000.000 đồng đến 30.000.000 đồng hoặc phạt cải tạo không giam giữ đến 03 năm.\n",
|
265 |
+
"2. Phạm tội thuộc một trong các trường hợp sau đây, thì bị phạt tù từ 03 tháng đến 02 năm:\n",
|
266 |
+
"a) Phạm tội 02 lần trở lên;\n",
|
267 |
+
"b) Đối với 02 người trở lên;\n",
|
268 |
+
"c) Lợi dụng chức vụ, quyền hạn;\n",
|
269 |
+
"d) Đối với người đang thi hành công vụ;\n",
|
270 |
+
"đ) Đối với người dạy dỗ, nuôi dưỡng, chăm sóc, chữa bệnh cho mình;\n",
|
271 |
+
"e) Sử dụng mạng máy tính hoặc mạng viễn thông, phương tiện điện tử để phạm tội;\n",
|
272 |
+
"g) Gây rối loạn tâm thần và hành vi của nạn nhân mà tỷ lệ tổn thương cơ thể từ 31% đến 60%”.\n",
|
273 |
+
"3. Phạm tội thuộc một trong các trường hợp sau đây, thì bị phạt tù từ 02 năm đến 05 năm:\n",
|
274 |
+
"a) Gây rối loạn tâm thần và hành vi của nạn nhân mà tỷ lệ tổn thương cơ thể 61% trở lên”;\n",
|
275 |
+
"b) Làm nạn nhân tự sát.\n",
|
276 |
+
"4. Người phạm tội còn có thể bị cấm đảm nhiệm chức vụ, cấm hành nghề hoặc làm công việc nhất định từ 01 năm đến 05 năm.\n",
|
277 |
+
"Điều 156. Tội vu khống\n",
|
278 |
+
"1. Người nào thực hiện một trong các hành vi sau đây, thì bị phạt tiền từ 10.000.000 đồng đến 50.000.000 đồng, phạt cải tạo không giam giữ đến 02 năm hoặc phạt tù từ 03 tháng đến 01 năm:\n",
|
279 |
+
"a) Bịa đặt hoặc loan truyền những điều biết rõ là sai sự thật nhằm xúc phạm nghiêm trọng nhân phẩm, danh dự hoặc gây thiệt hại đến quyền, lợi ích hợp pháp của người khác;\n",
|
280 |
+
"b) Bịa đặt người khác phạm tội và tố cáo họ trước cơ quan có thẩm quyền.\n",
|
281 |
+
"2. Phạm tội thuộc một trong các trường hợp sau đây, thì bị phạt tù từ 01 năm đến 03 năm:\n",
|
282 |
+
"a) Có tổ chức;\n",
|
283 |
+
"b) Lợi dụng chức vụ, quyền hạn;\n",
|
284 |
+
"c) Đối với 02 người trở lên;\n",
|
285 |
+
"d) Đối với ông, bà, cha, mẹ, người dạy dỗ, nuôi dưỡng, chăm sóc, giáo dục, chữa bệnh cho mình;\n",
|
286 |
+
"đ) Đối với người đang thi hành công vụ;\n",
|
287 |
+
"e) Sử dụng mạng máy tính hoặc mạng viễn thông, phương tiện điện tử để phạm tội;\n",
|
288 |
+
"g) Gây rối loạn tâm thần và hành vi của nạn nhân mà tỷ lệ tổn thương cơ thể từ 31% đến 60%;\n",
|
289 |
+
"h) Vu khống người khác phạm tội rất nghiêm trọng hoặc đặc biệt nghiêm trọng.\n",
|
290 |
+
"3. Phạm tội thuộc một trong các trường hợp sau đây, thì bị phạt tù từ 03 năm đến 07 năm:\n",
|
291 |
+
"a) Vì động cơ đê hèn;\n",
|
292 |
+
"b) Gây rối loạn tâm thần và hành vi của nạn nhân mà tỷ lệ tổn thương cơ thể 61% trở lên;\n",
|
293 |
+
"c) Làm nạn nhân tự sát.\n",
|
294 |
+
"4. Người phạm tội còn có thể bị phạt tiền từ 10.000.000 đồng đến 50.000.000 đồng, cấm đảm nhiệm chức vụ, cấm hành nghề hoặc làm công việc nhất định từ 01 năm đến 05 năm.\"\n",
|
295 |
+
"--------------------------------------------------------------------------------\n",
|
296 |
+
"[Rank 9] index=80022, score=0.5218\n",
|
297 |
+
"\"Điều 20.\n",
|
298 |
+
"1. Mọi người có quyền bất khả xâm phạm về thân thể, được pháp luật bảo hộ về sức khoẻ, danh dự và nhân phẩm; không bị tra tấn, bạo lực, truy bức, nhục hình hay bất kỳ hình thức đối xử nào khác xâm phạm thân thể, sức khỏe, xúc phạm danh dự, nhân phẩm.\"\n",
|
299 |
+
"--------------------------------------------------------------------------------\n",
|
300 |
+
"[Rank 10] index=52682, score=0.5203\n",
|
301 |
+
"\"Điều 589. Thiệt hại do tài sản bị xâm phạm\n",
|
302 |
+
"Thiệt hại do tài sản bị xâm phạm bao gồm:\n",
|
303 |
+
"1. Tài sản bị mất, bị hủy hoại hoặc bị hư hỏng.\n",
|
304 |
+
"2. Lợi ích gắn liền với việc sử dụng, khai thác tài sản bị mất, bị giảm sút.\n",
|
305 |
+
"3. Chi phí hợp lý để ngăn chặn, hạn chế và khắc phục thiệt hại.\n",
|
306 |
+
"4. Thiệt hại khác do luật quy định.\"\n",
|
307 |
+
"--------------------------------------------------------------------------------\n"
|
308 |
+
]
|
309 |
+
}
|
310 |
+
],
|
311 |
+
"source": [
|
312 |
+
"query = 'Tội xúc phạm danh dự'\n",
|
313 |
+
"hits = retrieval(fine_tuned_model, query, legal_index, top_k=10)\n",
|
314 |
+
"\n",
|
315 |
+
"for h in hits:\n",
|
316 |
+
" print(f\"[Rank {hits.index(h)+1}] index={h['index']}, score={h['score']:.4f}\")\n",
|
317 |
+
" print(f\"{h['text']}\\n{'-'*80}\")"
|
318 |
+
]
|
319 |
+
},
|
320 |
+
{
|
321 |
+
"cell_type": "code",
|
322 |
+
"execution_count": null,
|
323 |
+
"id": "1bedd1a7",
|
324 |
+
"metadata": {},
|
325 |
+
"outputs": [],
|
326 |
+
"source": [
|
327 |
+
"# def search(model, query, index, k=10):\n",
|
328 |
+
"# query_embedding = model.encode(\n",
|
329 |
+
"# query, \n",
|
330 |
+
"# convert_to_numpy=True, \n",
|
331 |
+
"# normalize_embeddings=True,\n",
|
332 |
+
"# ).astype(np.float32).reshape(1, -1)\n",
|
333 |
+
"\n",
|
334 |
+
"# scores, indices = index.search(query_embedding, k*3)\n",
|
335 |
+
"# hits = [{'score': scores[0][i], 'index': indices[0][i]} for i in range(len(scores[0]))]\n",
|
336 |
+
"# return hits"
|
337 |
+
]
|
338 |
+
},
|
339 |
+
{
|
340 |
+
"cell_type": "code",
|
341 |
+
"execution_count": null,
|
342 |
+
"id": "4ef857db",
|
343 |
+
"metadata": {},
|
344 |
+
"outputs": [],
|
345 |
+
"source": [
|
346 |
+
"# hits = search(\n",
|
347 |
+
"# model=fine_tuned_model, \n",
|
348 |
+
"# query='Hợp đồng lao động là gì?', \n",
|
349 |
+
"# index=legal_index, \n",
|
350 |
+
"# k=10\n",
|
351 |
+
"# )\n",
|
352 |
+
"\n",
|
353 |
+
"# for rank, hit in enumerate(hits):\n",
|
354 |
+
"# print(f\"[Rank: {rank + 1}]\")\n",
|
355 |
+
"# print(f\"(Index: {hit['index']}Score: {hit['score']:.4f})\\n\")\n",
|
356 |
+
"# print(passages[hit['index']])\n",
|
357 |
+
"# print('-' * 100)\n",
|
358 |
+
"# print()"
|
359 |
+
]
|
360 |
+
}
|
361 |
+
],
|
362 |
+
"metadata": {
|
363 |
+
"kernelspec": {
|
364 |
+
"display_name": "legal_doc_retrieval",
|
365 |
+
"language": "python",
|
366 |
+
"name": "python3"
|
367 |
+
},
|
368 |
+
"language_info": {
|
369 |
+
"codemirror_mode": {
|
370 |
+
"name": "ipython",
|
371 |
+
"version": 3
|
372 |
+
},
|
373 |
+
"file_extension": ".py",
|
374 |
+
"mimetype": "text/x-python",
|
375 |
+
"name": "python",
|
376 |
+
"nbconvert_exporter": "python",
|
377 |
+
"pygments_lexer": "ipython3",
|
378 |
+
"version": "3.10.16"
|
379 |
+
}
|
380 |
+
},
|
381 |
+
"nbformat": 4,
|
382 |
+
"nbformat_minor": 5
|
383 |
+
}
|