Spaces:
Running
Running
File size: 4,052 Bytes
1721aea |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
# auto_causal/config.py
"""Central configuration for AutoCausal, including LLM client setup."""
import os
import logging
from typing import Optional
# Langchain imports
from langchain_core.language_models import BaseChatModel
from langchain_openai import ChatOpenAI # Default
from langchain_anthropic import ChatAnthropic # Example
from langchain_google_genai import ChatGoogleGenerativeAI
# Add other providers if needed, e.g.:
# from langchain_community.chat_models import ChatOllama
from dotenv import load_dotenv
from langchain_deepseek import ChatDeepSeek
# Create a disk-backed SQLite cache:
# Import Together provider
from langchain_together import ChatTogether
logger = logging.getLogger(__name__)
# Load .env file when this module is loaded
load_dotenv()
def get_llm_client(provider: Optional[str] = None, model_name: Optional[str] = None, **kwargs) -> BaseChatModel:
"""Initializes and returns the chosen LLM client based on provider.
Reads provider, model, and API keys from environment variables if not passed directly.
Defaults to OpenAI GPT-4o-mini if no provider/model specified.
"""
# Prioritize arguments, then environment variables, then defaults
provider = provider or os.getenv("LLM_PROVIDER", "openai")
provider = provider.lower()
# Default model depends on provider
default_models = {
"openai": "gpt-4o-mini",
"anthropic": "claude-3-5-sonnet-latest",
"together": "deepseek-ai/DeepSeek-V3", # Default Together model
"gemini" : "gemini-2.5-flash",
"deepseek" : "deepseek-chat"
}
model_name = model_name or os.getenv("LLM_MODEL", default_models.get(provider, default_models["openai"]))
api_key = None
if model_name not in ['o3-mini', 'o3', 'o4-mini']:
kwargs.setdefault("temperature", 0) # Default temperature if not provided
logger.info(f"Initializing LLM client: Provider='{provider}', Model='{model_name}'")
try:
if provider == "openai":
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
raise ValueError("OPENAI_API_KEY not found in environment.")
return ChatOpenAI(model=model_name, api_key=api_key, **kwargs)
elif provider == "anthropic":
api_key = os.getenv("ANTHROPIC_API_KEY")
if not api_key:
raise ValueError("ANTHROPIC_API_KEY not found in environment.")
return ChatAnthropic(model=model_name, api_key=api_key, **kwargs, streaming=False)
elif provider == "together":
api_key = os.getenv("TOGETHER_API_KEY")
if not api_key:
raise ValueError("TOGETHER_API_KEY not found in environment.")
return ChatTogether(model=model_name, api_key=api_key, **kwargs)
elif provider == "gemini":
api_key = os.getenv("GEMINI_API_KEY")
if not api_key:
raise ValueError("GEMINI_API_KEY not found in environment.")
return ChatGoogleGenerativeAI(model=model_name, **kwargs, function_calling="auto")
elif provider == "deepseek":
api_key = os.getenv("DEEPSEEK_API_KEY")
if not api_key:
raise ValueError("DEEPSEEK_API_KEY not found in environment.")
return ChatDeepSeek(model=model_name, **kwargs)
# Example for Ollama (ensure langchain_community is installed)
# elif provider == "ollama":
# try:
# from langchain_community.chat_models import ChatOllama
# return ChatOllama(model=model_name, **kwargs)
# except ImportError:
# raise ValueError("langchain_community needed for Ollama. Run `pip install langchain-community`")
else:
raise ValueError(f"Unsupported LLM provider: {provider}")
except Exception as e:
logger.error(f"Failed to initialize LLM (Provider: {provider}, Model: {model_name}): {e}")
raise # Re-raise the exception |