jacq-is commited on
Commit
913c2f6
·
verified ·
1 Parent(s): 0ae53cb

Update tasks/text.py

Browse files
Files changed (1) hide show
  1. tasks/text.py +25 -7
tasks/text.py CHANGED
@@ -7,20 +7,31 @@ import random
7
  from .utils.evaluation import TextEvaluationRequest
8
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
9
 
 
 
 
 
 
 
10
  router = APIRouter()
11
 
12
- DESCRIPTION = "Random Baseline"
13
  ROUTE = "/text"
14
 
15
  @router.post(ROUTE, tags=["Text Task"],
16
  description=DESCRIPTION)
 
 
 
 
 
 
 
 
17
  async def evaluate_text(request: TextEvaluationRequest):
18
  """
19
  Evaluate text classification for climate disinformation detection.
20
-
21
- Current Model: Random Baseline
22
- - Makes random predictions from the label space (0-7)
23
- - Used as a baseline for comparison
24
  """
25
  # Get space info
26
  username, space_url = get_space_info()
@@ -56,9 +67,16 @@ async def evaluate_text(request: TextEvaluationRequest):
56
  # Update the code below to replace the random baseline by your model inference within the inference pass where the energy consumption and emissions are tracked.
57
  #--------------------------------------------------------------------------------------------
58
 
59
- # Make random predictions (placeholder for actual model inference)
 
 
 
 
 
 
 
 
60
  true_labels = test_dataset["label"]
61
- predictions = [random.randint(0, 7) for _ in range(len(true_labels))]
62
 
63
  #--------------------------------------------------------------------------------------------
64
  # YOUR MODEL INFERENCE STOPS HERE
 
7
  from .utils.evaluation import TextEvaluationRequest
8
  from .utils.emissions import tracker, clean_emissions_data, get_space_info
9
 
10
+ #modified: additional lib
11
+ import tensorflow as tf
12
+ from huggingface_hub import hf_hub_download
13
+ from transformers import TFElectraForSequenceClassification, ElectraTokenizer, ElectraConfig
14
+ #
15
+
16
  router = APIRouter()
17
 
18
+ DESCRIPTION = "Finetuned ELECTRA"
19
  ROUTE = "/text"
20
 
21
  @router.post(ROUTE, tags=["Text Task"],
22
  description=DESCRIPTION)
23
+
24
+ #modified: retrieve model
25
+ model_repo = "jennasparks/electra_tf"
26
+ config = ElectraConfig.from_pretrained(model_repo)
27
+ model = TFElectraForSequenceClassification.from_pretrained(model_repo)
28
+ tokenizer = ElectraTokenizer.from_pretrained("google/electra-base-discriminator")
29
+ #
30
+
31
  async def evaluate_text(request: TextEvaluationRequest):
32
  """
33
  Evaluate text classification for climate disinformation detection.
34
+ Current Model: Finetuned ELECTRA
 
 
 
35
  """
36
  # Get space info
37
  username, space_url = get_space_info()
 
67
  # Update the code below to replace the random baseline by your model inference within the inference pass where the energy consumption and emissions are tracked.
68
  #--------------------------------------------------------------------------------------------
69
 
70
+ #make predictions
71
+ predictions = []
72
+
73
+ for i in range(len(test_dataset["quote"])):
74
+ encoded_input = tokenizer(test_dataset["quote"][i], truncation=True, padding=True, return_tensors="tf")
75
+ outputs = model(encoded_input["input_ids"], attention_mask=encoded_input["attention_mask"], training=False)
76
+ predictions.append(tf.argmax(outputs.logits, axis=1))
77
+
78
+ # Get true labels
79
  true_labels = test_dataset["label"]
 
80
 
81
  #--------------------------------------------------------------------------------------------
82
  # YOUR MODEL INFERENCE STOPS HERE