Quintino Fernandes commited on
Commit
a2682b3
·
1 Parent(s): e3bef92

All models and query

Browse files
Dockerfile CHANGED
@@ -1,13 +1,32 @@
1
  FROM python:3.12-slim
2
 
 
 
 
 
 
 
 
 
3
  RUN useradd -m -u 1000 user
4
  USER user
5
  ENV PATH="/home/user/.local/bin:$PATH"
6
 
7
  WORKDIR /app
8
 
 
9
  COPY --chown=user ./requirements.txt requirements.txt
10
- RUN pip install --no-cache-dir --upgrade -r requirements.txt
 
 
 
 
11
 
 
12
  COPY --chown=user . /app
 
 
 
 
 
13
  CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
 
1
  FROM python:3.12-slim
2
 
3
+ # Install system dependencies
4
+ RUN apt-get update && apt-get install -y \
5
+ build-essential \
6
+ libpq-dev \
7
+ curl \
8
+ && rm -rf /var/lib/apt/lists/*
9
+
10
+ # Add a non-root user
11
  RUN useradd -m -u 1000 user
12
  USER user
13
  ENV PATH="/home/user/.local/bin:$PATH"
14
 
15
  WORKDIR /app
16
 
17
+ # Copy and install Python dependencies
18
  COPY --chown=user ./requirements.txt requirements.txt
19
+ RUN pip install --no-cache-dir --upgrade pip && \
20
+ pip install --no-cache-dir -r requirements.txt
21
+
22
+ # Download Spacy model
23
+ RUN python -m spacy download pt_core_news_md
24
 
25
+ # Copy application code
26
  COPY --chown=user . /app
27
+
28
+ # Expose the application port
29
+ EXPOSE 7860
30
+
31
+ # Run the application
32
  CMD ["uvicorn", "app:app", "--host", "0.0.0.0", "--port", "7860"]
app.py CHANGED
@@ -1,25 +1,42 @@
1
- # main.py
2
- import logging
3
  from fastapi import FastAPI, HTTPException, BackgroundTasks
4
  from fastapi.middleware.cors import CORSMiddleware
5
  from pydantic import BaseModel
6
- from typing import Dict, Optional
7
  import uuid
8
- from datetime import datetime, timedelta
9
- import asyncio
10
- import random
11
- # from sentence_transformers import SentenceTransformer
12
- # from transformers import T5Tokenizer, T5ForConditionalGeneration
13
- # from LexRank import degree_centrality_scores
14
- # import torch
15
- # import nltk
16
- # import spacy
17
- # from psycopg2 import sql
18
 
 
 
 
 
 
19
 
20
- app = FastAPI(title="Kairos News API", version="1.0")
21
 
22
- # Enable CORS
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  app.add_middleware(
24
  CORSMiddleware,
25
  allow_origins=["*"],
@@ -27,13 +44,24 @@ app.add_middleware(
27
  allow_headers=["*"],
28
  )
29
 
30
- # In-memory database simulation
31
  jobs_db: Dict[str, Dict] = {}
32
 
33
  class PostRequest(BaseModel):
34
  query: str
 
 
 
 
 
 
 
 
 
35
  topic: str
36
- date: str # Format: "YYYY/MM to YYYY/MM"
 
 
37
 
38
  class JobStatus(BaseModel):
39
  id: str
@@ -45,10 +73,8 @@ class JobStatus(BaseModel):
45
 
46
  @app.post("/index", response_model=JobStatus)
47
  async def create_job(request: PostRequest, background_tasks: BackgroundTasks):
48
- """Create a new processing job"""
49
  job_id = str(uuid.uuid4())
50
 
51
- # Store initial job data
52
  jobs_db[job_id] = {
53
  "status": "processing",
54
  "created_at": datetime.now(),
@@ -56,9 +82,16 @@ async def create_job(request: PostRequest, background_tasks: BackgroundTasks):
56
  "request": request.dict(),
57
  "result": None
58
  }
59
- logging.info(f"Job {job_id} created with request: {request.query}")
60
- # Simulate background processing
61
- background_tasks.add_task(process_job, job_id)
 
 
 
 
 
 
 
62
 
63
  return {
64
  "id": job_id,
@@ -71,48 +104,42 @@ async def create_job(request: PostRequest, background_tasks: BackgroundTasks):
71
 
72
  @app.get("/loading", response_model=JobStatus)
73
  async def get_job_status(id: str):
74
- """Check job status with timeout simulation"""
75
  if id not in jobs_db:
76
  raise HTTPException(status_code=404, detail="Job not found")
77
 
78
- job = jobs_db[id]
79
-
80
- # Simulate random processing time (3-25 seconds)
81
- elapsed = datetime.now() - job["created_at"]
82
- if elapsed < timedelta(seconds=3):
83
- await asyncio.sleep(1) # Artificial delay
84
-
85
- # 10% chance of failure for demonstration
86
- if random.random() < 0.1 and job["status"] == "processing":
87
- job["status"] = "failed"
88
- job["result"] = {"error": "Random processing failure"}
89
-
90
- return {
91
- "id": id,
92
- "status": job["status"],
93
- "created_at": job["created_at"],
94
- "completed_at": job["completed_at"],
95
- "request": job["request"],
96
- "result": job["result"]
97
- }
98
 
99
- async def process_job(job_id: str):
100
- """Background task to simulate processing"""
101
- await asyncio.sleep(random.uniform(3, 10)) # Random processing time
102
-
103
- if job_id in jobs_db:
104
- jobs_db[job_id]["status"] = "completed"
105
- jobs_db[job_id]["completed_at"] = datetime.now()
106
- jobs_db[job_id]["result"] = {
107
- "query": jobs_db[job_id]["request"]["query"],
108
- "topic": jobs_db[job_id]["request"]["topic"],
109
- "date_range": jobs_db[job_id]["request"]["date"],
110
- "analysis": f"Processed results for {jobs_db[job_id]['request']['query']}",
111
- "sources": ["Source A", "Source B", "Source C"],
112
- "summary": "This is a generated summary based on your query."
113
- }
114
-
115
- @app.get("/jobs")
116
- async def list_jobs():
117
- """Debug endpoint to view all jobs"""
118
- return jobs_db
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from fastapi import FastAPI, HTTPException, BackgroundTasks
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from pydantic import BaseModel
4
+ from typing import Dict, Optional, List
5
  import uuid
6
+ from datetime import datetime
7
+ from contextlib import asynccontextmanager
 
 
 
 
 
 
 
 
8
 
9
+ from models.embedding import EmbeddingModel
10
+ from models.summarization import SummarizationModel
11
+ from models.nlp import NLPModel
12
+ from database.query import DatabaseService
13
+ from database.query_processor import QueryProcessor
14
 
 
15
 
16
+ # Initialize models
17
+ embedding_model = None
18
+ summarization_model = None
19
+ nlp_model = None
20
+ db_service = None
21
+
22
+ @asynccontextmanager
23
+ async def lifespan(app: FastAPI):
24
+ # Load models when app starts
25
+ global embedding_model, summarization_model, nlp_model, db_service
26
+ embedding_model = EmbeddingModel()
27
+ summarization_model = SummarizationModel()
28
+ nlp_model = NLPModel()
29
+ db_service = DatabaseService()
30
+ yield
31
+ # Clean up when app stops
32
+ await db_service.close()
33
+
34
+ app = FastAPI(
35
+ title="Kairos News API",
36
+ version="1.0",
37
+ lifespan=lifespan
38
+ )
39
+
40
  app.add_middleware(
41
  CORSMiddleware,
42
  allow_origins=["*"],
 
44
  allow_headers=["*"],
45
  )
46
 
47
+ # In-memory job storage
48
  jobs_db: Dict[str, Dict] = {}
49
 
50
  class PostRequest(BaseModel):
51
  query: str
52
+ topic: Optional[str] = None
53
+ start_date: Optional[str] = None # Format: "YYYY-MM-DD"
54
+ end_date: Optional[str] = None # Format: "YYYY-MM-DD"
55
+
56
+ class ArticleResult(BaseModel):
57
+ url: str
58
+ content: str
59
+ distance: float
60
+ date: str
61
  topic: str
62
+
63
+ class SummaryResult(BaseModel):
64
+ summary: str
65
 
66
  class JobStatus(BaseModel):
67
  id: str
 
73
 
74
  @app.post("/index", response_model=JobStatus)
75
  async def create_job(request: PostRequest, background_tasks: BackgroundTasks):
 
76
  job_id = str(uuid.uuid4())
77
 
 
78
  jobs_db[job_id] = {
79
  "status": "processing",
80
  "created_at": datetime.now(),
 
82
  "request": request.dict(),
83
  "result": None
84
  }
85
+
86
+ background_tasks.add_task(
87
+ process_job,
88
+ job_id,
89
+ request,
90
+ embedding_model,
91
+ summarization_model,
92
+ nlp_model,
93
+ db_service
94
+ )
95
 
96
  return {
97
  "id": job_id,
 
104
 
105
  @app.get("/loading", response_model=JobStatus)
106
  async def get_job_status(id: str):
 
107
  if id not in jobs_db:
108
  raise HTTPException(status_code=404, detail="Job not found")
109
 
110
+ return jobs_db[id]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
 
112
+ async def process_job(
113
+ job_id: str,
114
+ request: PostRequest,
115
+ embedding_model: EmbeddingModel,
116
+ summarization_model: SummarizationModel,
117
+ nlp_model: NLPModel,
118
+ db_service: DatabaseService
119
+ ):
120
+ try:
121
+ processor = QueryProcessor(
122
+ embedding_model=embedding_model,
123
+ summarization_model=summarization_model,
124
+ nlp_model=nlp_model,
125
+ db_service=db_service
126
+ )
127
+
128
+ result = await processor.process(
129
+ query=request.query,
130
+ topic=request.topic,
131
+ start_date=request.start_date,
132
+ end_date=request.end_date
133
+ )
134
+
135
+ jobs_db[job_id].update({
136
+ "status": "completed",
137
+ "completed_at": datetime.now(),
138
+ "result": result
139
+ })
140
+ except Exception as e:
141
+ jobs_db[job_id].update({
142
+ "status": "failed",
143
+ "completed_at": datetime.now(),
144
+ "result": {"error": str(e)}
145
+ })
app_antiga.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # main.py
2
+ import logging
3
+ from fastapi import FastAPI, HTTPException, BackgroundTasks
4
+ from fastapi.middleware.cors import CORSMiddleware
5
+ from pydantic import BaseModel
6
+ from typing import Dict, Optional
7
+ import uuid
8
+ from datetime import datetime, timedelta
9
+ import asyncio
10
+ import random
11
+ from sentence_transformers import SentenceTransformer
12
+ from transformers import T5Tokenizer, T5ForConditionalGeneration
13
+ from models.LexRank import degree_centrality_scores
14
+ import torch
15
+ import nltk
16
+ import spacy
17
+ from psycopg2 import sql
18
+
19
+
20
+
21
+ app = FastAPI(title="Kairos News API", version="1.0")
22
+
23
+ # Enable CORS
24
+ app.add_middleware(
25
+ CORSMiddleware,
26
+ allow_origins=["*"],
27
+ allow_methods=["*"],
28
+ allow_headers=["*"],
29
+ )
30
+
31
+ #Database connection setup
32
+ url = "https://daxquaudqidyeirypexa.supabase.co"
33
+ key = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJzdXBhYmFzZSIsInJlZiI6ImRheHF1YXVkcWlkeWVpcnlwZXhhIiwicm9sZSI6ImFub24iLCJpYXQiOjE3NDQzOTIzNzcsImV4cCI6MjA1OTk2ODM3N30.3qB-GfiCoqXEpbNfqV3iHiqOLr8Ex9nPVr6p9De5Hdc"
34
+ opts = ClientOptions().replace(schema="articles")
35
+ supabase = create_client(url, key, options=opts)
36
+
37
+ # Loading models
38
+ nlp = spacy.load("pt_core_news_md")
39
+ model_embedding = SentenceTransformer("paraphrase-multilingual-MiniLM-L12-v2")
40
+ token_name = 'unicamp-dl/ptt5-base-portuguese-vocab'
41
+ model_name = 'recogna-nlp/ptt5-base-summ'
42
+ tokenizer = T5Tokenizer.from_pretrained(token_name)
43
+ model_summ = T5ForConditionalGeneration.from_pretrained(model_name).to('cuda')
44
+
45
+ # In-memory database simulation
46
+ jobs_db: Dict[str, Dict] = {}
47
+
48
+ class PostRequest(BaseModel):
49
+ query: str
50
+ topic: str
51
+ start_date: str # Format: "YYYY/MM to YYYY/MM"
52
+ end_date: str # Format: "YYYY/MM to YYYY/MM"
53
+
54
+ class JobStatus(BaseModel):
55
+ id: str
56
+ status: str # "processing", "completed", "failed"
57
+ created_at: datetime
58
+ completed_at: Optional[datetime]
59
+ request: PostRequest
60
+ result: Optional[Dict]
61
+
62
+ @app.post("/index", response_model=JobStatus)
63
+ async def create_job(request: PostRequest, background_tasks: BackgroundTasks):
64
+ """Create a new processing job"""
65
+ job_id = str(uuid.uuid4())
66
+
67
+ # Store initial job data
68
+ jobs_db[job_id] = {
69
+ "status": "processing",
70
+ "created_at": datetime.now(),
71
+ "completed_at": None,
72
+ "request": request.dict(),
73
+ "result": None
74
+ }
75
+
76
+ logging.info(f"Job {job_id} created with request: {request.query}")
77
+ # Simulate background processing
78
+ background_tasks.add_task(process_job, job_id)
79
+
80
+ return {
81
+ "id": job_id,
82
+ "status": "processing",
83
+ "created_at": jobs_db[job_id]["created_at"],
84
+ "completed_at": None,
85
+ "request": request,
86
+ "result": None
87
+ }
88
+
89
+
90
+ @app.get("/loading", response_model=JobStatus)
91
+ async def get_job_status(id: str):
92
+ """Check job status with timeout simulation"""
93
+ if id not in jobs_db:
94
+ raise HTTPException(status_code=404, detail="Job not found")
95
+
96
+ job = jobs_db[id]
97
+
98
+ # Simulate random processing time (3-25 seconds)
99
+ elapsed = datetime.now() - job["created_at"]
100
+ if elapsed < timedelta(seconds=3):
101
+ await asyncio.sleep(1) # Artificial delay
102
+
103
+ # 10% chance of failure for demonstration
104
+ if random.random() < 0.1 and job["status"] == "processing":
105
+ job["status"] = "failed"
106
+ job["result"] = {"error": "Random processing failure"}
107
+
108
+ return {
109
+ "id": id,
110
+ "status": job["status"],
111
+ "created_at": job["created_at"],
112
+ "completed_at": job["completed_at"],
113
+ "request": job["request"],
114
+ "result": job["result"]
115
+ }
116
+
117
+ async def process_job(job_id: str):
118
+ """Background task to simulate processing"""
119
+ await asyncio.sleep(random.uniform(3, 10)) # Random processing time
120
+
121
+ if job_id in jobs_db:
122
+ jobs_db[job_id]["status"] = "completed"
123
+ jobs_db[job_id]["completed_at"] = datetime.now()
124
+ jobs_db[job_id]["result"] = {
125
+ "query": jobs_db[job_id]["request"]["query"],
126
+ "topic": jobs_db[job_id]["request"]["topic"],
127
+ "date_range": jobs_db[job_id]["request"]["date"],
128
+ "analysis": f"Processed results for {jobs_db[job_id]['request']['query']}",
129
+ "sources": ["Source A", "Source B", "Source C"],
130
+ "summary": "This is a generated summary based on your query."
131
+ }
132
+
133
+ @app.get("/jobs")
134
+ async def list_jobs():
135
+ """Debug endpoint to view all jobs"""
136
+ return jobs_db
database/query.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List, Dict, Optional
3
+ import vecs
4
+ from datetime import datetime
5
+
6
+ class DatabaseService:
7
+ def __init__(self):
8
+ # Connection parameters
9
+ self.DB_HOST = os.getenv("SUPABASE_HOST", "db.daxquaudqidyeirypexa.supabase.co")
10
+ self.DB_PORT = os.getenv("DB_PORT", "5432")
11
+ self.DB_NAME = os.getenv("DB_NAME", "postgres")
12
+ self.DB_USER = os.getenv("DB_USER", "postgres")
13
+ self.DB_PASSWORD = os.getenv("DB_PASSWORD", "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJpc3MiOiJzdXBhYmFzZSIsInJlZiI6ImRheHF1YXVkcWlkeWVpcnlwZXhhIiwicm9sZSI6ImFub24iLCJpYXQiOjE3NDQzOTIzNzcsImV4cCI6MjA1OTk2ODM3N30.3qB-GfiCoqXEpbNfqV3iHiqOLr8Ex9nPVr6p9De5Hdc")
14
+
15
+ # Create vecs client
16
+ self.vx = vecs.create_client(
17
+ f"postgresql://{self.DB_USER}:{self.DB_PASSWORD}@{self.DB_HOST}:{self.DB_PORT}/{self.DB_NAME}"
18
+ )
19
+
20
+ # Get or create the collection
21
+ self.articles = self.vx.get_or_create_collection(
22
+ name="articles",
23
+ dimension=384 # Match your embedding model's output dimension
24
+ )
25
+
26
+ async def semantic_search(
27
+ self,
28
+ query_embedding: List[float],
29
+ start_date: Optional[datetime] = None,
30
+ end_date: Optional[datetime] = None,
31
+ topic: Optional[str] = None,
32
+ entities: Optional[List[str]] = None, # Add entities parameter
33
+ limit: int = 10
34
+ ) -> List[Dict[str, any]]:
35
+ try:
36
+ # Base vector search
37
+ filters = self._build_filters(start_date, end_date, topic)
38
+
39
+ # Add entity filter if entities are provided
40
+ if entities:
41
+ filters["entities"] = {"$in": entities}
42
+
43
+ results = self.articles.query(
44
+ data=query_embedding,
45
+ limit=limit,
46
+ filters=filters,
47
+ measure="cosine_distance" # or "inner_product", "l2_distance"
48
+ )
49
+
50
+ # Format results with metadata
51
+ formatted_results = []-
52
+ for article_id, distance in results:
53
+ metadata = self.articles.fetch(ids=[article_id])[0]["metadata"]
54
+ formatted_results.append({
55
+ "id": article_id,
56
+ "url": metadata.get("url"),
57
+ "content": metadata.get("content"),
58
+ "date": metadata.get("date"),
59
+ "topic": metadata.get("topic"),
60
+ "distance": float(distance),
61
+ "similarity": 1 - float(distance) # Convert to similarity score
62
+ })
63
+
64
+ return formatted_results
65
+
66
+ except Exception as e:
67
+ print(f"Vector search error: {e}")
68
+ return []
69
+
70
+ def _build_filters(
71
+ self,
72
+ start_date: Optional[datetime],
73
+ end_date: Optional[datetime],
74
+ topic: Optional[str]
75
+ ) -> Dict[str, any]:
76
+ filters = {}
77
+
78
+ if start_date and end_date:
79
+ filters["date"] = {
80
+ "$gte": start_date.isoformat(),
81
+ "$lte": end_date.isoformat()
82
+ }
83
+
84
+ if topic:
85
+ filters["topic"] = {"$eq": topic}
86
+
87
+ return filters
88
+
89
+ async def close(self):
90
+ self.vx.disconnect()
database/query_processor.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict, Any
2
+ import numpy as np
3
+ from LexRank import degree_centrality_scores
4
+
5
+ class QueryProcessor:
6
+ def __init__(self, embedding_model, summarization_model, nlp_model, db_service):
7
+ self.embedding_model = embedding_model
8
+ self.summarization_model = summarization_model
9
+ self.nlp_model = nlp_model
10
+ self.db_service = db_service
11
+
12
+ async def process(
13
+ self,
14
+ query: str,
15
+ topic: Optional[str] = None,
16
+ start_date: Optional[str] = None,
17
+ end_date: Optional[str] = None
18
+ ) -> Dict[str, Any]:
19
+ # Convert string dates to datetime objects
20
+ start_dt = datetime.strptime(start_date, "%Y-%m-%d") if start_date else None
21
+ end_dt = datetime.strptime(end_date, "%Y-%m-%d") if end_date else None
22
+
23
+ # Get query embedding
24
+ query_embedding = self.embedding_model.encode(query).tolist()
25
+
26
+ # Get entities from the query
27
+ doc = self.nlp_model(query)
28
+ entities = [ent.text.lower() for ent in doc.ents] # Extract entity texts
29
+
30
+ # Semantic search with entities
31
+ articles = await self.db_service.semantic_search(
32
+ query_embedding=query_embedding,
33
+ start_date=start_dt,
34
+ end_date=end_dt,
35
+ topic=topic,
36
+ entities=entities # Pass entities to the search
37
+ )
38
+
39
+ if not articles:
40
+ return {"error": "No articles found matching the criteria"}
41
+
42
+ # Step 3: Process results
43
+ contents = [article["content"] for article in articles]
44
+ sentences = []
45
+ for content in contents:
46
+ sentences.extend(self.nlp_model.tokenize_sentences(content))
47
+
48
+ # Step 4: Generate summary
49
+ if sentences:
50
+ embeddings = self.embedding_model.encode(sentences)
51
+ similarity_matrix = np.inner(embeddings, embeddings)
52
+ centrality_scores = degree_centrality_scores(similarity_matrix, threshold=None)
53
+
54
+ top_indices = np.argsort(-centrality_scores)[0:10]
55
+ key_sentences = [sentences[idx].strip() for idx in top_indices]
56
+ combined_text = ' '.join(key_sentences)
57
+
58
+ summary = self.summarization_model.summarize(combined_text)
59
+ else:
60
+ key_sentences = []
61
+ summary = "No content available for summarization"
62
+
63
+ return {
64
+ "summary": summary,
65
+ "key_sentences": key_sentences,
66
+ "articles": articles
67
+ }
LexRank.py → models/LexRank.py RENAMED
File without changes
models/embedding.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from sentence_transformers import SentenceTransformer
2
+ import torch
3
+
4
+ class EmbeddingModel:
5
+ def __init__(self):
6
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
7
+ self.model = SentenceTransformer("paraphrase-multilingual-MiniLM-L12-v2")
8
+
9
+ def encode(self, text: str):
10
+ return self.model.encode(text, device=self.device)
models/nlp.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spacy
2
+ import nltk
3
+
4
+ class NLPModel:
5
+ def __init__(self):
6
+ self.nlp = spacy.load("pt_core_news_md")
7
+ nltk.download('punkt')
8
+
9
+ def extract_entities(self, text: str):
10
+ doc = self.nlp(text)
11
+ return [(ent.text.lower(), ent.label_) for ent in doc.ents]
12
+
13
+ def tokenize_sentences(self, text: str):
14
+ return nltk.sent_tokenize(text)
models/summarization.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import T5Tokenizer, T5ForConditionalGeneration
2
+ import torch
3
+
4
+ class SummarizationModel:
5
+ def __init__(self):
6
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
7
+ self.tokenizer = T5Tokenizer.from_pretrained('unicamp-dl/ptt5-base-portuguese-vocab')
8
+ self.model = T5ForConditionalGeneration.from_pretrained('recogna-nlp/ptt5-base-summ').to(self.device)
9
+
10
+ def summarize(self, text: str, max_length: int = 256, min_length: int = 128) -> str:
11
+ inputs = self.tokenizer.encode(
12
+ text,
13
+ max_length=512,
14
+ truncation=True,
15
+ return_tensors='pt'
16
+ ).to(self.device)
17
+
18
+ summary_ids = self.model.generate(
19
+ inputs,
20
+ max_length=max_length,
21
+ min_length=min_length,
22
+ num_beams=5,
23
+ no_repeat_ngram_size=3,
24
+ early_stopping=False
25
+ )
26
+
27
+ return self.tokenizer.decode(summary_ids[0], skip_special_tokens=True)
requirements.txt CHANGED
@@ -1,12 +1,13 @@
1
  fastapi
2
  uvicorn[standard]
3
  logging
4
- # transformers
5
- # torch
6
- # sentence_transformers
7
- # nltk
8
- # spacy
9
- # numpy
10
- # pandas
11
- # scipy
12
- # psycopg2
 
 
1
  fastapi
2
  uvicorn[standard]
3
  logging
4
+ transformers
5
+ torch
6
+ sentence_transformers
7
+ nltk
8
+ spacy
9
+ numpy
10
+ pandas
11
+ scipy
12
+ psycopg2
13
+ vecs