File size: 1,414 Bytes
3e19edc
f960c36
3e19edc
 
f960c36
3e19edc
 
ac4a3a2
 
 
 
 
3e19edc
ac4a3a2
 
 
 
3e19edc
ac4a3a2
 
 
 
 
 
 
 
 
3e19edc
 
ac4a3a2
 
 
 
3e19edc
 
 
ac4a3a2
3e19edc
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
### hf_client.py

from huggingface_hub import InferenceClient, HfApi
from tavily import TavilyClient
import os

# HF Inference Client

# Supported billing targets
_VALID_BILL_TO = {"huggingface", "fairworksai", "groq"}

HF_TOKEN = os.getenv("HF_TOKEN")
if not HF_TOKEN:
    raise RuntimeError(
        "HF_TOKEN environment variable is not set. "
        "Please set it to your Hugging Face API token."
    )

def get_inference_client(model_id: str, provider: str = "auto") -> InferenceClient:
    """
    Return an InferenceClient configured with the correct provider and billing target.
    
    - If model_id == "moonshotai/Kimi-K2-Instruct", force provider to "groq".
    - If the requested provider is not one of the supported billing targets,
      default billing to "groq".
    """
    # force certain models onto groq hardware
    if model_id == "moonshotai/Kimi-K2-Instruct":
        provider = "groq"

    # determine billing target
    bill_to = provider if provider in _VALID_BILL_TO else "groq"

    return InferenceClient(
        provider=provider,
        api_key=HF_TOKEN,
        bill_to=bill_to
    )

# Tavily Search Client
TAVILY_API_KEY = os.getenv('TAVILY_API_KEY')
tavily_client = None
if TAVILY_API_KEY:
    try:
        tavily_client = TavilyClient(api_key=TAVILY_API_KEY)
    except Exception as e:
        print(f"Failed to initialize Tavily client: {e}")
        tavily_client = None