Spaces:
Configuration error
Configuration error
Update cross_encoder_reranking_train.py
Browse files
cross_encoder_reranking_train.py
CHANGED
@@ -70,6 +70,29 @@ def process_single_patent(patent_dict):
|
|
70 |
"features": rank_by_centrality(top_features),
|
71 |
}
|
72 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
def load_json_file(file_path):
|
74 |
"""Load JSON data from a file"""
|
75 |
with open(file_path, 'r') as f:
|
@@ -145,10 +168,8 @@ def extract_text(content_dict, text_type="full"):
|
|
145 |
filtered_dict = process_single_patent(content_dict)
|
146 |
all_text = []
|
147 |
# Start with abstract for better context at the beginning
|
148 |
-
if "
|
149 |
-
all_text.append(content_dict["
|
150 |
-
# if "pa01" in content_dict:
|
151 |
-
# all_text.append(content_dict["pa01"])
|
152 |
|
153 |
# For claims, paragraphs and features, we take only the top-10 most relevant
|
154 |
# Add claims
|
@@ -162,6 +183,26 @@ def extract_text(content_dict, text_type="full"):
|
|
162 |
all_text.append(paragraph)
|
163 |
|
164 |
return " ".join(all_text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
165 |
|
166 |
|
167 |
return ""
|
|
|
70 |
"features": rank_by_centrality(top_features),
|
71 |
}
|
72 |
|
73 |
+
def process_single_patent2(patent_dict):
|
74 |
+
def filter_short_texts(texts, min_tokens=5):
|
75 |
+
return [text for text in texts if len(text.split()) >= min_tokens]
|
76 |
+
|
77 |
+
# Filter short texts
|
78 |
+
claims = filter_short_texts([v for k, v in patent_dict.items() if k.startswith("c-en")])
|
79 |
+
paragraphs = filter_short_texts([v for k, v in patent_dict.items() if k.startswith("p")])
|
80 |
+
features = filter_short_texts([v for k, v in patent_dict.get("features", {}).items()])
|
81 |
+
|
82 |
+
# Re-rank claims and features directly
|
83 |
+
ranked_claims = rank_by_centrality(claims)
|
84 |
+
ranked_features = rank_by_centrality(features)
|
85 |
+
|
86 |
+
# Only filter (cluster + rank) for paragraphs
|
87 |
+
filtered_paragraphs = cluster_and_rank(paragraphs)
|
88 |
+
ranked_paragraphs = rank_by_centrality(filtered_paragraphs)
|
89 |
+
|
90 |
+
return {
|
91 |
+
"claims": ranked_claims,
|
92 |
+
"paragraphs": ranked_paragraphs,
|
93 |
+
"features": ranked_features,
|
94 |
+
}
|
95 |
+
|
96 |
def load_json_file(file_path):
|
97 |
"""Load JSON data from a file"""
|
98 |
with open(file_path, 'r') as f:
|
|
|
168 |
filtered_dict = process_single_patent(content_dict)
|
169 |
all_text = []
|
170 |
# Start with abstract for better context at the beginning
|
171 |
+
if "pa01" in content_dict:
|
172 |
+
all_text.append(content_dict["pa01"])
|
|
|
|
|
173 |
|
174 |
# For claims, paragraphs and features, we take only the top-10 most relevant
|
175 |
# Add claims
|
|
|
183 |
all_text.append(paragraph)
|
184 |
|
185 |
return " ".join(all_text)
|
186 |
+
|
187 |
+
elif text_type == "smart2":
|
188 |
+
filtered_dict = process_single_patent2(content_dict)
|
189 |
+
all_text = []
|
190 |
+
# Start with abstract for better context at the beginning
|
191 |
+
if "pa01" in content_dict:
|
192 |
+
all_text.append(content_dict["pa01"])
|
193 |
+
|
194 |
+
# For claims, paragraphs and features, we take only the top-10 most relevant
|
195 |
+
# Add claims
|
196 |
+
for claim in filtered_dict["claims"][:10]:
|
197 |
+
all_text.append(claim)
|
198 |
+
# Add paragraphs
|
199 |
+
for paragraph in filtered_dict["paragraphs"][:10]:
|
200 |
+
all_text.append(paragraph)
|
201 |
+
# Add features
|
202 |
+
for feature in filtered_dict["features"][:10]:
|
203 |
+
all_text.append(feature)
|
204 |
+
|
205 |
+
return " ".join(all_text)
|
206 |
|
207 |
|
208 |
return ""
|