IsrakML commited on
Commit
e6d6007
·
verified ·
1 Parent(s): 8020013

Upload 2 files

Browse files
Files changed (2) hide show
  1. app.py +119 -0
  2. best_model_state1.bin +3 -0
app.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from transformers import BertTokenizer, BertModel
5
+ import requests
6
+ from bs4 import BeautifulSoup
7
+ import pandas as pd
8
+
9
+ # Define the model class (matching the saved architecture)
10
+ class HeadlineClassifier(torch.nn.Module):
11
+ def __init__(self, num_aspect_classes, num_polarity_classes):
12
+ super(HeadlineClassifier, self).__init__()
13
+ self.bert = BertModel.from_pretrained("sagorsarker/bangla-bert-base", return_dict=False)
14
+ self.drop = torch.nn.Dropout(0.5)
15
+ self.aspect_out = torch.nn.Linear(self.bert.config.hidden_size, num_aspect_classes)
16
+ self.polarity_out = torch.nn.Linear(self.bert.config.hidden_size, num_polarity_classes)
17
+
18
+ def forward(self, input_ids, attention_mask):
19
+ _, pooled_output = self.bert(input_ids=input_ids, attention_mask=attention_mask, return_dict=False)
20
+ output = self.drop(pooled_output)
21
+ aspect_output = self.aspect_out(output)
22
+ polarity_output = self.polarity_out(output)
23
+ return aspect_output, polarity_output
24
+
25
+ # Load tokenizer and model
26
+ tokenizer = BertTokenizer.from_pretrained("sagorsarker/bangla-bert-base")
27
+ model = HeadlineClassifier(num_aspect_classes=4, num_polarity_classes=3)
28
+ model.load_state_dict(torch.load('best_model_state1.bin', map_location=torch.device('cpu')))
29
+ model.eval()
30
+
31
+ # Class labels
32
+ aspect_class_names = ["others", "politics", "religion", "sports"]
33
+ polarity_class_names = ["negative", "neutral", "positive"]
34
+
35
+ # Function for single text prediction
36
+ def predict_text(text):
37
+ encoded = tokenizer.encode_plus(
38
+ text,
39
+ max_length=40,
40
+ add_special_tokens=True,
41
+ return_token_type_ids=False,
42
+ pad_to_max_length=True,
43
+ return_attention_mask=True,
44
+ return_tensors='pt'
45
+ )
46
+ input_ids = encoded['input_ids']
47
+ attention_mask = encoded['attention_mask']
48
+
49
+ with torch.no_grad():
50
+ aspect_output, polarity_output = model(input_ids, attention_mask)
51
+ aspect_prediction = torch.argmax(aspect_output, dim=1).item()
52
+ polarity_prediction = torch.argmax(polarity_output, dim=1).item()
53
+
54
+ return aspect_class_names[aspect_prediction], polarity_class_names[polarity_prediction]
55
+
56
+ # Function to scrape headlines with multiple classes
57
+ def scrape_headlines(url):
58
+ response = requests.get(url)
59
+ soup = BeautifulSoup(response.content, "html.parser")
60
+
61
+ # Extract headlines with the specified classes
62
+ headlines = [h.get_text(strip=True) for h in soup.find_all("a", class_=["title-link", "stretched-link", "Title"])[:50]]
63
+ return headlines
64
+
65
+ # Streamlit App Interface
66
+ st.title("Bangla Headline Aspect and Polarity Predictor")
67
+
68
+ # Radio button for functionality selection
69
+ option = st.radio("Choose Analysis Type:", ("Particular", "Overall"))
70
+
71
+ if option == "Particular":
72
+ # Input for single text prediction
73
+ text_input = st.text_area("Enter your Bangla text:")
74
+ if st.button("Predict"):
75
+ if text_input.strip():
76
+ aspect, polarity = predict_text(text_input)
77
+ st.write("### Original Text:")
78
+ st.write(f"{text_input}")
79
+ st.write(f"**Predicted Aspect Class:** {aspect}")
80
+ st.write(f"**Predicted Polarity Class:** {polarity}")
81
+ else:
82
+ st.warning("Please enter some text to predict.")
83
+
84
+ elif option == "Overall":
85
+ # Input for URL and headline analysis
86
+ url_input = st.text_input("Enter the URL:")
87
+ if st.button("Analyze Headlines"):
88
+ if url_input.strip():
89
+ headlines = scrape_headlines(url_input)
90
+ if not headlines:
91
+ st.warning("No headlines found. Please check the URL or structure of the site.")
92
+ else:
93
+ # Initialize counters
94
+ aspect_counts = {cls: 0 for cls in aspect_class_names}
95
+ polarity_counts = {cls: 0 for cls in polarity_class_names}
96
+
97
+ # Process each headline
98
+ for headline in headlines:
99
+ aspect, polarity = predict_text(headline)
100
+ aspect_counts[aspect] += 1
101
+ polarity_counts[polarity] += 1
102
+
103
+ # Display counts
104
+ st.write("### Aspect Class Counts")
105
+ for cls in aspect_class_names:
106
+ st.write(f"{cls}: {aspect_counts[cls]}")
107
+
108
+ st.write("### Polarity Class Counts")
109
+ for cls in polarity_class_names:
110
+ st.write(f"{cls}: {polarity_counts[cls]}")
111
+
112
+ # Display bar charts
113
+ st.write("### Aspect Distribution")
114
+ st.bar_chart(pd.DataFrame(list(aspect_counts.items()), columns=['Aspect', 'Count']).set_index('Aspect'))
115
+
116
+ st.write("### Polarity Distribution")
117
+ st.bar_chart(pd.DataFrame(list(polarity_counts.items()), columns=['Polarity', 'Count']).set_index('Polarity'))
118
+ else:
119
+ st.warning("Please enter a valid URL.")
best_model_state1.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:359a527621945876ad41c0acffc3ba4980bec1839cdd52ca883790d9675543e4
3
+ size 657692018