talktorhutika commited on
Commit
e816d24
·
verified ·
1 Parent(s): 863be65

Update finetune3.py

Browse files
Files changed (1) hide show
  1. finetune3.py +8 -8
finetune3.py CHANGED
@@ -61,6 +61,9 @@ def prepare_data(df, tokenizer):
61
  def main():
62
  st.title("Patent Classification with Fine-Tuned BERT")
63
 
 
 
 
64
  # Load data
65
  df = load_data()
66
 
@@ -70,12 +73,11 @@ def main():
70
 
71
  # Prepare data
72
  model_name = "bert-base-uncased"
73
- dummy_num_labels = 5
74
- tokenizer, model = load_tokenizer_and_model(model_name, dummy_num_labels)
75
  dataset, num_labels = prepare_data(df, tokenizer)
76
 
77
  # Update the model with the correct number of labels based on the data
78
- if num_labels != dummy_num_labels:
79
  model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_labels)
80
 
81
  # Split the dataset
@@ -102,7 +104,7 @@ def main():
102
 
103
  # Fine-tune model
104
  training_args = TrainingArguments(
105
- output_dir='./results',
106
  evaluation_strategy="epoch",
107
  learning_rate=2e-5,
108
  per_device_train_batch_size=8,
@@ -123,15 +125,13 @@ def main():
123
  if st.button('Train Model'):
124
  with st.spinner('Training in progress...'):
125
  trainer.train()
126
- model.save_pretrained("./finetuned_model")
127
- tokenizer.save_pretrained("./finetuned_model")
128
  st.success("Model training complete and saved.")
129
 
130
  # Display pretrained model data
131
  st.subheader("Pretrained Model")
132
  if st.button('Show Pretrained Model'):
133
- model_dir = './finetuned_model'
134
-
135
  # List files in the directory
136
  if os.path.exists(model_dir):
137
  files = os.listdir(model_dir)
 
61
  def main():
62
  st.title("Patent Classification with Fine-Tuned BERT")
63
 
64
+ # Initialize model directory path
65
+ model_dir = './finetuned_model'
66
+
67
  # Load data
68
  df = load_data()
69
 
 
73
 
74
  # Prepare data
75
  model_name = "bert-base-uncased"
76
+ tokenizer, model = load_tokenizer_and_model(model_name, num_labels=5)
 
77
  dataset, num_labels = prepare_data(df, tokenizer)
78
 
79
  # Update the model with the correct number of labels based on the data
80
+ if num_labels != 5:
81
  model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=num_labels)
82
 
83
  # Split the dataset
 
104
 
105
  # Fine-tune model
106
  training_args = TrainingArguments(
107
+ output_dir=model_dir,
108
  evaluation_strategy="epoch",
109
  learning_rate=2e-5,
110
  per_device_train_batch_size=8,
 
125
  if st.button('Train Model'):
126
  with st.spinner('Training in progress...'):
127
  trainer.train()
128
+ model.save_pretrained(model_dir)
129
+ tokenizer.save_pretrained(model_dir)
130
  st.success("Model training complete and saved.")
131
 
132
  # Display pretrained model data
133
  st.subheader("Pretrained Model")
134
  if st.button('Show Pretrained Model'):
 
 
135
  # List files in the directory
136
  if os.path.exists(model_dir):
137
  files = os.listdir(model_dir)