darpanaswal commited on
Commit
ec30c35
·
verified ·
1 Parent(s): 74714ba

Update cross_encoder_reranking_train.py

Browse files
Files changed (1) hide show
  1. cross_encoder_reranking_train.py +45 -4
cross_encoder_reranking_train.py CHANGED
@@ -70,6 +70,29 @@ def process_single_patent(patent_dict):
70
  "features": rank_by_centrality(top_features),
71
  }
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  def load_json_file(file_path):
74
  """Load JSON data from a file"""
75
  with open(file_path, 'r') as f:
@@ -145,10 +168,8 @@ def extract_text(content_dict, text_type="full"):
145
  filtered_dict = process_single_patent(content_dict)
146
  all_text = []
147
  # Start with abstract for better context at the beginning
148
- if "title" in content_dict:
149
- all_text.append(content_dict["title"])
150
- # if "pa01" in content_dict:
151
- # all_text.append(content_dict["pa01"])
152
 
153
  # For claims, paragraphs and features, we take only the top-10 most relevant
154
  # Add claims
@@ -162,6 +183,26 @@ def extract_text(content_dict, text_type="full"):
162
  all_text.append(paragraph)
163
 
164
  return " ".join(all_text)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
 
166
 
167
  return ""
 
70
  "features": rank_by_centrality(top_features),
71
  }
72
 
73
+ def process_single_patent2(patent_dict):
74
+ def filter_short_texts(texts, min_tokens=5):
75
+ return [text for text in texts if len(text.split()) >= min_tokens]
76
+
77
+ # Filter short texts
78
+ claims = filter_short_texts([v for k, v in patent_dict.items() if k.startswith("c-en")])
79
+ paragraphs = filter_short_texts([v for k, v in patent_dict.items() if k.startswith("p")])
80
+ features = filter_short_texts([v for k, v in patent_dict.get("features", {}).items()])
81
+
82
+ # Re-rank claims and features directly
83
+ ranked_claims = rank_by_centrality(claims)
84
+ ranked_features = rank_by_centrality(features)
85
+
86
+ # Only filter (cluster + rank) for paragraphs
87
+ filtered_paragraphs = cluster_and_rank(paragraphs)
88
+ ranked_paragraphs = rank_by_centrality(filtered_paragraphs)
89
+
90
+ return {
91
+ "claims": ranked_claims,
92
+ "paragraphs": ranked_paragraphs,
93
+ "features": ranked_features,
94
+ }
95
+
96
  def load_json_file(file_path):
97
  """Load JSON data from a file"""
98
  with open(file_path, 'r') as f:
 
168
  filtered_dict = process_single_patent(content_dict)
169
  all_text = []
170
  # Start with abstract for better context at the beginning
171
+ if "pa01" in content_dict:
172
+ all_text.append(content_dict["pa01"])
 
 
173
 
174
  # For claims, paragraphs and features, we take only the top-10 most relevant
175
  # Add claims
 
183
  all_text.append(paragraph)
184
 
185
  return " ".join(all_text)
186
+
187
+ elif text_type == "smart2":
188
+ filtered_dict = process_single_patent2(content_dict)
189
+ all_text = []
190
+ # Start with abstract for better context at the beginning
191
+ if "pa01" in content_dict:
192
+ all_text.append(content_dict["pa01"])
193
+
194
+ # For claims, paragraphs and features, we take only the top-10 most relevant
195
+ # Add claims
196
+ for claim in filtered_dict["claims"][:10]:
197
+ all_text.append(claim)
198
+ # Add paragraphs
199
+ for paragraph in filtered_dict["paragraphs"][:10]:
200
+ all_text.append(paragraph)
201
+ # Add features
202
+ for feature in filtered_dict["features"][:10]:
203
+ all_text.append(feature)
204
+
205
+ return " ".join(all_text)
206
 
207
 
208
  return ""