File size: 2,192 Bytes
52c6f5c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
/* eslint-disable @typescript-eslint/no-explicit-any -- Sorry */
import { omit } from "$lib/utils/object.svelte.js";
import { InferenceClient } from "@huggingface/inference";
import type { ChatCompletionInputMessage } from "@huggingface/tasks";
import OpenAI from "openai";
import type { Stream } from "openai/streaming.mjs";
import type { GenerateRequest, OpenAIFunctionSchema } from "./types.js";
import type { ChatCompletionMessage } from "openai/resources/index.mjs";

export type GenerationArgs = {
	model: string;
	messages: Array<ChatCompletionInputMessage | ChatCompletionMessage>;
	provider?: string;
	config?: Record<string, unknown>;
	tools?: OpenAIFunctionSchema[];
	response_format?: unknown;
};

export interface Adapter {
	stream: (args: GenerationArgs) => Promise<Stream<OpenAI.Chat.Completions.ChatCompletionChunk>>;
	generate: (args: GenerationArgs) => Promise<OpenAI.Chat.Completions.ChatCompletion>;
}

function createCustomAdapter({ model }: GenerateRequest): Adapter {
	// Handle OpenAI-compatible custom models
	const openai = new OpenAI({
		apiKey: model.accessToken,
		baseURL: model.endpointUrl,
	});

	return {
		stream: async (args: GenerationArgs) => {
			return await openai.chat.completions.create({
				...omit(args, "provider"),
				stream: true,
			} as OpenAI.ChatCompletionCreateParamsStreaming);
		},
		generate: (args: GenerationArgs) => {
			return openai.chat.completions.create({
				...omit(args, "provider"),
				stream: false,
			} as OpenAI.ChatCompletionCreateParamsNonStreaming);
		},
	};
}

function createHFAdapter({ accessToken }: GenerateRequest): Adapter {
	const client = new InferenceClient(accessToken);
	return {
		stream: (args: GenerationArgs) => {
			return client.chatCompletionStream({
				...args,
				provider: args.provider as any,
				response_format: args.response_format as any,
				tools: args.tools as any,
			} as any) as any;
		},
		generate: (args: GenerationArgs) => {
			return client.chatCompletion(args as any) as any;
		},
	};
}

export function createAdapter(body: GenerateRequest): Adapter {
	const { model } = body;

	if (model.isCustom) {
		return createCustomAdapter(body);
	}
	return createHFAdapter(body);
}