File size: 2,476 Bytes
59e0ff2
 
 
dbe21a2
 
59e0ff2
dbe21a2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59e0ff2
dbe21a2
 
 
 
 
 
 
 
59e0ff2
 
dbe21a2
59e0ff2
 
dbe21a2
59e0ff2
 
 
 
 
 
 
 
 
 
 
dbe21a2
59e0ff2
 
 
 
 
 
 
 
 
 
dbe21a2
59e0ff2
 
 
 
 
 
dbe21a2
59e0ff2
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import gradio as gr
import torch
import re
from transformers import AutoTokenizer, T5EncoderModel
import torch.nn as nn

# Klassendefinition aus dem Training
class FlanT5Classifier(nn.Module):
    def __init__(self, base_model_name="google/flan-t5-base", num_labels=4):
        super().__init__()
        self.encoder = T5EncoderModel.from_pretrained(base_model_name)
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(self.encoder.config.d_model, num_labels)

    def forward(self, input_ids, attention_mask=None):
        encoder_outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        pooled = encoder_outputs.last_hidden_state[:, 0]
        logits = self.classifier(self.dropout(pooled))
        return {"logits": logits}

# Tokenizer laden
tokenizer = AutoTokenizer.from_pretrained("pepegiallo/flan-t5-base_ner")

# Modell instanziieren und Token-Embeddings anpassen
model = FlanT5Classifier()
model.encoder.resize_token_embeddings(len(tokenizer))

# Gewichte laden
state_dict = torch.load("pytorch_model.bin", map_location="cpu")
model.load_state_dict(state_dict)
model.eval()

# ID-Zuordnung
id2label = {0: "LOC", 1: "ORG", 2: "PER", 3: "O"}

# Tokenizer-Funktionen
def custom_tokenize(text):
    return re.findall(r"\w+|[^\w\s]", text, re.UNICODE)

def custom_detokenize(tokens):
    text = ""
    for i, token in enumerate(tokens):
        if i > 0 and re.match(r"\w", token):
            text += " "
        text += token
    return text

# Klassifikationsfunktion
def classify_tokens(text):
    tokens = custom_tokenize(text)
    results = []

    for i in range(len(tokens)):
        wrapped = tokens[:i] + ["<TSTART>", tokens[i], "<TEND>"] + tokens[i+1:]
        prompt = "classify token in: " + custom_detokenize(wrapped)

        inputs = tokenizer(prompt, return_tensors="pt", padding="max_length", truncation=True, max_length=128)
        with torch.no_grad():
            logits = model(**inputs)["logits"]
            pred_id = torch.argmax(logits, dim=-1).item()
            label = id2label[pred_id]

        results.append((tokens[i], label))
    return results

# Gradio UI
demo = gr.Interface(
    fn=classify_tokens,
    inputs=gr.Textbox(lines=3, placeholder="Enter a sentence..."),
    outputs=gr.HighlightedText(label="Token Classification Output"),
    title="Flan-T5 Token Classification (NER)",
    description="Classifies each token in the input text as LOC, ORG, PER, or O."
)

demo.launch()