Mridul2003 commited on
Commit
fc66fa8
·
1 Parent(s): 2bf22ab

Copied all files from local

Browse files
Dockerfile ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10
2
+
3
+ WORKDIR /app
4
+
5
+ COPY requirements.txt .
6
+
7
+ RUN pip install -r requirements.txt
8
+
9
+ COPY . .
10
+
11
+ EXPOSE 3000
12
+
13
+ CMD ["uvicorn", "api:app", "--host", "0.0.0.0", "--port", "3000"]
14
+
15
+
16
+
api.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from pydantic import BaseModel
3
+ from model_loader import ModelLoader
4
+ from services.text_filter import TextFilterService
5
+ from services.image_ocr import ImageOCRService
6
+ from typing import Optional
7
+ from fastapi.responses import JSONResponse
8
+ import logging
9
+
10
+ logging.basicConfig(
11
+ level=logging.INFO,
12
+ format="%(asctime)s [%(levelname)s] %(message)s",
13
+ handlers=[
14
+ logging.StreamHandler()
15
+ ]
16
+ )
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+ app = FastAPI()
21
+ logger.info("Starting FastAPI app...")
22
+
23
+ model_loader = ModelLoader()
24
+ logger.info("ModelLoader initialized.")
25
+
26
+ text_filter_service = TextFilterService(model_loader)
27
+ logger.info("TextFilterService initialized.")
28
+
29
+ image_ocr_service = ImageOCRService()
30
+ logger.info("Image OCR image initialized")
31
+
32
+
33
+ class InputData(BaseModel):
34
+ text: Optional[str] = None
35
+ image_url: Optional[str] = None
36
+
37
+ @app.post("/filtercomment")
38
+ async def filter_comment(input_data: InputData):
39
+ logger.info("Received request: %s", input_data)
40
+ final_text = ""
41
+ # Case 1: Extract text from image
42
+ if input_data.image_url:
43
+ logger.info("Image URL provided: %s", input_data.image_url)
44
+ try:
45
+ logger.info("Fetching image from URL...")
46
+ final_text = image_ocr_service.extract_text(input_data.image_url)
47
+ logger.info("Generated text: %s", final_text)
48
+
49
+ except Exception as e:
50
+ logger.error("Image processing failed: %s", str(e))
51
+ return JSONResponse(status_code=400, content={"error": f"Image processing failed: {str(e)}"})
52
+
53
+ # Case 2: Use provided text
54
+ elif input_data.text:
55
+ logger.info("Text input provided.")
56
+ final_text = input_data.text
57
+ else:
58
+ logger.warning("No input provided.")
59
+ return JSONResponse(status_code=400, content={"error": "Either 'text' or 'image_url' must be provided."})
60
+
61
+ try:
62
+ logger.info("Processing text through TextFilterService...")
63
+ results = text_filter_service.process_text(final_text)
64
+ results["extracted_text"] = final_text
65
+ logger.info("Text filtering complete. Results: %s", results)
66
+ return results
67
+
68
+ except Exception as e:
69
+ logger.exception("Text filtering failed.")
70
+ return JSONResponse(status_code=500, content={"error": f"Text filtering failed: {str(e)}"})
71
+
72
+ if __name__ == "__main__":
73
+ import uvicorn
74
+ logger.info("Starting Uvicorn server...")
75
+ uvicorn.run(app, host="0.0.0.0", port=3000)
model_loader.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
2
+ import torch
3
+ class ModelLoader:
4
+ def __init__(self):
5
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
6
+ self._load_models()
7
+
8
+ def _load_models(self):
9
+ self.hf_model = AutoModelForSequenceClassification.from_pretrained("unitary/toxic-bert").to(self.device)
10
+ self.hf_tokenizer = AutoTokenizer.from_pretrained("unitary/toxic-bert")
11
+
12
+ self.identity_model = AutoModelForSequenceClassification.from_pretrained(
13
+ "Mridul2003/identity-hate-detector"
14
+ ).to(self.device)
15
+
16
+ try:
17
+ self.identity_tokenizer = AutoTokenizer.from_pretrained("Mridul2003/identity-hate-detector")
18
+ except Exception:
19
+ self.identity_tokenizer = self.hf_tokenizer
services/__pycache__/text_filter.cpython-310.pyc ADDED
Binary file (2.24 kB). View file
 
services/image_ocr.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ from io import BytesIO
3
+ import requests
4
+ from transformers import pipeline
5
+
6
+ class ImageOCRService:
7
+ def __init__(self):
8
+ self.pipe = pipeline("image-text-to-text", model="ds4sd/SmolDocling-256M-preview")
9
+
10
+ def extract_text(self, image_url: str) -> str:
11
+ response = requests.get(image_url)
12
+ image = Image.open(BytesIO(response.content)).convert("RGB")
13
+
14
+ result = self.pipe(image, text=[{
15
+ "role": "user",
16
+ "content": [
17
+ {"type": "image", "url": image_url},
18
+ {"type": "text", "text": "extract text from image"}
19
+ ]
20
+ }])
21
+ return result[0]['generated_text'] if result else ""
services/text_filter.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from model_loader import ModelLoader
3
+
4
+ class TextFilterService:
5
+ def __init__(self, model_loader: ModelLoader):
6
+ self.model_loader = model_loader
7
+
8
+ def process_text(self, final_text: str) -> dict:
9
+ hf_tokenizer = self.model_loader.hf_tokenizer
10
+ hf_model = self.model_loader.hf_model
11
+ identity_model = self.model_loader.identity_model
12
+ identity_tokenizer = self.model_loader.identity_tokenizer
13
+ device = self.model_loader.device
14
+
15
+ # Toxic-BERT inference
16
+ hf_inputs = hf_tokenizer(final_text, return_tensors="pt", padding=True, truncation=True)
17
+ hf_inputs = {k: v.to(device) for k, v in hf_inputs.items()}
18
+ with torch.no_grad():
19
+ hf_outputs = hf_model(**hf_inputs)
20
+ hf_probs = torch.sigmoid(hf_outputs.logits)[0]
21
+ hf_labels = hf_model.config.id2label
22
+ results = {hf_labels.get(i, f"Label {i}"): float(prob) for i, prob in enumerate(hf_probs)}
23
+
24
+ # Identity hate classifier
25
+ identity_inputs = identity_tokenizer(final_text, return_tensors="pt", padding=True, truncation=True)
26
+ identity_inputs.pop("token_type_ids", None)
27
+ identity_inputs = {k: v.to(device) for k, v in identity_inputs.items()}
28
+ with torch.no_grad():
29
+ identity_outputs = identity_model(**identity_inputs)
30
+ identity_probs = torch.sigmoid(identity_outputs.logits)
31
+ identity_prob = identity_probs[0][1].item()
32
+ not_identity_prob = identity_probs[0][0].item()
33
+
34
+ results["identity_hate_custom"] = identity_prob
35
+ results["not_identity_hate_custom"] = not_identity_prob
36
+
37
+ results["safe"] = (
38
+ all(results.get(label, 0) < 0.5 for label in ['toxic', 'severe_toxic', 'obscene', 'insult', 'identity_hate'])
39
+ and identity_prob < 0.5
40
+ )
41
+
42
+ return results