from datasets import load_dataset from transformers import BertTokenizer, BertForSequenceClassification, Trainer, TrainingArguments from sklearn.metrics import classification_report import torch # Load few-shot dataset dataset = load_dataset("ai4bharat/sangraha") train_data = dataset["train"].select(range(30)) test_data = dataset["validation"].select(range(10)) # Tokenizer and preprocessing tokenizer = BertTokenizer.from_pretrained("bert-base-uncased") def tokenize(example): return tokenizer(example["context"], padding="max_length", truncation=True) def encode_label(example): example["label"] = 1 if "bank" in example["context"].lower() else 0 return example train_data = train_data.map(tokenize).map(encode_label) test_data = test_data.map(tokenize).map(encode_label) train_data.set_format("torch", columns=["input_ids", "attention_mask", "label"]) test_data.set_format("torch", columns=["input_ids", "attention_mask", "label"]) # Model and trainer model = BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2) training_args = TrainingArguments( output_dir="./results", num_train_epochs=3, per_device_train_batch_size=4, per_device_eval_batch_size=4, evaluation_strategy="epoch", logging_dir="./logs" ) trainer = Trainer( model=model, args=training_args, train_dataset=train_data, eval_dataset=test_data, ) trainer.train() metrics = trainer.evaluate() predictions = trainer.predict(test_data).predictions.argmax(-1) print(classification_report(test_data["label"], predictions))