|
import torch |
|
import tensorflow as tf |
|
from tf_keras import models, layers |
|
from transformers import AutoTokenizer, AutoModelForSequenceClassification, AutoModelForTokenClassification, TFAutoModelForQuestionAnswering |
|
import gradio as gr |
|
import re |
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
qa_model_name = 'salsarra/ConfliBERT-QA' |
|
qa_model = TFAutoModelForQuestionAnswering.from_pretrained(qa_model_name) |
|
qa_tokenizer = AutoTokenizer.from_pretrained(qa_model_name) |
|
|
|
ner_model_name = 'eventdata-utd/conflibert-named-entity-recognition' |
|
ner_model = AutoModelForTokenClassification.from_pretrained(ner_model_name).to(device) |
|
ner_tokenizer = AutoTokenizer.from_pretrained(ner_model_name) |
|
|
|
clf_model_name = 'eventdata-utd/conflibert-binary-classification' |
|
clf_model = AutoModelForSequenceClassification.from_pretrained(clf_model_name).to(device) |
|
clf_tokenizer = AutoTokenizer.from_pretrained(clf_model_name) |
|
|
|
multi_clf_model_name = 'eventdata-utd/conflibert-satp-relevant-multilabel' |
|
multi_clf_model = AutoModelForSequenceClassification.from_pretrained(multi_clf_model_name).to(device) |
|
multi_clf_tokenizer = AutoTokenizer.from_pretrained(multi_clf_model_name) |
|
|
|
|
|
class_names = ['Negative', 'Positive'] |
|
multi_class_names = ["Armed Assault", "Bombing or Explosion", "Kidnapping", "Other"] |
|
|
|
|
|
ner_labels = { |
|
'Organisation': 'blue', |
|
'Person': 'red', |
|
'Location': 'green', |
|
'Quantity': 'orange', |
|
'Weapon': 'purple', |
|
'Nationality': 'cyan', |
|
'Temporal': 'magenta', |
|
'DocumentReference': 'brown', |
|
'MilitaryPlatform': 'yellow', |
|
'Money': 'pink' |
|
} |
|
|
|
def handle_error_message(e, default_limit=512): |
|
error_message = str(e) |
|
pattern = re.compile(r"The size of tensor a \((\d+)\) must match the size of tensor b \((\d+)\)") |
|
match = pattern.search(error_message) |
|
if match: |
|
number_1, number_2 = match.groups() |
|
return f"<span style='color: red; font-weight: bold;'>Error: Text Input is over limit where inserted text size {number_1} is larger than model limits of {number_2}</span>" |
|
pattern_qa = re.compile(r"indices\[0,(\d+)\] = \d+ is not in \[0, (\d+)\)") |
|
match_qa = pattern_qa.search(error_message) |
|
if match_qa: |
|
number_1, number_2 = match_qa.groups() |
|
return f"<span style='color: red; font-weight: bold;'>Error: Text Input is over limit where inserted text size {number_1} is larger than model limits of {number_2}</span>" |
|
return f"<span style='color: red; font-weight: bold;'>Error: Text Input is over limit where inserted text size is larger than model limits of {default_limit}</span>" |
|
|
|
|
|
def question_answering(context, question): |
|
try: |
|
inputs = qa_tokenizer(question, context, return_tensors='tf', truncation=True) |
|
outputs = qa_model(inputs) |
|
answer_start = tf.argmax(outputs.start_logits, axis=1).numpy()[0] |
|
answer_end = tf.argmax(outputs.end_logits, axis=1).numpy()[0] + 1 |
|
answer = qa_tokenizer.convert_tokens_to_string(qa_tokenizer.convert_ids_to_tokens(inputs['input_ids'].numpy()[0][answer_start:answer_end])) |
|
return f"<span style='color: green; font-weight: bold;'>{answer}</span>" |
|
except Exception as e: |
|
return handle_error_message(e) |
|
|
|
def replace_unk(tokens): |
|
return [token.replace('[UNK]', "'") for token in tokens] |
|
|
|
def named_entity_recognition(text): |
|
try: |
|
inputs = ner_tokenizer(text, return_tensors='pt', truncation=True) |
|
with torch.no_grad(): |
|
outputs = ner_model(**inputs) |
|
ner_results = outputs.logits.argmax(dim=2).squeeze().tolist() |
|
tokens = ner_tokenizer.convert_ids_to_tokens(inputs['input_ids'].squeeze().tolist()) |
|
tokens = replace_unk(tokens) |
|
entities = [] |
|
seen_labels = set() |
|
for i in range(len(tokens)): |
|
token = tokens[i] |
|
label = ner_model.config.id2label[ner_results[i]].split('-')[-1] |
|
if token.startswith('##'): |
|
if entities: |
|
entities[-1][0] += token[2:] |
|
else: |
|
entities.append([token, label]) |
|
if label != 'O': |
|
seen_labels.add(label) |
|
|
|
highlighted_text = "" |
|
for token, label in entities: |
|
color = ner_labels.get(label, 'black') |
|
if label != 'O': |
|
highlighted_text += f"<span style='color: {color}; font-weight: bold;'>{token}</span> " |
|
else: |
|
highlighted_text += f"{token} " |
|
|
|
legend = "<div><strong>NER Tags Found:</strong><ul style='list-style-type: disc; padding-left: 20px;'>" |
|
for label in seen_labels: |
|
color = ner_labels.get(label, 'black') |
|
legend += f"<li style='color: {color}; font-weight: bold;'>{label}</li>" |
|
legend += "</ul></div>" |
|
|
|
return f"<div>{highlighted_text}</div>{legend}" |
|
except Exception as e: |
|
return handle_error_message(e) |
|
|
|
def text_classification(text): |
|
try: |
|
inputs = clf_tokenizer(text, return_tensors='pt', truncation=True, padding=True).to(device) |
|
with torch.no_grad(): |
|
outputs = clf_model(**inputs) |
|
logits = outputs.logits.squeeze().tolist() |
|
predicted_class = torch.argmax(outputs.logits, dim=1).item() |
|
confidence = torch.softmax(outputs.logits, dim=1).max().item() * 100 |
|
|
|
if predicted_class == 1: |
|
result = f"<span style='color: green; font-weight: bold;'>Positive: The text is related to conflict, violence, or politics. (Confidence: {confidence:.2f}%)</span>" |
|
else: |
|
result = f"<span style='color: red; font-weight: bold;'>Negative: The text is not related to conflict, violence, or politics. (Confidence: {confidence:.2f}%)</span>" |
|
return result |
|
except Exception as e: |
|
return handle_error_message(e) |
|
|
|
def multilabel_classification(text): |
|
try: |
|
inputs = multi_clf_tokenizer(text, return_tensors='pt', truncation=True, padding=True).to(device) |
|
with torch.no_grad(): |
|
outputs = multi_clf_model(**inputs) |
|
predicted_classes = torch.sigmoid(outputs.logits).squeeze().tolist() |
|
if len(predicted_classes) != len(multi_class_names): |
|
return f"Error: Number of predicted classes ({len(predicted_classes)}) does not match number of class names ({len(multi_class_names)})." |
|
|
|
results = [] |
|
for i in range(len(predicted_classes)): |
|
confidence = predicted_classes[i] * 100 |
|
if predicted_classes[i] >= 0.5: |
|
results.append(f"<span style='color: green; font-weight: bold;'>{multi_class_names[i]} (Confidence: {confidence:.2f}%)</span>") |
|
else: |
|
results.append(f"<span style='color: red; font-weight: bold;'>{multi_class_names[i]} (Confidence: {confidence:.2f}%)</span>") |
|
|
|
return " / ".join(results) |
|
except Exception as e: |
|
return handle_error_message(e) |
|
|
|
|
|
def chatbot(task, text=None, context=None, question=None): |
|
if task == "Question Answering": |
|
if context and question: |
|
return question_answering(context, question) |
|
else: |
|
return "Please provide both context and question for the Question Answering task." |
|
elif task == "Named Entity Recognition": |
|
if text: |
|
return named_entity_recognition(text) |
|
else: |
|
return "Please provide text for the Named Entity Recognition task." |
|
elif task == "Text Classification": |
|
if text: |
|
return text_classification(text) |
|
else: |
|
return "Please provide text for the Text Classification task." |
|
elif task == "Multilabel Classification": |
|
if text: |
|
return multilabel_classification(text) |
|
else: |
|
return "Please provide text for the Multilabel Classification task." |
|
else: |
|
return "Please select a valid task." |
|
|
|
css = """ |
|
body { |
|
background-color: #f0f8ff; |
|
font-family: 'Helvetica Neue', Helvetica, Arial, sans-serif; |
|
} |
|
|
|
h1 { |
|
color: #2e8b57; |
|
text-align: center; |
|
font-size: 2em; |
|
} |
|
|
|
h2 { |
|
color: #ff8c00; |
|
text-align: center; |
|
font-size: 1.5em; |
|
} |
|
|
|
.gradio-container { |
|
max-width: 100%; |
|
margin: 10px auto; |
|
padding: 10px; |
|
background-color: #ffffff; |
|
border-radius: 10px; |
|
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.1); |
|
} |
|
|
|
.gr-input, .gr-output { |
|
background-color: #ffffff; |
|
border: 1px solid #ddd; |
|
border-radius: 5px; |
|
padding: 10px; |
|
font-size: 1em; |
|
} |
|
|
|
.gr-title { |
|
font-size: 1.5em; |
|
font-weight: bold; |
|
color: #2e8b57; |
|
margin-bottom: 10px; |
|
text-align: center; |
|
} |
|
|
|
.gr-description { |
|
font-size: 1.2em; |
|
color: #ff8c00; |
|
margin-bottom: 10px; |
|
text-align: center; |
|
} |
|
|
|
.header { |
|
display: flex; |
|
justify-content: center; |
|
align-items: center; |
|
padding: 10px; |
|
flex-wrap: wrap; |
|
} |
|
|
|
.header-title-center a { |
|
font-size: 4em; /* Increased font size */ |
|
font-weight: bold; /* Made text bold */ |
|
color: darkorange; /* Darker orange color */ |
|
text-align: center; |
|
display: block; |
|
} |
|
|
|
.gr-button { |
|
background-color: #ff8c00; |
|
color: white; |
|
border: none; |
|
padding: 10px 20px; |
|
font-size: 1em; |
|
border-radius: 5px; |
|
cursor: pointer; |
|
} |
|
|
|
.gr-button:hover { |
|
background-color: #ff4500; |
|
} |
|
|
|
.footer { |
|
text-align: center; |
|
margin-top: 10px; |
|
font-size: 0.9em; /* Updated font size */ |
|
color: #666; |
|
width: 100%; |
|
} |
|
|
|
.footer a { |
|
color: #2e8b57; |
|
font-weight: bold; |
|
text-decoration: none; |
|
} |
|
|
|
.footer a:hover { |
|
text-decoration: underline; |
|
} |
|
""" |
|
|
|
with gr.Blocks(css=css) as demo: |
|
with gr.Row(elem_id="header"): |
|
gr.Markdown("<div class='header-title-center'><a href='https://eventdata.utdallas.edu/conflibert/'>ConfliBERT</a></div>", elem_id="header-title-center") |
|
|
|
gr.Markdown("<span style='color: black;'>Select a task and provide the necessary inputs:</span>") |
|
|
|
task = gr.Dropdown(choices=["Question Answering", "Named Entity Recognition", "Text Classification", "Multilabel Classification"], label="Select Task") |
|
|
|
with gr.Row(): |
|
text_input = gr.Textbox(lines=5, placeholder="Enter the text here...", label="Text") |
|
context_input = gr.Textbox(lines=5, placeholder="Enter the context here...", label="Context", visible=False) |
|
question_input = gr.Textbox(lines=2, placeholder="Enter your question here...", label="Question", visible=False) |
|
|
|
output = gr.HTML(label="Output") |
|
|
|
def update_inputs(task): |
|
if task == "Question Answering": |
|
return gr.update(visible=False), gr.update(visible=True), gr.update(visible=True) |
|
else: |
|
return gr.update(visible=True), gr.update(visible=False), gr.update(visible=False) |
|
|
|
task.change(fn=update_inputs, inputs=task, outputs=[text_input, context_input, question_input]) |
|
|
|
def chatbot_interface(task, text, context, question): |
|
result = chatbot(task, text, context, question) |
|
return result |
|
|
|
submit_button = gr.Button("Submit", elem_id="gr-button") |
|
submit_button.click(fn=chatbot_interface, inputs=[task, text_input, context_input, question_input], outputs=output) |
|
|
|
gr.Markdown("<div class='footer'><a href='https://eventdata.utdallas.edu/'>UTD Event Data</a> | <a href='https://www.utdallas.edu/'>University of Texas at Dallas</a></div>") |
|
gr.Markdown("<div class='footer'>Developed By: <a href='https://www.linkedin.com/in/sultan-alsarra-phd-56977a63/' target='_blank'>Sultan Alsarra</a></div>") |
|
|
|
demo.launch(share=True) |
|
|