Usernameblue2 commited on
Commit
256edf9
Β·
verified Β·
1 Parent(s): e4f0c31

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +51 -0
  2. requirements.txt +5 -0
app.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from datasets import load_dataset
3
+ from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments
4
+ from sklearn.metrics import classification_report
5
+ import torch
6
+
7
+ # Load few-shot dataset
8
+ dataset = load_dataset("ai4bharat/sangraha")
9
+ train_data = dataset["train"].select(range(30))
10
+ test_data = dataset["validation"].select(range(10))
11
+
12
+ # Tokenizer and preprocessing
13
+ tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
14
+
15
+ def tokenize(example):
16
+ return tokenizer(example["context"], padding="max_length", truncation=True)
17
+
18
+ def encode_label(example):
19
+ example["label"] = 1 if "bank" in example["context"].lower() else 0
20
+ return example
21
+
22
+ train_data = train_data.map(tokenize).map(encode_label)
23
+ test_data = test_data.map(tokenize).map(encode_label)
24
+
25
+ train_data.set_format("torch", columns=["input_ids", "attention_mask", "label"])
26
+ test_data.set_format("torch", columns=["input_ids", "attention_mask", "label"])
27
+
28
+ # Model and trainer
29
+ model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)
30
+
31
+ training_args = TrainingArguments(
32
+ output_dir="./results",
33
+ num_train_epochs=3,
34
+ per_device_train_batch_size=4,
35
+ per_device_eval_batch_size=4,
36
+ evaluation_strategy="epoch",
37
+ logging_dir="./logs"
38
+ )
39
+
40
+ trainer = Trainer(
41
+ model=model,
42
+ args=training_args,
43
+ train_dataset=train_data,
44
+ eval_dataset=test_data,
45
+ )
46
+
47
+ trainer.train()
48
+ metrics = trainer.evaluate()
49
+
50
+ predictions = trainer.predict(test_data).predictions.argmax(-1)
51
+ print(classification_report(test_data["label"], predictions))
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+
2
+ transformers
3
+ datasets
4
+ scikit-learn
5
+ torch