File size: 3,110 Bytes
15a5288
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import { BaseProviderFetcher } from './base';
import type { ProviderEntry, CohereModel } from './types';

export class CohereFetcher extends BaseProviderFetcher {
  name = 'cohere';

  constructor(apiKey?: string) {
    super('https://api.cohere.ai', apiKey, {
      requestsPerMinute: 60  // Conservative default
    });
  }

  async fetchModels(): Promise<ProviderEntry[]> {
    try {
      // Fetch all models
      const response = await this.fetchWithRetry<{ models: CohereModel[] }>(
        `${this.baseUrl}/v1/models`
      );

      // Optionally filter by endpoint type
      const chatModels = response.models.filter(model => 
        model.endpoints.includes('chat') || model.endpoints.includes('generate')
      );

      return chatModels.map(model => this.mapModelToProviderEntry(model));
    } catch (error) {
      console.error(`Failed to fetch Cohere models: ${error}`);
      return [];
    }
  }

  async fetchModel(modelName: string): Promise<ProviderEntry | null> {
    try {
      const response = await this.fetchWithRetry<CohereModel>(
        `${this.baseUrl}/v1/models/${encodeURIComponent(modelName)}`
      );

      return this.mapModelToProviderEntry(response);
    } catch (error) {
      console.error(`Failed to fetch Cohere model ${modelName}: ${error}`);
      return null;
    }
  }

  private mapModelToProviderEntry(model: CohereModel): ProviderEntry {
    const entry: ProviderEntry = {
      provider: this.name,
      context_length: model.context_length,
      status: model.is_deprecated ? 'deprecated' : 'live',
      supports_image_input: model.supports_vision
    };

    // Map features to capability flags
    const featureMapping = this.mapFeatures(model.features);
    Object.assign(entry, featureMapping);

    // Map endpoints to capabilities
    const endpointCapabilities = this.mapEndpoints(model.endpoints);
    Object.assign(entry, endpointCapabilities);

    // Set supported parameters based on features
    entry.supported_parameters = model.features;

    return entry;
  }

  private mapFeatures(features: string[]): Partial<ProviderEntry> {
    const result: Partial<ProviderEntry> = {};

    // Feature mapping based on the spec
    const featureMap: { [key: string]: (keyof ProviderEntry)[] } = {
      'tools': ['supports_tools'],
      'strict_tools': ['supports_function_calling'],
      'json_mode': ['supports_structured_output'],
      'json_schema': ['supports_structured_output', 'supports_response_format'],
      'logprobs': ['supports_logprobs']
    };

    for (const feature of features) {
      const mappedKeys = featureMap[feature];
      if (mappedKeys) {
        for (const key of mappedKeys) {
          (result[key] as any) = true;
        }
      }
    }


    return result;
  }

  private mapEndpoints(endpoints: string[]): Partial<ProviderEntry> {
    const result: Partial<ProviderEntry> = {};

    // If the model supports chat or generate endpoints, it's a text generation model
    if (endpoints.includes('chat') || endpoints.includes('generate')) {
      result.model_type = 'chat';
    }

    return result;
  }
}