# auto_causal/config.py """Central configuration for AutoCausal, including LLM client setup.""" import os import logging from typing import Optional # Langchain imports from langchain_core.language_models import BaseChatModel from langchain_openai import ChatOpenAI # Default from langchain_anthropic import ChatAnthropic # Example from langchain_google_genai import ChatGoogleGenerativeAI # Add other providers if needed, e.g.: # from langchain_community.chat_models import ChatOllama from dotenv import load_dotenv from langchain_deepseek import ChatDeepSeek # Create a disk-backed SQLite cache: # Import Together provider from langchain_together import ChatTogether logger = logging.getLogger(__name__) # Load .env file when this module is loaded load_dotenv() def get_llm_client(provider: Optional[str] = None, model_name: Optional[str] = None, **kwargs) -> BaseChatModel: """Initializes and returns the chosen LLM client based on provider. Reads provider, model, and API keys from environment variables if not passed directly. Defaults to OpenAI GPT-4o-mini if no provider/model specified. """ # Prioritize arguments, then environment variables, then defaults provider = provider or os.getenv("LLM_PROVIDER", "openai") provider = provider.lower() # Default model depends on provider default_models = { "openai": "gpt-4o-mini", "anthropic": "claude-3-5-sonnet-latest", "together": "deepseek-ai/DeepSeek-V3", # Default Together model "gemini" : "gemini-2.5-flash", "deepseek" : "deepseek-chat" } model_name = model_name or os.getenv("LLM_MODEL", default_models.get(provider, default_models["openai"])) api_key = None if model_name not in ['o3-mini', 'o3', 'o4-mini']: kwargs.setdefault("temperature", 0) # Default temperature if not provided logger.info(f"Initializing LLM client: Provider='{provider}', Model='{model_name}'") try: if provider == "openai": api_key = os.getenv("OPENAI_API_KEY") if not api_key: raise ValueError("OPENAI_API_KEY not found in environment.") return ChatOpenAI(model=model_name, api_key=api_key, **kwargs) elif provider == "anthropic": api_key = os.getenv("ANTHROPIC_API_KEY") if not api_key: raise ValueError("ANTHROPIC_API_KEY not found in environment.") return ChatAnthropic(model=model_name, api_key=api_key, **kwargs, streaming=False) elif provider == "together": api_key = os.getenv("TOGETHER_API_KEY") if not api_key: raise ValueError("TOGETHER_API_KEY not found in environment.") return ChatTogether(model=model_name, api_key=api_key, **kwargs) elif provider == "gemini": api_key = os.getenv("GEMINI_API_KEY") if not api_key: raise ValueError("GEMINI_API_KEY not found in environment.") return ChatGoogleGenerativeAI(model=model_name, **kwargs, function_calling="auto") elif provider == "deepseek": api_key = os.getenv("DEEPSEEK_API_KEY") if not api_key: raise ValueError("DEEPSEEK_API_KEY not found in environment.") return ChatDeepSeek(model=model_name, **kwargs) # Example for Ollama (ensure langchain_community is installed) # elif provider == "ollama": # try: # from langchain_community.chat_models import ChatOllama # return ChatOllama(model=model_name, **kwargs) # except ImportError: # raise ValueError("langchain_community needed for Ollama. Run `pip install langchain-community`") else: raise ValueError(f"Unsupported LLM provider: {provider}") except Exception as e: logger.error(f"Failed to initialize LLM (Provider: {provider}, Model: {model_name}): {e}") raise # Re-raise the exception