ADK09 commited on
Commit
f54c8a8
·
1 Parent(s): 5c50bbf

Added clean_text function to remove filler words and extra spaces

Browse files
Files changed (2) hide show
  1. app.py +21 -6
  2. requirements.txt +2 -1
app.py CHANGED
@@ -2,6 +2,7 @@ from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel
3
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
  import torch
 
5
 
6
  app = FastAPI()
7
 
@@ -11,6 +12,12 @@ model = AutoModelForSequenceClassification.from_pretrained("cardiffnlp/twitter-r
11
  class TextRequest(BaseModel):
12
  text: str
13
 
 
 
 
 
 
 
14
  @app.get("/")
15
  def home():
16
  return {"message": "Speak your mind emotion API is running"}
@@ -18,20 +25,28 @@ def home():
18
  @app.post("/classify-emotion")
19
  async def classify_emotion(request: TextRequest):
20
  try:
21
- text = request.text
 
 
 
 
22
 
23
- inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=512)
24
 
25
  with torch.no_grad():
26
  outputs = model(**inputs)
27
 
28
-
29
  logits = outputs.logits
30
  predicted_class_id = torch.argmax(logits, dim=-1).item()
31
- predicted_emotion = model.config.id2label[predicted_class_id]
32
 
33
- return {"predicted_emotion": predicted_emotion}
 
 
 
 
34
 
35
  except Exception as e:
36
- raise HTTPException(status_code=500, detail=str(e))
 
37
 
 
2
  from pydantic import BaseModel
3
  from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
  import torch
5
+ import re
6
 
7
  app = FastAPI()
8
 
 
12
  class TextRequest(BaseModel):
13
  text: str
14
 
15
+ def clean_text(text: str) -> str:
16
+ fillers = ["um", "uh", "like", "you know", "I mean", "sort of", "kind of", "hmm", "uhh"]
17
+ text = re.sub(r'\b(?:' + '|'.join(fillers) + r')\b', '', text, flags=re.IGNORECASE)
18
+ text = re.sub(r'\s+', ' ', text).strip()
19
+ return text
20
+
21
  @app.get("/")
22
  def home():
23
  return {"message": "Speak your mind emotion API is running"}
 
25
  @app.post("/classify-emotion")
26
  async def classify_emotion(request: TextRequest):
27
  try:
28
+ text = request.text.strip()
29
+
30
+ if not text:
31
+ raise HTTPException(status_code=400, detail="Text cannot be empty")
32
+ cleaned_text = clean_text(text)
33
 
34
+ inputs = tokenizer(cleaned_text, return_tensors="pt", truncation=True, padding=True, max_length=512)
35
 
36
  with torch.no_grad():
37
  outputs = model(**inputs)
38
 
 
39
  logits = outputs.logits
40
  predicted_class_id = torch.argmax(logits, dim=-1).item()
41
+ predicted_emotion = model.config.id2label[predicted_class_id]
42
 
43
+ return {
44
+ "original_text": text,
45
+ "cleaned_text": cleaned_text,
46
+ "predicted_emotion": predicted_emotion
47
+ }
48
 
49
  except Exception as e:
50
+ raise HTTPException(status_code=500, detail=f"Error processing text: {str(e)}")
51
+
52
 
requirements.txt CHANGED
@@ -3,4 +3,5 @@ uvicorn
3
  transformers
4
  torch
5
  httpx
6
- pytest
 
 
3
  transformers
4
  torch
5
  httpx
6
+ pytest
7
+ pydantic