Dan Walsh commited on
Commit
9e707a5
·
1 Parent(s): 124b5b5

Adding additional improvements & refinements

Browse files
__pycache__/main.cpython-311.pyc CHANGED
Binary files a/__pycache__/main.cpython-311.pyc and b/__pycache__/main.cpython-311.pyc differ
 
app/api/__pycache__/routes.cpython-311.pyc CHANGED
Binary files a/app/api/__pycache__/routes.cpython-311.pyc and b/app/api/__pycache__/routes.cpython-311.pyc differ
 
app/api/routes.py CHANGED
@@ -5,7 +5,8 @@ from app.services.summariser import SummariserService
5
  from app.services.url_extractor import URLExtractorService
6
  from app.services.cache import hash_text, get_cached_summary, cache_summary
7
 
8
- router = APIRouter()
 
9
 
10
  class TextSummaryRequest(BaseModel):
11
  text: str = Field(..., min_length=10, description="The text to summarise")
@@ -45,8 +46,7 @@ async def summarise_text(request: TextSummaryRequest):
45
  return cached_summary
46
 
47
  # If not in cache, generate summary
48
- summariser = SummariserService()
49
- summary = summariser.summarise(
50
  text=request.text,
51
  max_length=request.max_length,
52
  min_length=request.min_length,
@@ -54,23 +54,6 @@ async def summarise_text(request: TextSummaryRequest):
54
  temperature=request.temperature
55
  )
56
 
57
- result = {
58
- "original_text_length": len(request.text),
59
- "summary": summary,
60
- "summary_length": len(summary),
61
- "source_type": "text"
62
- }
63
-
64
- # Cache the result
65
- cache_summary(
66
- text_hash,
67
- request.max_length,
68
- request.min_length,
69
- request.do_sample,
70
- request.temperature,
71
- result
72
- )
73
-
74
  return result
75
  except Exception as e:
76
  raise HTTPException(status_code=500, detail=str(e))
@@ -86,8 +69,7 @@ async def summarise_url(request: URLSummaryRequest):
86
  raise HTTPException(status_code=422, detail="Could not extract sufficient content from the URL")
87
 
88
  # Summarise the extracted content
89
- summariser = SummariserService()
90
- summary = summariser.summarise(
91
  text=content,
92
  max_length=request.max_length,
93
  min_length=request.min_length,
@@ -97,8 +79,8 @@ async def summarise_url(request: URLSummaryRequest):
97
 
98
  return {
99
  "original_text_length": len(content),
100
- "summary": summary,
101
- "summary_length": len(summary),
102
  "source_type": "url",
103
  "source_url": str(request.url)
104
  }
@@ -106,3 +88,8 @@ async def summarise_url(request: URLSummaryRequest):
106
  raise
107
  except Exception as e:
108
  raise HTTPException(status_code=500, detail=str(e))
 
 
 
 
 
 
5
  from app.services.url_extractor import URLExtractorService
6
  from app.services.cache import hash_text, get_cached_summary, cache_summary
7
 
8
+ router = APIRouter(prefix="/api")
9
+ summariser_service = SummariserService()
10
 
11
  class TextSummaryRequest(BaseModel):
12
  text: str = Field(..., min_length=10, description="The text to summarise")
 
46
  return cached_summary
47
 
48
  # If not in cache, generate summary
49
+ result = summariser_service.summarise(
 
50
  text=request.text,
51
  max_length=request.max_length,
52
  min_length=request.min_length,
 
54
  temperature=request.temperature
55
  )
56
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  return result
58
  except Exception as e:
59
  raise HTTPException(status_code=500, detail=str(e))
 
69
  raise HTTPException(status_code=422, detail="Could not extract sufficient content from the URL")
70
 
71
  # Summarise the extracted content
72
+ result = summariser_service.summarise(
 
73
  text=content,
74
  max_length=request.max_length,
75
  min_length=request.min_length,
 
79
 
80
  return {
81
  "original_text_length": len(content),
82
+ "summary": result["summary"],
83
+ "summary_length": len(result["summary"]),
84
  "source_type": "url",
85
  "source_url": str(request.url)
86
  }
 
88
  raise
89
  except Exception as e:
90
  raise HTTPException(status_code=500, detail=str(e))
91
+
92
+ @router.get("/status")
93
+ async def get_status():
94
+ """Get the current status of the summariser service"""
95
+ return summariser_service.get_status()
app/services/__pycache__/summariser.cpython-311.pyc CHANGED
Binary files a/app/services/__pycache__/summariser.cpython-311.pyc and b/app/services/__pycache__/summariser.cpython-311.pyc differ
 
app/services/summariser.py CHANGED
@@ -1,19 +1,105 @@
1
  import numpy as np # Import NumPy first
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
 
 
4
 
5
  class SummariserService:
6
  def __init__(self):
7
- # Initialize with a smaller model for faster loading
8
- model_name = "facebook/bart-large-cnn"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
10
- self.model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
 
 
 
 
 
 
11
 
12
  # Move to GPU if available
13
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
14
  self.model.to(self.device)
15
 
16
- def summarise(self, text, max_length=150, min_length=50, do_sample=False, temperature=1.0):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  """
18
  Summarise the given text using the loaded model.
19
 
@@ -25,78 +111,71 @@ class SummariserService:
25
  temperature (float): Sampling temperature (higher = more random)
26
 
27
  Returns:
28
- str: The generated summary
29
  """
30
- # Convert character lengths to approximate token counts
31
- # A rough estimate is that 1 token ≈ 4 characters in English
32
- max_tokens = max(1, max_length // 4)
33
- min_tokens = max(1, min_length // 4)
34
-
35
- # Ensure text is within model's max token limit
36
- inputs = self.tokenizer(text, return_tensors="pt", truncation=True, max_length=1024)
37
- inputs = inputs.to(self.device)
38
-
39
- # Set generation parameters
40
- generation_params = {
41
- "max_length": max_tokens,
42
- "min_length": min_tokens,
43
- "num_beams": 4,
44
- "length_penalty": 2.0,
45
- "early_stopping": True,
46
  }
47
 
48
- # Handle sampling and temperature
49
- if do_sample:
50
- try:
51
- # First attempt: try with the requested temperature
52
- generation_params["do_sample"] = True
53
- generation_params["temperature"] = temperature
54
- summary_ids = self.model.generate(
55
- inputs["input_ids"],
56
- **generation_params
57
- )
58
- except Exception as e:
59
- # If that fails, try with default temperature (1.0)
60
- print(f"Error with temperature {temperature}, falling back to default: {str(e)}")
61
- generation_params["temperature"] = 1.0
62
- try:
63
- summary_ids = self.model.generate(
64
- inputs["input_ids"],
65
- **generation_params
66
- )
67
- except Exception:
68
- # If sampling still fails, fall back to beam search without sampling
69
- print("Sampling failed, falling back to beam search")
70
- generation_params.pop("do_sample", None)
71
- generation_params.pop("temperature", None)
72
- summary_ids = self.model.generate(
73
- inputs["input_ids"],
74
- **generation_params
75
- )
76
- else:
77
- # Standard beam search without sampling
78
- summary_ids = self.model.generate(
79
- inputs["input_ids"],
80
- **generation_params
81
- )
 
 
 
 
 
 
82
 
83
  summary = self.tokenizer.decode(summary_ids[0], skip_special_tokens=True)
84
 
85
- # If the summary is still too long, truncate it
86
- if len(summary) > max_length:
87
- # Try to truncate at a sentence boundary
88
- sentences = summary.split('. ')
89
- truncated_summary = ''
90
- for sentence in sentences:
91
- if len(truncated_summary) + len(sentence) + 2 <= max_length: # +2 for '. '
92
- truncated_summary += sentence + '. '
93
- else:
94
- break
95
 
96
- # If we couldn't even fit one sentence, just truncate at max_length
97
- if not truncated_summary:
98
- truncated_summary = summary[:max_length]
99
 
100
- summary = truncated_summary.strip()
 
 
 
101
 
102
- return summary
 
1
  import numpy as np # Import NumPy first
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
+ import time
5
+ import re
6
 
7
  class SummariserService:
8
  def __init__(self):
9
+ # Status tracking
10
+ self.model_loading_status = {
11
+ "is_loading": False,
12
+ "step": "",
13
+ "progress": 0
14
+ }
15
+
16
+ # Consider these alternative models
17
+ model_options = {
18
+ "general": "facebook/bart-large-cnn", # Better for news articles
19
+ "news": "facebook/bart-large-xsum", # Better for short news summaries
20
+ "long_form": "google/pegasus-large", # Better for long documents
21
+ "literary": "t5-large" # Better for literary text
22
+ }
23
+
24
+ # Choose the most appropriate model - let's switch to BART for better literary summaries
25
+ model_name = model_options["general"] # Changed from "literary" to "general"
26
+
27
+ # Update loading status
28
+ self.model_loading_status["is_loading"] = True
29
+ self.model_loading_status["step"] = "Initializing tokenizer"
30
+
31
  self.tokenizer = AutoTokenizer.from_pretrained(model_name)
32
+
33
+ self.model_loading_status["step"] = "Loading model"
34
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(
35
+ model_name,
36
+ force_download=False,
37
+ local_files_only=False
38
+ )
39
 
40
  # Move to GPU if available
41
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
42
  self.model.to(self.device)
43
 
44
+ self.model_loading_status["is_loading"] = False
45
+ self.model_loading_status["progress"] = 100
46
+
47
+ # Track current processing job
48
+ self.current_job = {
49
+ "in_progress": False,
50
+ "start_time": None,
51
+ "input_word_count": 0,
52
+ "estimated_time": 0,
53
+ "stage": "",
54
+ "progress": 0
55
+ }
56
+
57
+ def get_status(self):
58
+ """Return the current status of the summarizer service"""
59
+ status = {
60
+ "model_loading": self.model_loading_status,
61
+ "device": self.device,
62
+ "current_job": self.current_job
63
+ }
64
+
65
+ # Update estimated time remaining if job in progress
66
+ if self.current_job["in_progress"] and self.current_job["start_time"]:
67
+ elapsed = time.time() - self.current_job["start_time"]
68
+ estimated = self.current_job["estimated_time"]
69
+ remaining = max(0, estimated - elapsed)
70
+ status["current_job"]["time_remaining"] = round(remaining, 1)
71
+
72
+ # Update progress based on time
73
+ if estimated > 0:
74
+ progress = min(95, (elapsed / estimated) * 100)
75
+ status["current_job"]["progress"] = round(progress, 0)
76
+
77
+ return status
78
+
79
+ def clean_summary(self, summary):
80
+ """Clean and format the summary text"""
81
+ # Remove any leading punctuation or spaces
82
+ summary = re.sub(r'^[,.\s]+', '', summary)
83
+
84
+ # Ensure the first letter is capitalized
85
+ if summary and len(summary) > 0:
86
+ summary = summary[0].upper() + summary[1:]
87
+
88
+ # Ensure proper ending punctuation
89
+ if summary and not any(summary.endswith(end) for end in ['.', '!', '?']):
90
+ last_sentence_end = max(
91
+ summary.rfind('.'),
92
+ summary.rfind('!'),
93
+ summary.rfind('?')
94
+ )
95
+ if last_sentence_end > 0:
96
+ summary = summary[:last_sentence_end + 1]
97
+ else:
98
+ summary = summary + '.'
99
+
100
+ return summary
101
+
102
+ def summarise(self, text, max_length=250, min_length=100, do_sample=True, temperature=1.2):
103
  """
104
  Summarise the given text using the loaded model.
105
 
 
111
  temperature (float): Sampling temperature (higher = more random)
112
 
113
  Returns:
114
+ dict: The generated summary and processing metadata
115
  """
116
+ # Reset and start job tracking
117
+ self.current_job = {
118
+ "in_progress": True,
119
+ "start_time": time.time(),
120
+ "input_word_count": len(text.split()),
121
+ "estimated_time": max(1, min(30, len(text.split()) / 500)), # Rough estimate
122
+ "stage": "Tokenizing input text",
123
+ "progress": 5
 
 
 
 
 
 
 
 
124
  }
125
 
126
+ result = {
127
+ "summary": "",
128
+ "metadata": {
129
+ "input_word_count": self.current_job["input_word_count"],
130
+ "estimated_time_seconds": self.current_job["estimated_time"],
131
+ "model_used": "facebook/bart-large-cnn", # Update to match the model we're using
132
+ "processing_device": self.device
133
+ }
134
+ }
135
+
136
+ # Tokenization step
137
+ inputs = self.tokenizer(text, return_tensors="pt", max_length=1024, truncation=True)
138
+ input_ids = inputs.input_ids.to(self.device)
139
+
140
+ # Update metadata with token info
141
+ result["metadata"]["input_token_count"] = len(input_ids[0])
142
+ result["metadata"]["truncated"] = len(input_ids[0]) == 1024
143
+
144
+ # Update job status
145
+ self.current_job["stage"] = "Generating summary"
146
+ self.current_job["progress"] = 30
147
+
148
+ # Enhanced generation parameters - adjusted for better literary summaries
149
+ summary_ids = self.model.generate(
150
+ input_ids,
151
+ max_length=max_length,
152
+ min_length=min_length,
153
+ do_sample=do_sample,
154
+ temperature=temperature,
155
+ num_beams=5, # Increased from 4 to 5 for better quality
156
+ early_stopping=True,
157
+ no_repeat_ngram_size=3,
158
+ length_penalty=2.0,
159
+ top_k=50, # Added top_k parameter
160
+ top_p=0.95, # Added top_p parameter for nucleus sampling
161
+ )
162
+
163
+ # Update job status
164
+ self.current_job["stage"] = "Post-processing summary"
165
+ self.current_job["progress"] = 90
166
 
167
  summary = self.tokenizer.decode(summary_ids[0], skip_special_tokens=True)
168
 
169
+ # Clean and format the summary
170
+ summary = self.clean_summary(summary)
 
 
 
 
 
 
 
 
171
 
172
+ result["summary"] = summary
173
+ result["metadata"]["output_word_count"] = len(summary.split())
174
+ result["metadata"]["compression_ratio"] = round(len(summary.split()) / self.current_job["input_word_count"] * 100, 1)
175
 
176
+ # Complete job
177
+ self.current_job["in_progress"] = False
178
+ self.current_job["stage"] = "Complete"
179
+ self.current_job["progress"] = 100
180
 
181
+ return result
main.py CHANGED
@@ -1,39 +1,24 @@
1
  from fastapi import FastAPI
2
  from fastapi.middleware.cors import CORSMiddleware
3
- import os
4
 
5
- app = FastAPI(
6
- title="AI Content Summariser API",
7
- description="API for summarising text content using NLP models",
8
- version="0.1.0"
9
- )
10
 
11
  # Configure CORS
12
- origins = os.getenv("CORS_ORIGINS", "http://localhost:3000").split(",")
13
  app.add_middleware(
14
  CORSMiddleware,
15
- allow_origins=origins,
16
  allow_credentials=True,
17
  allow_methods=["*"],
18
  allow_headers=["*"],
19
  )
20
 
21
- @app.get("/")
22
- async def root():
23
- return {"message": "Welcome to the AI Content Summariser API"}
24
 
25
  @app.get("/health")
26
  async def health_check():
27
  return {"status": "healthy"}
28
 
29
- # Import and include API routes
30
- from app.api.routes import router as api_router
31
- app.include_router(api_router, prefix="/api")
32
-
33
- # Import and include async API routes
34
- from app.api.async_routes import router as async_router
35
- app.include_router(async_router, prefix="/api/async")
36
-
37
  if __name__ == "__main__":
38
  import uvicorn
39
- uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True)
 
1
  from fastapi import FastAPI
2
  from fastapi.middleware.cors import CORSMiddleware
3
+ from app.api.routes import router as api_router
4
 
5
+ app = FastAPI(title="AI Content Summariser API")
 
 
 
 
6
 
7
  # Configure CORS
 
8
  app.add_middleware(
9
  CORSMiddleware,
10
+ allow_origins=["http://localhost:3000"], # Add your frontend URL
11
  allow_credentials=True,
12
  allow_methods=["*"],
13
  allow_headers=["*"],
14
  )
15
 
16
+ app.include_router(api_router)
 
 
17
 
18
  @app.get("/health")
19
  async def health_check():
20
  return {"status": "healthy"}
21
 
 
 
 
 
 
 
 
 
22
  if __name__ == "__main__":
23
  import uvicorn
24
+ uvicorn.run(app, host="0.0.0.0", port=8000)
requirements.txt CHANGED
@@ -1,12 +1,15 @@
1
- fastapi==0.100.1
2
- uvicorn==0.23.2
3
- pydantic==2.1.1
4
- transformers==4.31.0
5
  torch==2.0.1
 
 
 
 
 
 
 
6
  sentencepiece==0.1.99
7
  python-dotenv==1.0.0
8
  httpx==0.24.1
9
  accelerate==0.21.0
10
- beautifulsoup4==4.12.2
11
  pytest==7.3.1
12
  pytest-cov==4.1.0
 
1
+ numpy==1.24.3
 
 
 
2
  torch==2.0.1
3
+ transformers==4.30.2
4
+ huggingface_hub==0.16.4
5
+ fastapi==0.100.0
6
+ uvicorn==0.22.0
7
+ pydantic==1.10.8
8
+ beautifulsoup4==4.12.2
9
+ requests==2.31.0
10
  sentencepiece==0.1.99
11
  python-dotenv==1.0.0
12
  httpx==0.24.1
13
  accelerate==0.21.0
 
14
  pytest==7.3.1
15
  pytest-cov==4.1.0