umangshikarvar commited on
Commit
284cfdb
·
verified ·
1 Parent(s): 3ecefb7
Files changed (1) hide show
  1. app.py +67 -0
app.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM, LogitsProcessor, LogitsProcessorList
3
+ from peft import PeftModel, PeftConfig
4
+ import gradio as gr
5
+
6
+ # === Load Tokenizer ===
7
+ tokenizer = AutoTokenizer.from_pretrained("EleutherAI/gpt-neo-1.3B")
8
+ tokenizer.pad_token = tokenizer.eos_token
9
+
10
+ # === Load Model + QLoRA Adapter ===
11
+ checkpoint_dir = "umangshikarvar/sentiment-gpt-neo-qlora" # Update if needed
12
+ peft_config = PeftConfig.from_pretrained(checkpoint_dir)
13
+ base_model = AutoModelForCausalLM.from_pretrained(peft_config.base_model_name_or_path, torch_dtype=torch.float16)
14
+ model = PeftModel.from_pretrained(base_model, checkpoint_dir)
15
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+ model.eval().to(device)
17
+
18
+ # === Define Custom LogitsProcessor ===
19
+ class RestrictVocabLogitsProcessor(LogitsProcessor):
20
+ def __init__(self, allowed_token_ids):
21
+ self.allowed_token_ids = allowed_token_ids
22
+
23
+ def __call__(self, input_ids, scores):
24
+ mask = torch.full_like(scores, float("-inf"))
25
+ mask[:, self.allowed_token_ids] = scores[:, self.allowed_token_ids]
26
+ return mask
27
+
28
+ # === Set Allowed Sentiment Tokens ===
29
+ sentiment_words = ["Positive", "Negative", "Neutral"]
30
+ allowed_ids = [
31
+ tokenizer(word, add_special_tokens=False)["input_ids"][0]
32
+ for word in sentiment_words
33
+ ]
34
+ logits_processor = LogitsProcessorList([
35
+ RestrictVocabLogitsProcessor(allowed_ids)
36
+ ])
37
+
38
+ # === Inference Function ===
39
+ def predict_sentiment(tweet):
40
+ prompt = f"Tweet: {tweet}\nSentiment:"
41
+ inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
42
+
43
+ outputs = model.generate(
44
+ **inputs,
45
+ max_new_tokens=1,
46
+ do_sample=False,
47
+ logits_processor=logits_processor
48
+ )
49
+
50
+ response = tokenizer.decode(outputs[0], skip_special_tokens=True)
51
+ prediction = response.replace(prompt, "").strip().split()[0]
52
+
53
+ if prediction.lower().startswith("pos"):
54
+ return "Positive"
55
+ elif prediction.lower().startswith("neg"):
56
+ return "Negative"
57
+ else:
58
+ return "Neutral"
59
+
60
+ # === Gradio Interface ===
61
+ gr.Interface(
62
+ fn=predict_sentiment,
63
+ inputs=gr.Textbox(lines=2, placeholder="Enter the text", label="Statement"),
64
+ outputs="text",
65
+ title="Sentiment Classifier",
66
+ description="Classifies the sentiment of a statement, as Positive, Negative, or Neutral."
67
+ ).launch(share=True)