kxshrx commited on
Commit
a6dcc9d
·
verified ·
1 Parent(s): c5ddea9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -13
app.py CHANGED
@@ -1,24 +1,48 @@
1
  import gradio as gr
2
- from transformers import pipeline
 
3
 
4
- # This is the key line:
5
- # It loads your private model from your *other* repository.
6
- pipe = pipeline(
7
- "text-classification",
8
- model="kxshrx/infrnce-bert-classifier"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  )
 
 
10
 
11
  def classify_log(log_text):
12
- # This function runs the classification.
13
- results = pipe(log_text, top_k=None)
14
- # We format the result into a simple dictionary.
15
- return {item['label']: item['score'] for item in results[0]}
 
 
 
 
 
 
 
16
 
17
- # This creates a simple web UI for testing and, more importantly,
18
- # an API endpoint that we can call.
19
  gr.Interface(
20
  fn=classify_log,
21
  inputs=gr.Textbox(lines=5, label="Log Entry"),
22
- outputs=gr.Label(num_top_classes=6),
23
  title="Infrnce Private Log Classifier API"
24
  ).launch()
 
1
  import gradio as gr
2
+ import torch
3
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
 
5
+ # --- Configuration ---
6
+ MODEL_NAME = "distilbert-base-uncased"
7
+ NUM_LABELS = 6
8
+ MODEL_PATH = "controlled_bert_model.pth" # The name of the file you uploaded
9
+
10
+ # --- Load Tokenizer and Model ---
11
+ print("Loading tokenizer...")
12
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
13
+
14
+ print("Loading model architecture...")
15
+ # First, create the model "shell"
16
+ model = AutoModelForSequenceClassification.from_pretrained(
17
+ MODEL_NAME,
18
+ num_labels=NUM_LABELS
19
+ )
20
+
21
+ print(f"Loading fine-tuned weights from {MODEL_PATH}...")
22
+ # Now, load your trained weights into the shell
23
+ model.load_state_dict(
24
+ torch.load(MODEL_PATH, map_location=torch.device("cpu"))
25
  )
26
+ model.eval() # Set model to evaluation mode
27
+ print("Model loaded successfully!")
28
 
29
  def classify_log(log_text):
30
+ """
31
+ This function runs the classification using your loaded .pth model.
32
+ """
33
+ inputs = tokenizer(log_text, return_tensors="pt", padding=True, truncation=True)
34
+ with torch.no_grad():
35
+ logits = model(**inputs).logits
36
+
37
+ scores = torch.softmax(logits, dim=1).squeeze().tolist()
38
+ # Create a dictionary of {label_name: score}
39
+ confidences = {model.config.id2label[i]: score for i, score in enumerate(scores)}
40
+ return confidences
41
 
42
+ # This creates the Gradio interface and API endpoint
 
43
  gr.Interface(
44
  fn=classify_log,
45
  inputs=gr.Textbox(lines=5, label="Log Entry"),
46
+ outputs=gr.Label(num_top_classes=6, label="Classification Results"),
47
  title="Infrnce Private Log Classifier API"
48
  ).launch()