Dan Walsh commited on
Commit
b089011
·
1 Parent(s): 9e707a5

Updates to hugging face spaces config

Browse files
Files changed (5) hide show
  1. .env +3 -0
  2. Dockerfile +9 -0
  3. README.md +10 -5
  4. app/services/summariser.py +128 -83
  5. main.py +14 -4
.env ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ TRANSFORMERS_CACHE=/tmp/huggingface_cache
2
+ HF_HOME=/tmp/huggingface_cache
3
+ HUGGINGFACE_HUB_CACHE=/tmp/huggingface_cache
Dockerfile CHANGED
@@ -2,6 +2,15 @@ FROM python:3.9-slim
2
 
3
  WORKDIR /app
4
 
 
 
 
 
 
 
 
 
 
5
  # Copy requirements first for better caching
6
  COPY requirements.txt .
7
  RUN pip install --no-cache-dir -r requirements.txt
 
2
 
3
  WORKDIR /app
4
 
5
+ # Create a writable cache directory
6
+ RUN mkdir -p /tmp/huggingface_cache && \
7
+ chmod 777 /tmp/huggingface_cache
8
+
9
+ # Set environment variables for model caching
10
+ ENV TRANSFORMERS_CACHE=/tmp/huggingface_cache
11
+ ENV HF_HOME=/tmp/huggingface_cache
12
+ ENV HUGGINGFACE_HUB_CACHE=/tmp/huggingface_cache
13
+
14
  # Copy requirements first for better caching
15
  COPY requirements.txt .
16
  RUN pip install --no-cache-dir -r requirements.txt
README.md CHANGED
@@ -136,11 +136,16 @@ See the deployment guide in the frontend repository for detailed instructions on
136
 
137
  ### Deploying to Hugging Face Spaces
138
 
139
- 1. Create a new Space on Hugging Face
140
- 2. Choose Docker as the SDK
141
- 3. Upload your backend code
142
- 4. Configure the environment variables:
143
- - `CORS_ORIGINS`: Your frontend URL
 
 
 
 
 
144
 
145
  ## Performance Optimizations
146
 
 
136
 
137
  ### Deploying to Hugging Face Spaces
138
 
139
+ When deploying to Hugging Face Spaces, make sure to:
140
+
141
+ 1. Set the following environment variables in the Space settings:
142
+ - `TRANSFORMERS_CACHE=/tmp/huggingface_cache`
143
+ - `HF_HOME=/tmp/huggingface_cache`
144
+ - `HUGGINGFACE_HUB_CACHE=/tmp/huggingface_cache`
145
+
146
+ 2. Use the Docker SDK in your Space settings
147
+
148
+ 3. If you encounter memory issues, consider using a smaller model by changing the `model_name` in `summariser.py`
149
 
150
  ## Performance Optimizations
151
 
app/services/summariser.py CHANGED
@@ -2,6 +2,7 @@ import numpy as np # Import NumPy first
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
  import time
 
5
  import re
6
 
7
  class SummariserService:
@@ -15,35 +16,74 @@ class SummariserService:
15
 
16
  # Consider these alternative models
17
  model_options = {
18
- "general": "facebook/bart-large-cnn", # Better for news articles
19
- "news": "facebook/bart-large-xsum", # Better for short news summaries
20
- "long_form": "google/pegasus-large", # Better for long documents
21
- "literary": "t5-large" # Better for literary text
22
  }
23
 
24
- # Choose the most appropriate model - let's switch to BART for better literary summaries
25
- model_name = model_options["general"] # Changed from "literary" to "general"
26
 
27
  # Update loading status
28
  self.model_loading_status["is_loading"] = True
29
  self.model_loading_status["step"] = "Initializing tokenizer"
30
 
31
- self.tokenizer = AutoTokenizer.from_pretrained(model_name)
 
 
32
 
33
- self.model_loading_status["step"] = "Loading model"
34
- self.model = AutoModelForSeq2SeqLM.from_pretrained(
35
- model_name,
36
- force_download=False,
37
- local_files_only=False
38
- )
 
 
 
 
 
 
 
 
 
 
 
 
39
 
40
- # Move to GPU if available
41
- self.device = "cuda" if torch.cuda.is_available() else "cpu"
42
- self.model.to(self.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
  self.model_loading_status["is_loading"] = False
45
  self.model_loading_status["progress"] = 100
46
 
 
 
 
47
  # Track current processing job
48
  self.current_job = {
49
  "in_progress": False,
@@ -54,28 +94,6 @@ class SummariserService:
54
  "progress": 0
55
  }
56
 
57
- def get_status(self):
58
- """Return the current status of the summarizer service"""
59
- status = {
60
- "model_loading": self.model_loading_status,
61
- "device": self.device,
62
- "current_job": self.current_job
63
- }
64
-
65
- # Update estimated time remaining if job in progress
66
- if self.current_job["in_progress"] and self.current_job["start_time"]:
67
- elapsed = time.time() - self.current_job["start_time"]
68
- estimated = self.current_job["estimated_time"]
69
- remaining = max(0, estimated - elapsed)
70
- status["current_job"]["time_remaining"] = round(remaining, 1)
71
-
72
- # Update progress based on time
73
- if estimated > 0:
74
- progress = min(95, (elapsed / estimated) * 100)
75
- status["current_job"]["progress"] = round(progress, 0)
76
-
77
- return status
78
-
79
  def clean_summary(self, summary):
80
  """Clean and format the summary text"""
81
  # Remove any leading punctuation or spaces
@@ -99,6 +117,28 @@ class SummariserService:
99
 
100
  return summary
101
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
  def summarise(self, text, max_length=250, min_length=100, do_sample=True, temperature=1.2):
103
  """
104
  Summarise the given text using the loaded model.
@@ -128,54 +168,59 @@ class SummariserService:
128
  "metadata": {
129
  "input_word_count": self.current_job["input_word_count"],
130
  "estimated_time_seconds": self.current_job["estimated_time"],
131
- "model_used": "facebook/bart-large-cnn", # Update to match the model we're using
132
  "processing_device": self.device
133
  }
134
  }
135
 
136
- # Tokenization step
137
- inputs = self.tokenizer(text, return_tensors="pt", max_length=1024, truncation=True)
138
- input_ids = inputs.input_ids.to(self.device)
139
-
140
- # Update metadata with token info
141
- result["metadata"]["input_token_count"] = len(input_ids[0])
142
- result["metadata"]["truncated"] = len(input_ids[0]) == 1024
143
-
144
- # Update job status
145
- self.current_job["stage"] = "Generating summary"
146
- self.current_job["progress"] = 30
147
-
148
- # Enhanced generation parameters - adjusted for better literary summaries
149
- summary_ids = self.model.generate(
150
- input_ids,
151
- max_length=max_length,
152
- min_length=min_length,
153
- do_sample=do_sample,
154
- temperature=temperature,
155
- num_beams=5, # Increased from 4 to 5 for better quality
156
- early_stopping=True,
157
- no_repeat_ngram_size=3,
158
- length_penalty=2.0,
159
- top_k=50, # Added top_k parameter
160
- top_p=0.95, # Added top_p parameter for nucleus sampling
161
- )
162
-
163
- # Update job status
164
- self.current_job["stage"] = "Post-processing summary"
165
- self.current_job["progress"] = 90
166
-
167
- summary = self.tokenizer.decode(summary_ids[0], skip_special_tokens=True)
168
-
169
- # Clean and format the summary
170
- summary = self.clean_summary(summary)
171
-
172
- result["summary"] = summary
173
- result["metadata"]["output_word_count"] = len(summary.split())
174
- result["metadata"]["compression_ratio"] = round(len(summary.split()) / self.current_job["input_word_count"] * 100, 1)
175
-
176
- # Complete job
177
- self.current_job["in_progress"] = False
178
- self.current_job["stage"] = "Complete"
179
- self.current_job["progress"] = 100
 
 
 
 
 
180
 
181
  return result
 
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
  import time
5
+ import os
6
  import re
7
 
8
  class SummariserService:
 
16
 
17
  # Consider these alternative models
18
  model_options = {
19
+ "general": "facebook/bart-large-cnn",
20
+ "news": "facebook/bart-large-xsum",
21
+ "long_form": "google/pegasus-large",
22
+ "literary": "t5-large"
23
  }
24
 
25
+ # Choose the most appropriate model
26
+ model_name = model_options["literary"] # Better for literary text
27
 
28
  # Update loading status
29
  self.model_loading_status["is_loading"] = True
30
  self.model_loading_status["step"] = "Initializing tokenizer"
31
 
32
+ # Ensure cache directory exists and is writable
33
+ cache_dir = os.environ.get("TRANSFORMERS_CACHE", "/tmp/huggingface_cache")
34
+ os.makedirs(cache_dir, exist_ok=True)
35
 
36
+ try:
37
+ self.tokenizer = AutoTokenizer.from_pretrained(
38
+ model_name,
39
+ cache_dir=cache_dir,
40
+ local_files_only=False
41
+ )
42
+
43
+ self.model_loading_status["step"] = "Loading model"
44
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(
45
+ model_name,
46
+ cache_dir=cache_dir,
47
+ force_download=False,
48
+ local_files_only=False
49
+ )
50
+
51
+ # Move to GPU if available
52
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
53
+ self.model.to(self.device)
54
 
55
+ except Exception as e:
56
+ # Fallback to a smaller model if the main one fails
57
+ print(f"Error loading model {model_name}: {str(e)}")
58
+ print("Falling back to smaller model...")
59
+
60
+ fallback_model = "sshleifer/distilbart-cnn-6-6" # Much smaller model
61
+
62
+ self.tokenizer = AutoTokenizer.from_pretrained(
63
+ fallback_model,
64
+ cache_dir=cache_dir,
65
+ local_files_only=False
66
+ )
67
+
68
+ self.model = AutoModelForSeq2SeqLM.from_pretrained(
69
+ fallback_model,
70
+ cache_dir=cache_dir,
71
+ force_download=False,
72
+ local_files_only=False
73
+ )
74
+
75
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
76
+ self.model.to(self.device)
77
+
78
+ # Update model name for metadata
79
+ model_name = fallback_model
80
 
81
  self.model_loading_status["is_loading"] = False
82
  self.model_loading_status["progress"] = 100
83
 
84
+ # Store the actual model name used
85
+ self.model_name = model_name
86
+
87
  # Track current processing job
88
  self.current_job = {
89
  "in_progress": False,
 
94
  "progress": 0
95
  }
96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  def clean_summary(self, summary):
98
  """Clean and format the summary text"""
99
  # Remove any leading punctuation or spaces
 
117
 
118
  return summary
119
 
120
+ def get_status(self):
121
+ """Return the current status of the summarizer service"""
122
+ status = {
123
+ "model_loading": self.model_loading_status,
124
+ "device": self.device,
125
+ "current_job": self.current_job
126
+ }
127
+
128
+ # Update estimated time remaining if job in progress
129
+ if self.current_job["in_progress"] and self.current_job["start_time"]:
130
+ elapsed = time.time() - self.current_job["start_time"]
131
+ estimated = self.current_job["estimated_time"]
132
+ remaining = max(0, estimated - elapsed)
133
+ status["current_job"]["time_remaining"] = round(remaining, 1)
134
+
135
+ # Update progress based on time
136
+ if estimated > 0:
137
+ progress = min(95, (elapsed / estimated) * 100)
138
+ status["current_job"]["progress"] = round(progress, 0)
139
+
140
+ return status
141
+
142
  def summarise(self, text, max_length=250, min_length=100, do_sample=True, temperature=1.2):
143
  """
144
  Summarise the given text using the loaded model.
 
168
  "metadata": {
169
  "input_word_count": self.current_job["input_word_count"],
170
  "estimated_time_seconds": self.current_job["estimated_time"],
171
+ "model_used": self.model_name,
172
  "processing_device": self.device
173
  }
174
  }
175
 
176
+ try:
177
+ # Tokenization step
178
+ inputs = self.tokenizer(text, return_tensors="pt", max_length=1024, truncation=True)
179
+ input_ids = inputs.input_ids.to(self.device)
180
+
181
+ # Update metadata with token info
182
+ result["metadata"]["input_token_count"] = len(input_ids[0])
183
+ result["metadata"]["truncated"] = len(input_ids[0]) == 1024
184
+
185
+ # Update job status
186
+ self.current_job["stage"] = "Generating summary"
187
+ self.current_job["progress"] = 30
188
+
189
+ # Enhanced generation parameters
190
+ summary_ids = self.model.generate(
191
+ input_ids,
192
+ max_length=max_length,
193
+ min_length=min_length,
194
+ do_sample=do_sample,
195
+ temperature=temperature,
196
+ num_beams=4,
197
+ early_stopping=True,
198
+ no_repeat_ngram_size=3,
199
+ length_penalty=2.0,
200
+ )
201
+
202
+ # Update job status
203
+ self.current_job["stage"] = "Post-processing summary"
204
+ self.current_job["progress"] = 90
205
+
206
+ summary = self.tokenizer.decode(summary_ids[0], skip_special_tokens=True)
207
+
208
+ # Clean and format the summary
209
+ summary = self.clean_summary(summary)
210
+
211
+ result["summary"] = summary
212
+ result["metadata"]["output_word_count"] = len(summary.split())
213
+ result["metadata"]["compression_ratio"] = round(len(summary.split()) / self.current_job["input_word_count"] * 100, 1)
214
+
215
+ except Exception as e:
216
+ # Handle errors gracefully
217
+ print(f"Error during summarization: {str(e)}")
218
+ result["summary"] = "An error occurred during summarization. Please try again with a shorter text or different parameters."
219
+ result["error"] = str(e)
220
+ finally:
221
+ # Complete job
222
+ self.current_job["in_progress"] = False
223
+ self.current_job["stage"] = "Complete"
224
+ self.current_job["progress"] = 100
225
 
226
  return result
main.py CHANGED
@@ -1,18 +1,28 @@
1
  from fastapi import FastAPI
2
  from fastapi.middleware.cors import CORSMiddleware
3
- from app.api.routes import router as api_router
 
 
 
 
 
 
 
 
 
4
 
5
  app = FastAPI(title="AI Content Summariser API")
6
 
7
- # Configure CORS
8
  app.add_middleware(
9
  CORSMiddleware,
10
- allow_origins=["http://localhost:3000"], # Add your frontend URL
11
  allow_credentials=True,
12
  allow_methods=["*"],
13
  allow_headers=["*"],
14
  )
15
 
 
16
  app.include_router(api_router)
17
 
18
  @app.get("/health")
@@ -21,4 +31,4 @@ async def health_check():
21
 
22
  if __name__ == "__main__":
23
  import uvicorn
24
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
1
  from fastapi import FastAPI
2
  from fastapi.middleware.cors import CORSMiddleware
3
+ import os
4
+
5
+ # Import the router from the correct location
6
+ # Check which router file exists and use that one
7
+ if os.path.exists("app/api/routes.py"):
8
+ from app.api.routes import router as api_router
9
+ elif os.path.exists("app/routers/api.py"):
10
+ from app.routers.api import router as api_router
11
+ else:
12
+ raise ImportError("Could not find router file")
13
 
14
  app = FastAPI(title="AI Content Summariser API")
15
 
16
+ # Configure CORS - allow requests from the frontend
17
  app.add_middleware(
18
  CORSMiddleware,
19
+ allow_origins=["*"], # For development - restrict this in production
20
  allow_credentials=True,
21
  allow_methods=["*"],
22
  allow_headers=["*"],
23
  )
24
 
25
+ # Include the router
26
  app.include_router(api_router)
27
 
28
  @app.get("/health")
 
31
 
32
  if __name__ == "__main__":
33
  import uvicorn
34
+ uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", 8000)))