nei-demo-backup / svm_predict.py
SmitaGautam's picture
Update svm_predict.py
25b106e verified
raw
history blame
1.1 kB
import nltk
from nltk import word_tokenize
from nltk import pos_tag
import joblib
from train import feature_vector, pos_tags
model = joblib.load('SVM_NEI_model.pkl')
nltk.download('averaged_perceptron_tagger_eng')
nltk.download('punkt_tab')
def predict(sentence):
tokens = word_tokenize(sentence)
sent_pos_tags = pos_tag(tokens)
predictions = []
l = len(tokens)
for idx, word in enumerate(tokens):
# prev_tag = -1 if idx==0 else sent_pos_tags[idx-1][1]
# next_tag = -1 if idx==len(tokens)-1 else sent_pos_tags[idx+1][1]
# current_tag = sent_pos_tags[idx][1]
# prev_idx = pos_tags.index(prev_tag) if prev_tag in pos_tags else -1
# next_idx = pos_tags.index(next_tag) if next_tag in pos_tags else -1
# current_idx = pos_tags.index(current_tag) if current_tag in pos_tags else -1
# vec = feature_vector(word, prev_idx, next_idx, current_idx)
vec = feature_vector(word, idx/l, sent_pos_tags[idx][1])
y_pred = model.predict([vec])
predictions.append(round(y_pred[0]))
return tokens, predictions