talktorhutika commited on
Commit
0556c92
·
verified ·
1 Parent(s): 4325127

Upload finetune1.py

Browse files
Files changed (1) hide show
  1. finetune1.py +115 -0
finetune1.py ADDED
@@ -0,0 +1,115 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
4
+ from datasets import Dataset
5
+ from sklearn.model_selection import train_test_split
6
+ import requests
7
+ from io import BytesIO
8
+ import torch
9
+
10
+ # Load the dataset
11
+ @st.cache_data
12
+ def load_data():
13
+ url = "https://huggingface.co/datasets/HUPD/hupd/resolve/main/hupd_metadata_2022-02-22.feather"
14
+ response = requests.get(url)
15
+ data = BytesIO(response.content)
16
+ df = pd.read_feather(data)
17
+ return df
18
+
19
+ # Tokenizer and model loading
20
+ @st.cache_resource
21
+ def load_tokenizer_and_model(model_name):
22
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
23
+ model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2) # Adjust num_labels as needed
24
+ return tokenizer, model
25
+
26
+ # Tokenize and prepare the dataset
27
+ def prepare_data(df, tokenizer):
28
+ df['filing_date'] = pd.to_datetime(df['filing_date'])
29
+ jan_2016_df = df[df['filing_date'].dt.to_period('M') == '2016-01']
30
+
31
+ texts = jan_2016_df['invention_title'].tolist()
32
+ labels = jan_2016_df['patent_number'].tolist()
33
+
34
+ def tokenize_function(texts):
35
+ return tokenizer(texts, padding="max_length", truncation=True, return_tensors="pt", max_length=512)
36
+
37
+ tokenized_data = tokenize_function(texts)
38
+
39
+ dataset_dict = {
40
+ 'input_ids': [x.tolist() for x in tokenized_data['input_ids']],
41
+ 'attention_mask': [x.tolist() for x in tokenized_data['attention_mask']],
42
+ 'labels': labels
43
+ }
44
+
45
+ dataset = Dataset.from_dict(dataset_dict)
46
+
47
+ return dataset
48
+
49
+ # Define Streamlit app
50
+ def main():
51
+ st.title("Patent Classification with Fine-Tuned BERT")
52
+
53
+ # Load data
54
+ df = load_data()
55
+
56
+ # Show sample data
57
+ st.subheader("Some Data from January 2016")
58
+ st.write(df.head())
59
+
60
+ # Load tokenizer and model
61
+ model_name = "bert-base-uncased"
62
+ tokenizer, model = load_tokenizer_and_model(model_name)
63
+
64
+ # Prepare data
65
+ dataset = prepare_data(df, tokenizer)
66
+
67
+ # Split the dataset
68
+ train_data, eval_data = train_test_split(list(zip(dataset['input_ids'], dataset['attention_mask'], dataset['labels'])), test_size=0.2, random_state=42)
69
+
70
+ train_dataset = Dataset.from_dict({
71
+ 'input_ids': [item[0] for item in train_data],
72
+ 'attention_mask': [item[1] for item in train_data],
73
+ 'labels': [item[2] for item in train_data]
74
+ })
75
+
76
+ eval_dataset = Dataset.from_dict({
77
+ 'input_ids': [item[0] for item in eval_data],
78
+ 'attention_mask': [item[1] for item in eval_data],
79
+ 'labels': [item[2] for item in eval_data]
80
+ })
81
+
82
+ # Fine-tune model
83
+ training_args = TrainingArguments(
84
+ output_dir='./results',
85
+ evaluation_strategy="epoch",
86
+ learning_rate=2e-5,
87
+ per_device_train_batch_size=8,
88
+ per_device_eval_batch_size=8,
89
+ num_train_epochs=3,
90
+ weight_decay=0.01,
91
+ )
92
+
93
+ trainer = Trainer(
94
+ model=model,
95
+ args=training_args,
96
+ train_dataset=train_dataset,
97
+ eval_dataset=eval_dataset,
98
+ tokenizer=tokenizer # Ensure tokenizer is passed
99
+ )
100
+
101
+ st.subheader("Training the Model")
102
+ if st.button('Train Model'):
103
+ with st.spinner('Training in progress...'):
104
+ trainer.train()
105
+ model.save_pretrained("./finetuned_model")
106
+ tokenizer.save_pretrained("./finetuned_model")
107
+ st.success("Model training complete and saved.")
108
+
109
+ # Display pretrained model data
110
+ st.subheader("Pretrained Model")
111
+ if st.button('Show Pretrained Model'):
112
+ st.write("Pretrained model is `bert-base-uncased`. Fine-tuned model is saved at './finetuned_model'.")
113
+
114
+ if __name__ == "__main__":
115
+ main()