Spaces:
Running
Running
File size: 5,611 Bytes
15a5288 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
import { ProviderEntry, ProviderFetcher } from './types';
import { getStaticPricing } from './static-pricing';
import { NovitaFetcher } from './novita';
import { SambaNovaFetcher } from './sambanova';
import { GroqFetcher } from './groq';
import { FeatherlessFetcher } from './featherless';
import { TogetherFetcher } from './together';
import { CohereFetcher } from './cohere';
import { FireworksFetcher } from './fireworks';
import { NebiusFetcher } from './nebius';
import { HyperbolicFetcher } from './hyperbolic';
import { CerebrasFetcher } from './cerebras';
import { NScaleFetcher } from './nscale';
export interface AggregatorConfig {
providers?: string[]; // Specific providers to fetch from
apiKeys?: {
[provider: string]: string;
};
concurrent?: number; // Number of concurrent fetches
includeStaticPricing?: boolean;
}
export class ProviderAggregator {
private fetchers: Map<string, ProviderFetcher>;
private config: AggregatorConfig;
constructor(config: AggregatorConfig = {}) {
this.config = {
concurrent: 3,
includeStaticPricing: true,
...config
};
this.fetchers = new Map();
this.initializeFetchers();
}
private initializeFetchers() {
const apiKeys = this.config.apiKeys || {};
// Initialize all available fetchers
this.fetchers.set('novita', new NovitaFetcher(apiKeys.novita));
this.fetchers.set('sambanova', new SambaNovaFetcher(apiKeys.sambanova));
this.fetchers.set('groq', new GroqFetcher(apiKeys.groq));
this.fetchers.set('featherless', new FeatherlessFetcher(apiKeys.featherless));
this.fetchers.set('together', new TogetherFetcher(apiKeys.together));
this.fetchers.set('cohere', new CohereFetcher(apiKeys.cohere));
this.fetchers.set('fireworks', new FireworksFetcher(apiKeys.fireworks));
this.fetchers.set('nebius', new NebiusFetcher(apiKeys.nebius));
this.fetchers.set('hyperbolic', new HyperbolicFetcher(apiKeys.hyperbolic));
this.fetchers.set('cerebras', new CerebrasFetcher(apiKeys.cerebras));
this.fetchers.set('nscale', new NScaleFetcher(apiKeys.nscale));
}
async fetchAllProviders(): Promise<Map<string, ProviderEntry[]>> {
const results = new Map<string, ProviderEntry[]>();
const providers = this.config.providers || Array.from(this.fetchers.keys());
// Fetch in batches to respect rate limits
const batches = this.createBatches(providers, this.config.concurrent || 3);
for (const batch of batches) {
const batchPromises = batch.map(async (provider) => {
const fetcher = this.fetchers.get(provider);
if (!fetcher) {
console.warn(`No fetcher found for provider: ${provider}`);
return { provider, entries: [] };
}
try {
console.log(`Fetching models from ${provider}...`);
const entries = await fetcher.fetchModels();
// Enrich with static pricing if needed
const enrichedEntries = this.enrichWithStaticPricing(provider, entries);
return { provider, entries: enrichedEntries };
} catch (error) {
console.error(`Failed to fetch from ${provider}:`, error);
return { provider, entries: [] };
}
});
const batchResults = await Promise.all(batchPromises);
for (const { provider, entries } of batchResults) {
results.set(provider, entries);
}
}
return results;
}
async fetchProvider(provider: string): Promise<ProviderEntry[]> {
const fetcher = this.fetchers.get(provider);
if (!fetcher) {
throw new Error(`No fetcher found for provider: ${provider}`);
}
const entries = await fetcher.fetchModels();
return this.enrichWithStaticPricing(provider, entries);
}
private enrichWithStaticPricing(provider: string, entries: ProviderEntry[]): ProviderEntry[] {
if (!this.config.includeStaticPricing) {
return entries;
}
return entries.map(entry => {
// Only add static pricing if the entry doesn't already have pricing
if (!entry.pricing) {
const modelId = this.extractModelId(entry);
const staticPrice = getStaticPricing(provider, modelId);
if (staticPrice) {
return {
...entry,
pricing: staticPrice
};
}
}
return entry;
});
}
private extractModelId(entry: ProviderEntry): string {
// Extract model ID from various possible fields
// This is a simplified version - in production you'd need provider-specific logic
return (entry as any).id || (entry as any).model_id || 'unknown';
}
private createBatches<T>(items: T[], batchSize: number): T[][] {
const batches: T[][] = [];
for (let i = 0; i < items.length; i += batchSize) {
batches.push(items.slice(i, i + batchSize));
}
return batches;
}
// Aggregate all provider data into a single array
async aggregateAll(): Promise<ProviderEntry[]> {
const providerMap = await this.fetchAllProviders();
const allEntries: ProviderEntry[] = [];
for (const [provider, entries] of providerMap) {
allEntries.push(...entries);
}
return allEntries;
}
// Get a summary of available models per provider
async getSummary(): Promise<{ [provider: string]: number }> {
const providerMap = await this.fetchAllProviders();
const summary: { [provider: string]: number } = {};
for (const [provider, entries] of providerMap) {
summary[provider] = entries.length;
}
return summary;
}
} |