ABSA_Test_Space / app.py
Futuresony's picture
Update app.py
9a6cded verified
raw
history blame
12.2 kB
import psycopg2
import os
import pickle
import traceback
import numpy as np
import json
import base64
import time
# Assuming gspread and SentenceTransformer are installed
try:
import gspread
from oauth2client.service_account import ServiceAccountCredentials
from sentence_transformers import SentenceTransformer
print("gspread and SentenceTransformer imported successfully.")
except ImportError:
print("Error: Required libraries (gspread, oauth2client, sentence_transformers) not found.")
print("Please install them: pip install psycopg2-binary gspread oauth2client sentence-transformers numpy")
# Exit or handle the error appropriately if libraries are missing
exit() # Exiting for demonstration if imports fail
# Define environment variables for PostgreSQL connection
# These should be set in the environment where you run this script
#DB_HOST = os.getenv("DB_HOST")
DB_NAME = "postgres"
#DB_NAME = os.getenv("DB_NAME")
DB_HOST = "https://wziqfkzaqorzthpoxhjh.supabase.co"
#DB_USER = os.getenv("DB_USER")
DB_USER = "postgres"
#DB_PASSWORD = os.getenv("DB_PASSWORD")
DB_PASSWORD = "Me21322972.........." # Replace with your actual password
#DB_PORT = os.getenv("DB_PORT", "5432") # Default PostgreSQL port
DB_PORT = "5432"
# Define environment variables for Google Sheets authentication
GOOGLE_BASE64_CREDENTIALS = os.getenv("GOOGLE_BASE64_CREDENTIALS")
SHEET_ID = "19ipxC2vHYhpXCefpxpIkpeYdI43a1Ku2kYwecgUULIw" # Replace with your actual Sheet ID
# Define table names
BUSINESS_DATA_TABLE = "business_data"
CONVERSATION_HISTORY_TABLE = "conversation_history"
# Define Embedding Dimension (must match your chosen Sentence Transformer model)
EMBEDDING_DIM = 384 # Dimension for paraphrase-MiniLM-L6-v2
# --- Database Functions ---
def connect_db():
"""Establishes a connection to the PostgreSQL database."""
print("Attempting to connect to the database...")
# Retrieve credentials inside the function in case environment variables are set after import
# Use the hardcoded global variables defined above for this test
db_host = DB_HOST
db_name = DB_NAME
db_user = DB_USER
db_password = DB_PASSWORD
db_port = DB_PORT
if not all([db_host, db_name, db_user, db_password]):
print("Error: Database credentials (DB_HOST, DB_NAME, DB_USER, DB_PASSWORD) are not fully set as environment variables.")
return None
# *** FIX: Remove http(s):// prefix from host if present ***
if db_host.startswith("https://"):
db_host = db_host.replace("https://", "")
elif db_host.startswith("http://"):
db_host = db_host.replace("http://", "")
# **********************************************************
try:
conn = psycopg2.connect(
host=db_host,
database=db_name,
user=db_user,
password=db_password,
port=db_port
)
print("Database connection successful.")
return conn
except Exception as e:
print(f"Error connecting to the database: {e}")
print(traceback.format_exc())
return None
def setup_db_schema(conn):
"""Sets up the necessary tables and pgvector extension."""
print("Setting up database schema...")
try:
with conn.cursor() as cur:
# Enable pgvector extension
cur.execute("CREATE EXTENSION IF NOT EXISTS vector;")
print("pgvector extension enabled (if not already).")
# Create business_data table
cur.execute(f"""
CREATE TABLE IF NOT EXISTS {BUSINESS_DATA_TABLE} (
id SERIAL PRIMARY KEY,
service TEXT NOT NULL,
description TEXT NOT NULL,
embedding vector({EMBEDDING_DIM}) -- Assuming EMBEDDING_DIM is defined globally
);
""")
print(f"Table '{BUSINESS_DATA_TABLE}' created (if not already).")
# Create conversation_history table
cur.execute(f"""
CREATE TABLE IF NOT EXISTS {CONVERSATION_HISTORY_TABLE} (
id SERIAL PRIMARY KEY,
timestamp TIMESTAMP WITH TIME ZONE NOT NULL,
user_id TEXT,
user_query TEXT,
model_response TEXT,
tool_details JSONB,
model_used TEXT
);
""")
print(f"Table '{CONVERSATION_HISTORY_TABLE}' created (if not already).")
conn.commit()
print("Database schema setup complete.")
return True
except Exception as e:
print(f"Error setting up database schema: {e}")
print(traceback.format_exc())
conn.rollback()
return False
# --- Google Sheets Authentication and Data Retrieval ---
def authenticate_google_sheets():
"""Authenticates with Google Sheets using base64 encoded credentials."""
print("Authenticating Google Account for Sheets access...")
if not GOOGLE_BASE64_CREDENTIALS:
print("Error: GOOGLE_BASE64_CREDENTIALS environment variable not set. Google Sheets access will fail.")
return None
try:
credentials_json = base64.b64decode(GOOGLE_BASE64_CREDENTIALS).decode('utf-8')
credentials = json.loads(credentials_json)
# Use ServiceAccountCredentials.from_json_keyfile_dict for dictionary
scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive']
creds = ServiceAccountCredentials.from_json_keyfile_dict(credentials, scope)
gc = gspread.authorize(creds)
print("Google Sheets authentication successful.")
return gc
except Exception as e:
print(f"Google Sheets authentication failed: {e}")
print(traceback.format_exc())
print("Please ensure your GOOGLE_BASE64_CREDENTIALS environment variable is correctly set and contains valid service account credentials.")
return None
# --- Data Migration Function ---
def migrate_google_sheet_data_to_db(conn, gc_client, embedder_model):
"""Retrieves data from Google Sheet, generates embeddings, and inserts into DB."""
print("Migrating data from Google Sheet to database...")
if gc_client is None or SHEET_ID is None:
print("Skipping Google Sheet migration: Google Sheets client or Sheet ID not available.")
return False
if embedder_model is None:
print("Skipping Google Sheet migration: Embedder not available.")
return False
if EMBEDDING_DIM is None:
print("Skipping Google Sheet migration: EMBEDDING_DIM not defined.")
return False
try:
# Check if business_data table is already populated
with conn.cursor() as cur:
cur.execute(f"SELECT COUNT(*) FROM {BUSINESS_DATA_TABLE};")
count = cur.fetchone()[0]
if count > 0:
print(f"Table '{BUSINESS_DATA_TABLE}' already contains {count} records. Skipping migration.")
return True # Indicate success because data is already there
sheet = gc_client.open_by_key(SHEET_ID).sheet1
print(f"Successfully opened Google Sheet with ID: {SHEET_ID}")
data_records = sheet.get_all_records()
if not data_records:
print("No data records found in Google Sheet.")
return False
filtered_data = [row for row in data_records if row.get('Service') and row.get('Description')]
if not filtered_data:
print("Filtered data is empty after checking for 'Service' and 'Description'.")
return False
print(f"Processing {len(filtered_data)} records for migration.")
descriptions_for_embedding = [f"Service: {row['Service'].strip()}. Description: {row['Description'].strip()}" for row in filtered_data]
# Generate embeddings in batches if needed for large datasets
batch_size = 64
embeddings_list = []
for i in range(0, len(descriptions_for_embedding), batch_size):
batch_descriptions = descriptions_for_embedding[i:i + batch_size]
print(f"Encoding batch {int(i/batch_size) + 1} of {int(len(descriptions_for_embedding)/batch_size) + 1}...")
batch_embeddings = embedder_model.encode(batch_descriptions, convert_to_tensor=False)
embeddings_list.extend(batch_embeddings.tolist()) # Convert numpy array to list
insert_count = 0
with conn.cursor() as cur:
for i, row in enumerate(filtered_data):
service = row.get('Service', '').strip()
description = row.get('Description', '').strip()
embedding = embeddings_list[i]
# Use the vector literal format '[]' for inserting embeddings
# Use execute_values for potentially faster bulk inserts if necessary, but simple execute is fine for this
cur.execute(f"""
INSERT INTO {BUSINESS_DATA_TABLE} (service, description, embedding)
VALUES (%s, %s, %s::vector);
""", (service, description, embedding))
insert_count += 1
if insert_count % 100 == 0:
conn.commit() # Commit periodically
print(f"Inserted {insert_count} records...")
conn.commit() # Commit remaining records
print(f"Migration complete. Inserted {insert_count} records into '{BUSINESS_DATA_TABLE}'.")
return True
except Exception as e:
print(f"Error during Google Sheet data migration: {e}")
print(traceback.format_exc())
conn.rollback()
return False
# --- Main Migration Execution ---
if __name__ == "__main__":
print("Starting RAG data migration script...")
# 1. Authenticate Google Sheets
gc = authenticate_google_sheets()
if gc is None:
print("Google Sheets authentication failed. Cannot migrate data from Sheets.")
# Exit or handle the error if Sheets auth fails
exit()
# 2. Initialize Embedder Model
try:
print(f"Loading Sentence Transformer model for embeddings (dimension: {EMBEDDING_DIM})...")
# Make sure to use the correct model and check its dimension
embedder = SentenceTransformer("sentence-transformers/paraphrase-MiniLM-L6-v2")
# Verify the dimension matches EMBEDDING_DIM
if embedder.get_sentence_embedding_dimension() != EMBEDDING_DIM:
print(f"Error: Loaded embedder dimension ({embedder.get_sentence_embedding_dimension()}) does not match expected EMBEDDING_DIM ({EMBEDDING_DIM}).")
print("Please check the model or update EMBEDDING_DIM.")
embedder = None # Set to None to prevent migration with wrong dimension
else:
print("Embedder model loaded successfully.")
except Exception as e:
print(f"Error loading Sentence Transformer model: {e}")
print(traceback.format_exc())
embedder = None # Set to None if model loading fails
if embedder is None:
print("Embedder model not available. Cannot generate embeddings for migration.")
# Exit or handle the error if embedder fails to load
exit()
# 3. Connect to Database
db_conn = connect_db()
if db_conn is None:
print("Database connection failed. Cannot migrate data.")
# Exit or handle the error if DB connection fails
exit()
try:
# 4. Setup Database Schema (if not already done)
if setup_db_schema(db_conn):
# 5. Migrate Data
if migrate_google_sheet_data_to_db(db_conn, gc, embedder):
print("\nRAG Data Migration to PostgreSQL completed successfully.")
else:
print("\nRAG Data Migration to PostgreSQL failed.")
else:
print("\nDatabase schema setup failed. Data migration skipped.")
finally:
# 6. Close Database Connection
if db_conn:
db_conn.close()
print("Database connection closed.")
print("\nMigration script finished.")