Spaces:
No application file
No application file
import psycopg2 | |
import os | |
import pickle | |
import traceback | |
import numpy as np | |
import json | |
import base64 | |
import time | |
# Assuming gspread and SentenceTransformer are installed | |
try: | |
import gspread | |
from oauth2client.service_account import ServiceAccountCredentials | |
from sentence_transformers import SentenceTransformer | |
print("gspread and SentenceTransformer imported successfully.") | |
except ImportError: | |
print("Error: Required libraries (gspread, oauth2client, sentence_transformers) not found.") | |
print("Please install them: pip install psycopg2-binary gspread oauth2client sentence-transformers numpy") | |
# Exit or handle the error appropriately if libraries are missing | |
exit() # Exiting for demonstration if imports fail | |
# Define environment variables for PostgreSQL connection | |
# These should be set in the environment where you run this script | |
#DB_HOST = os.getenv("DB_HOST") | |
DB_NAME = "postgres" | |
#DB_NAME = os.getenv("DB_NAME") | |
DB_HOST = "https://wziqfkzaqorzthpoxhjh.supabase.co" | |
#DB_USER = os.getenv("DB_USER") | |
DB_USER = "postgres" | |
#DB_PASSWORD = os.getenv("DB_PASSWORD") | |
DB_PASSWORD = "Me21322972.........." # Replace with your actual password | |
#DB_PORT = os.getenv("DB_PORT", "5432") # Default PostgreSQL port | |
DB_PORT = "5432" | |
# Define environment variables for Google Sheets authentication | |
GOOGLE_BASE64_CREDENTIALS = os.getenv("GOOGLE_BASE64_CREDENTIALS") | |
SHEET_ID = "19ipxC2vHYhpXCefpxpIkpeYdI43a1Ku2kYwecgUULIw" # Replace with your actual Sheet ID | |
# Define table names | |
BUSINESS_DATA_TABLE = "business_data" | |
CONVERSATION_HISTORY_TABLE = "conversation_history" | |
# Define Embedding Dimension (must match your chosen Sentence Transformer model) | |
EMBEDDING_DIM = 384 # Dimension for paraphrase-MiniLM-L6-v2 | |
# --- Database Functions --- | |
def connect_db(): | |
"""Establishes a connection to the PostgreSQL database.""" | |
print("Attempting to connect to the database...") | |
# Retrieve credentials inside the function in case environment variables are set after import | |
# Use the hardcoded global variables defined above for this test | |
db_host = DB_HOST | |
db_name = DB_NAME | |
db_user = DB_USER | |
db_password = DB_PASSWORD | |
db_port = DB_PORT | |
if not all([db_host, db_name, db_user, db_password]): | |
print("Error: Database credentials (DB_HOST, DB_NAME, DB_USER, DB_PASSWORD) are not fully set as environment variables.") | |
return None | |
# *** FIX: Remove http(s):// prefix from host if present *** | |
if db_host.startswith("https://"): | |
db_host = db_host.replace("https://", "") | |
elif db_host.startswith("http://"): | |
db_host = db_host.replace("http://", "") | |
# ********************************************************** | |
try: | |
conn = psycopg2.connect( | |
host=db_host, | |
database=db_name, | |
user=db_user, | |
password=db_password, | |
port=db_port | |
) | |
print("Database connection successful.") | |
return conn | |
except Exception as e: | |
print(f"Error connecting to the database: {e}") | |
print(traceback.format_exc()) | |
return None | |
def setup_db_schema(conn): | |
"""Sets up the necessary tables and pgvector extension.""" | |
print("Setting up database schema...") | |
try: | |
with conn.cursor() as cur: | |
# Enable pgvector extension | |
cur.execute("CREATE EXTENSION IF NOT EXISTS vector;") | |
print("pgvector extension enabled (if not already).") | |
# Create business_data table | |
cur.execute(f""" | |
CREATE TABLE IF NOT EXISTS {BUSINESS_DATA_TABLE} ( | |
id SERIAL PRIMARY KEY, | |
service TEXT NOT NULL, | |
description TEXT NOT NULL, | |
embedding vector({EMBEDDING_DIM}) -- Assuming EMBEDDING_DIM is defined globally | |
); | |
""") | |
print(f"Table '{BUSINESS_DATA_TABLE}' created (if not already).") | |
# Create conversation_history table | |
cur.execute(f""" | |
CREATE TABLE IF NOT EXISTS {CONVERSATION_HISTORY_TABLE} ( | |
id SERIAL PRIMARY KEY, | |
timestamp TIMESTAMP WITH TIME ZONE NOT NULL, | |
user_id TEXT, | |
user_query TEXT, | |
model_response TEXT, | |
tool_details JSONB, | |
model_used TEXT | |
); | |
""") | |
print(f"Table '{CONVERSATION_HISTORY_TABLE}' created (if not already).") | |
conn.commit() | |
print("Database schema setup complete.") | |
return True | |
except Exception as e: | |
print(f"Error setting up database schema: {e}") | |
print(traceback.format_exc()) | |
conn.rollback() | |
return False | |
# --- Google Sheets Authentication and Data Retrieval --- | |
def authenticate_google_sheets(): | |
"""Authenticates with Google Sheets using base64 encoded credentials.""" | |
print("Authenticating Google Account for Sheets access...") | |
if not GOOGLE_BASE64_CREDENTIALS: | |
print("Error: GOOGLE_BASE64_CREDENTIALS environment variable not set. Google Sheets access will fail.") | |
return None | |
try: | |
credentials_json = base64.b64decode(GOOGLE_BASE64_CREDENTIALS).decode('utf-8') | |
credentials = json.loads(credentials_json) | |
# Use ServiceAccountCredentials.from_json_keyfile_dict for dictionary | |
scope = ['https://spreadsheets.google.com/feeds', 'https://www.googleapis.com/auth/drive'] | |
creds = ServiceAccountCredentials.from_json_keyfile_dict(credentials, scope) | |
gc = gspread.authorize(creds) | |
print("Google Sheets authentication successful.") | |
return gc | |
except Exception as e: | |
print(f"Google Sheets authentication failed: {e}") | |
print(traceback.format_exc()) | |
print("Please ensure your GOOGLE_BASE64_CREDENTIALS environment variable is correctly set and contains valid service account credentials.") | |
return None | |
# --- Data Migration Function --- | |
def migrate_google_sheet_data_to_db(conn, gc_client, embedder_model): | |
"""Retrieves data from Google Sheet, generates embeddings, and inserts into DB.""" | |
print("Migrating data from Google Sheet to database...") | |
if gc_client is None or SHEET_ID is None: | |
print("Skipping Google Sheet migration: Google Sheets client or Sheet ID not available.") | |
return False | |
if embedder_model is None: | |
print("Skipping Google Sheet migration: Embedder not available.") | |
return False | |
if EMBEDDING_DIM is None: | |
print("Skipping Google Sheet migration: EMBEDDING_DIM not defined.") | |
return False | |
try: | |
# Check if business_data table is already populated | |
with conn.cursor() as cur: | |
cur.execute(f"SELECT COUNT(*) FROM {BUSINESS_DATA_TABLE};") | |
count = cur.fetchone()[0] | |
if count > 0: | |
print(f"Table '{BUSINESS_DATA_TABLE}' already contains {count} records. Skipping migration.") | |
return True # Indicate success because data is already there | |
sheet = gc_client.open_by_key(SHEET_ID).sheet1 | |
print(f"Successfully opened Google Sheet with ID: {SHEET_ID}") | |
data_records = sheet.get_all_records() | |
if not data_records: | |
print("No data records found in Google Sheet.") | |
return False | |
filtered_data = [row for row in data_records if row.get('Service') and row.get('Description')] | |
if not filtered_data: | |
print("Filtered data is empty after checking for 'Service' and 'Description'.") | |
return False | |
print(f"Processing {len(filtered_data)} records for migration.") | |
descriptions_for_embedding = [f"Service: {row['Service'].strip()}. Description: {row['Description'].strip()}" for row in filtered_data] | |
# Generate embeddings in batches if needed for large datasets | |
batch_size = 64 | |
embeddings_list = [] | |
for i in range(0, len(descriptions_for_embedding), batch_size): | |
batch_descriptions = descriptions_for_embedding[i:i + batch_size] | |
print(f"Encoding batch {int(i/batch_size) + 1} of {int(len(descriptions_for_embedding)/batch_size) + 1}...") | |
batch_embeddings = embedder_model.encode(batch_descriptions, convert_to_tensor=False) | |
embeddings_list.extend(batch_embeddings.tolist()) # Convert numpy array to list | |
insert_count = 0 | |
with conn.cursor() as cur: | |
for i, row in enumerate(filtered_data): | |
service = row.get('Service', '').strip() | |
description = row.get('Description', '').strip() | |
embedding = embeddings_list[i] | |
# Use the vector literal format '[]' for inserting embeddings | |
# Use execute_values for potentially faster bulk inserts if necessary, but simple execute is fine for this | |
cur.execute(f""" | |
INSERT INTO {BUSINESS_DATA_TABLE} (service, description, embedding) | |
VALUES (%s, %s, %s::vector); | |
""", (service, description, embedding)) | |
insert_count += 1 | |
if insert_count % 100 == 0: | |
conn.commit() # Commit periodically | |
print(f"Inserted {insert_count} records...") | |
conn.commit() # Commit remaining records | |
print(f"Migration complete. Inserted {insert_count} records into '{BUSINESS_DATA_TABLE}'.") | |
return True | |
except Exception as e: | |
print(f"Error during Google Sheet data migration: {e}") | |
print(traceback.format_exc()) | |
conn.rollback() | |
return False | |
# --- Main Migration Execution --- | |
if __name__ == "__main__": | |
print("Starting RAG data migration script...") | |
# 1. Authenticate Google Sheets | |
gc = authenticate_google_sheets() | |
if gc is None: | |
print("Google Sheets authentication failed. Cannot migrate data from Sheets.") | |
# Exit or handle the error if Sheets auth fails | |
exit() | |
# 2. Initialize Embedder Model | |
try: | |
print(f"Loading Sentence Transformer model for embeddings (dimension: {EMBEDDING_DIM})...") | |
# Make sure to use the correct model and check its dimension | |
embedder = SentenceTransformer("sentence-transformers/paraphrase-MiniLM-L6-v2") | |
# Verify the dimension matches EMBEDDING_DIM | |
if embedder.get_sentence_embedding_dimension() != EMBEDDING_DIM: | |
print(f"Error: Loaded embedder dimension ({embedder.get_sentence_embedding_dimension()}) does not match expected EMBEDDING_DIM ({EMBEDDING_DIM}).") | |
print("Please check the model or update EMBEDDING_DIM.") | |
embedder = None # Set to None to prevent migration with wrong dimension | |
else: | |
print("Embedder model loaded successfully.") | |
except Exception as e: | |
print(f"Error loading Sentence Transformer model: {e}") | |
print(traceback.format_exc()) | |
embedder = None # Set to None if model loading fails | |
if embedder is None: | |
print("Embedder model not available. Cannot generate embeddings for migration.") | |
# Exit or handle the error if embedder fails to load | |
exit() | |
# 3. Connect to Database | |
db_conn = connect_db() | |
if db_conn is None: | |
print("Database connection failed. Cannot migrate data.") | |
# Exit or handle the error if DB connection fails | |
exit() | |
try: | |
# 4. Setup Database Schema (if not already done) | |
if setup_db_schema(db_conn): | |
# 5. Migrate Data | |
if migrate_google_sheet_data_to_db(db_conn, gc, embedder): | |
print("\nRAG Data Migration to PostgreSQL completed successfully.") | |
else: | |
print("\nRAG Data Migration to PostgreSQL failed.") | |
else: | |
print("\nDatabase schema setup failed. Data migration skipped.") | |
finally: | |
# 6. Close Database Connection | |
if db_conn: | |
db_conn.close() | |
print("Database connection closed.") | |
print("\nMigration script finished.") |