Spaces:
Runtime error
Runtime error
from flask import Flask, request, jsonify | |
import os | |
import pdfplumber | |
import pytesseract | |
from PIL import Image | |
from transformers import PegasusForConditionalGeneration, PegasusTokenizer | |
import torch | |
import logging | |
app = Flask(__name__) | |
# Set up logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Load Pegasus Model | |
logger.info("Loading Pegasus model and tokenizer...") | |
tokenizer = PegasusTokenizer.from_pretrained("google/pegasus-xsum") | |
model = PegasusForConditionalGeneration.from_pretrained("google/pegasus-xsum") | |
logger.info("Model loaded successfully.") | |
# Extract text from PDF with page limit and timeout handling | |
def extract_text_from_pdf(file_path, max_pages=10): | |
text = "" | |
try: | |
with pdfplumber.open(file_path) as pdf: | |
total_pages = len(pdf.pages) | |
pages_to_process = min(total_pages, max_pages) | |
logger.info(f"Extracting text from {pages_to_process} of {total_pages} pages in {file_path}") | |
for i, page in enumerate(pdf.pages[:pages_to_process]): | |
try: | |
extracted = page.extract_text() | |
if extracted: | |
text += extracted + "\n" | |
except Exception as e: | |
logger.warning(f"Error extracting text from page {i+1}: {e}") | |
continue | |
except Exception as e: | |
logger.error(f"Failed to process PDF {file_path}: {e}") | |
return "" | |
return text.strip() | |
# Extract text from image (OCR) | |
def extract_text_from_image(file_path): | |
try: | |
logger.info(f"Extracting text from image {file_path} using OCR...") | |
image = Image.open(file_path) | |
text = pytesseract.image_to_string(image) | |
return text.strip() | |
except Exception as e: | |
logger.error(f"Failed to process image {file_path}: {e}") | |
return "" | |
# Summarize text using Pegasus with truncation | |
def summarize_text(text, max_input_length=512, max_output_length=150): | |
try: | |
logger.info("Summarizing text...") | |
inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=max_input_length) | |
summary_ids = model.generate( | |
inputs["input_ids"], | |
max_length=max_output_length, | |
min_length=30, | |
num_beams=4, | |
early_stopping=True | |
) | |
summary = tokenizer.decode(summary_ids[0], skip_special_tokens=True) | |
logger.info("Summarization completed.") | |
return summary | |
except Exception as e: | |
logger.error(f"Error during summarization: {e}") | |
return "" | |
def summarize_document(): | |
if 'file' not in request.files: | |
logger.error("No file uploaded in request.") | |
return jsonify({"error": "No file uploaded"}), 400 | |
file = request.files['file'] | |
filename = file.filename | |
if not filename: | |
logger.error("Empty filename in request.") | |
return jsonify({"error": "No file uploaded"}), 400 | |
file_path = os.path.join("/tmp", filename) | |
try: | |
file.save(file_path) | |
logger.info(f"File saved to {file_path}") | |
if filename.lower().endswith('.pdf'): | |
text = extract_text_from_pdf(file_path, max_pages=5) | |
elif filename.lower().endswith(('.png', '.jpeg', '.jpg')): | |
text = extract_text_from_image(file_path) | |
else: | |
logger.error(f"Unsupported file format: {filename}") | |
return jsonify({"error": "Unsupported file format. Use PDF, PNG, JPEG, or JPG"}), 400 | |
if not text: | |
logger.warning(f"No text extracted from {filename}") | |
return jsonify({"error": "No text extracted from the file"}), 400 | |
summary = summarize_text(text) | |
if not summary: | |
logger.warning("Summarization failed to produce output.") | |
return jsonify({"error": "Failed to generate summary"}), 500 | |
logger.info(f"Summary generated for {filename}") | |
return jsonify({"summary": summary}) | |
except Exception as e: | |
logger.error(f"Unexpected error processing {filename}: {e}") | |
return jsonify({"error": str(e)}), 500 | |
finally: | |
if os.path.exists(file_path): | |
try: | |
os.remove(file_path) | |
logger.info(f"Cleaned up file: {file_path}") | |
except Exception as e: | |
logger.warning(f"Failed to delete {file_path}: {e}") | |
if __name__ == '__main__': | |
logger.info("Starting Flask app...") | |
app.run(host='0.0.0.0', port=7860) |