Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
/** BUSINESS | |
* | |
* All utils that are bound to business logic | |
* (and wouldn't be useful in another project) | |
* should be here. | |
* | |
**/ | |
import ctxLengthData from "$lib/data/context_length.json"; | |
import { pricing } from "$lib/state/pricing.svelte.js"; | |
import { snippets } from "@huggingface/inference"; | |
import { ConversationClass, type ConversationEntityMembers } from "$lib/state/conversations.svelte"; | |
import { token } from "$lib/state/token.svelte"; | |
import { isMcpEnabled } from "$lib/constants.js"; | |
import { | |
isCustomModel, | |
isHFModel, | |
Provider, | |
type Conversation, | |
type ConversationMessage, | |
type CustomModel, | |
type Model, | |
} from "$lib/types.js"; | |
import { safeParse } from "$lib/utils/json.js"; | |
import { omit } from "$lib/utils/object.svelte.js"; | |
import type { ChatCompletionInputMessage, InferenceSnippet } from "@huggingface/tasks"; | |
import { type ChatCompletionOutputMessage } from "@huggingface/tasks"; | |
import { AutoTokenizer, PreTrainedTokenizer } from "@huggingface/transformers"; | |
import { images } from "$lib/state/images.svelte.js"; | |
import { projects } from "$lib/state/projects.svelte.js"; | |
import { mcpServers } from "$lib/state/mcps.svelte.js"; | |
import { modifySnippet } from "$lib/utils/snippets.js"; | |
import { models } from "$lib/state/models.svelte"; | |
import { StreamReader } from "$lib/utils/stream.js"; | |
type ChatCompletionInputMessageChunk = | |
NonNullable<ChatCompletionInputMessage["content"]> extends string | (infer U)[] ? U : never; | |
async function parseMessage(message: ConversationMessage): Promise<ChatCompletionInputMessage> { | |
if (!message.images) return message; | |
const urls = await Promise.all(message.images?.map(k => images.get(k)) ?? []); | |
return { | |
...omit(message, "images"), | |
content: [ | |
{ | |
type: "text", | |
text: message.content ?? "", | |
}, | |
...message.images.map((_imgKey, i) => { | |
return { | |
type: "image_url", | |
image_url: { url: urls[i] as string }, | |
} satisfies ChatCompletionInputMessageChunk; | |
}), | |
], | |
}; | |
} | |
export function maxAllowedTokens(conversation: ConversationClass) { | |
const model = conversation.model; | |
const { provider } = conversation.data; | |
if (!provider || !isHFModel(model)) { | |
return customMaxTokens[conversation.model.id] ?? 100000; | |
} | |
// Try to get context length from pricing/router data first | |
const ctxLength = pricing.getContextLength(model.id, provider); | |
if (ctxLength) return ctxLength; | |
// Fall back to local context length data if available | |
const providerData = ctxLengthData[provider as keyof typeof ctxLengthData] as Record<string, number> | undefined; | |
const localCtxLength = providerData?.[model.id]; | |
if (localCtxLength) return localCtxLength; | |
// Final fallback to custom max tokens | |
return customMaxTokens[conversation.model.id] ?? 100000; | |
} | |
function getEnabledMCPs() { | |
if (!isMcpEnabled()) return []; | |
return mcpServers.enabled.map(server => ({ | |
id: server.id, | |
name: server.name, | |
url: server.url, | |
protocol: server.protocol, | |
headers: server.headers, | |
})); | |
} | |
function getResponseFormatObj(conversation: ConversationClass | Conversation) { | |
const data = conversation instanceof ConversationClass ? conversation.data : conversation; | |
const json = safeParse(data.structuredOutput?.schema ?? ""); | |
if (json && data.structuredOutput?.enabled && models.supportsStructuredOutput(conversation.model, data.provider)) { | |
switch (data.provider) { | |
case "cohere": { | |
return { | |
type: "json_object", | |
...json, | |
}; | |
} | |
case Provider.Cerebras: { | |
return { | |
type: "json_schema", | |
json_schema: { ...json, name: "schema" }, | |
}; | |
} | |
default: { | |
return { | |
type: "json_schema", | |
json_schema: json, | |
}; | |
} | |
} | |
} | |
} | |
export async function handleStreamingResponse( | |
conversation: ConversationClass | Conversation, | |
onChunk: (content: string) => void, | |
abortController: AbortController, | |
): Promise<void> { | |
const data = conversation instanceof ConversationClass ? conversation.data : conversation; | |
const model = conversation.model; | |
const systemMessage = projects.current?.systemMessage; | |
const messages: ConversationMessage[] = [ | |
...(isSystemPromptSupported(model) && systemMessage?.length ? [{ role: "system", content: systemMessage }] : []), | |
...(data.messages || []), | |
]; | |
const parsed = await Promise.all(messages.map(parseMessage)); | |
const requestBody = { | |
model: { | |
id: model.id, | |
isCustom: isCustomModel(model), | |
accessToken: isCustomModel(model) ? model.accessToken : undefined, | |
endpointUrl: isCustomModel(model) ? model.endpointUrl : undefined, | |
}, | |
messages: parsed, | |
config: data.config, | |
provider: data.provider, | |
streaming: true, | |
response_format: getResponseFormatObj(conversation), | |
accessToken: token.value, | |
enabledMCPs: getEnabledMCPs(), | |
}; | |
const reader = await StreamReader.fromFetch("/api/generate", { | |
method: "POST", | |
headers: { | |
"Content-Type": "application/json", | |
}, | |
body: JSON.stringify(requestBody), | |
signal: abortController.signal, | |
}); | |
let out = ""; | |
for await (const chunk of reader.read()) { | |
if (chunk.type === "chunk" && chunk.content) { | |
out += chunk.content; | |
onChunk(out); | |
} else if (chunk.type === "error") { | |
throw new Error(chunk.error || "Stream error"); | |
} | |
} | |
} | |
export async function handleNonStreamingResponse( | |
conversation: ConversationClass | Conversation, | |
): Promise<{ message: ChatCompletionOutputMessage; completion_tokens: number }> { | |
const data = conversation instanceof ConversationClass ? conversation.data : conversation; | |
const model = conversation.model; | |
const systemMessage = projects.current?.systemMessage; | |
const messages: ConversationMessage[] = [ | |
...(isSystemPromptSupported(model) && systemMessage?.length ? [{ role: "system", content: systemMessage }] : []), | |
...(data.messages || []), | |
]; | |
const parsed = await Promise.all(messages.map(parseMessage)); | |
const requestBody = { | |
model: { | |
id: model.id, | |
isCustom: isCustomModel(model), | |
accessToken: isCustomModel(model) ? model.accessToken : undefined, | |
endpointUrl: isCustomModel(model) ? model.endpointUrl : undefined, | |
}, | |
messages: parsed, | |
config: data.config, | |
provider: data.provider, | |
streaming: false, | |
response_format: getResponseFormatObj(conversation), | |
accessToken: token.value, | |
enabledMCPs: getEnabledMCPs(), | |
}; | |
const response = await fetch("/api/generate", { | |
method: "POST", | |
headers: { | |
"Content-Type": "application/json", | |
}, | |
body: JSON.stringify(requestBody), | |
}); | |
if (!response.ok) { | |
const error = await response.json(); | |
throw new Error(error.error || "Failed to generate response"); | |
} | |
return await response.json(); | |
} | |
export function isSystemPromptSupported(model: Model | CustomModel) { | |
if (isCustomModel(model)) return true; // OpenAI-compatible models support system messages | |
const template = model?.config.tokenizer_config?.chat_template; | |
if (typeof template !== "string") return false; | |
return template.includes("system"); | |
} | |
export const defaultSystemMessage: { [key: string]: string } = { | |
"Qwen/QwQ-32B-Preview": | |
"You are a helpful and harmless assistant. You are Qwen developed by Alibaba. You should think step-by-step.", | |
} as const; | |
export const customMaxTokens: { [key: string]: number } = { | |
"01-ai/Yi-1.5-34B-Chat": 2048, | |
"HuggingFaceM4/idefics-9b-instruct": 2048, | |
"deepseek-ai/DeepSeek-Coder-V2-Instruct": 16384, | |
"bigcode/starcoder": 8192, | |
"bigcode/starcoderplus": 8192, | |
"HuggingFaceH4/starcoderbase-finetuned-oasst1": 8192, | |
"google/gemma-7b": 8192, | |
"google/gemma-1.1-7b-it": 8192, | |
"google/gemma-2b": 8192, | |
"google/gemma-1.1-2b-it": 8192, | |
"google/gemma-2-27b-it": 8192, | |
"google/gemma-2-9b-it": 4096, | |
"google/gemma-2-2b-it": 8192, | |
"tiiuae/falcon-7b": 8192, | |
"tiiuae/falcon-7b-instruct": 8192, | |
"timdettmers/guanaco-33b-merged": 2048, | |
"mistralai/Mixtral-8x7B-Instruct-v0.1": 32768, | |
"Qwen/Qwen2.5-72B-Instruct": 32768, | |
"Qwen/Qwen2.5-Coder-32B-Instruct": 32768, | |
"meta-llama/Meta-Llama-3-70B-Instruct": 8192, | |
"CohereForAI/c4ai-command-r-plus-08-2024": 32768, | |
"NousResearch/Nous-Hermes-2-Mixtral-8x7B-DPO": 32768, | |
"meta-llama/Llama-2-70b-chat-hf": 8192, | |
"HuggingFaceH4/zephyr-7b-alpha": 17432, | |
"HuggingFaceH4/zephyr-7b-beta": 32768, | |
"mistralai/Mistral-7B-Instruct-v0.1": 32768, | |
"mistralai/Mistral-7B-Instruct-v0.2": 32768, | |
"mistralai/Mistral-7B-Instruct-v0.3": 32768, | |
"mistralai/Mistral-Nemo-Instruct-2407": 32768, | |
"meta-llama/Meta-Llama-3-8B-Instruct": 8192, | |
"mistralai/Mistral-7B-v0.1": 32768, | |
"bigcode/starcoder2-3b": 16384, | |
"bigcode/starcoder2-15b": 16384, | |
"HuggingFaceH4/starchat2-15b-v0.1": 16384, | |
"codellama/CodeLlama-7b-hf": 8192, | |
"codellama/CodeLlama-13b-hf": 8192, | |
"codellama/CodeLlama-34b-Instruct-hf": 8192, | |
"meta-llama/Llama-2-7b-chat-hf": 8192, | |
"meta-llama/Llama-2-13b-chat-hf": 8192, | |
"OpenAssistant/oasst-sft-6-llama-30b": 2048, | |
"TheBloke/vicuna-7B-v1.5-GPTQ": 2048, | |
"HuggingFaceH4/starchat-beta": 8192, | |
"bigcode/octocoder": 8192, | |
"vwxyzjn/starcoderbase-triviaqa": 8192, | |
"lvwerra/starcoderbase-gsm8k": 8192, | |
"NousResearch/Hermes-3-Llama-3.1-8B": 16384, | |
"microsoft/Phi-3.5-mini-instruct": 32768, | |
"meta-llama/Llama-3.1-70B-Instruct": 32768, | |
"meta-llama/Llama-3.1-8B-Instruct": 8192, | |
} as const; | |
// Order of the elements in InferenceModal.svelte is determined by this const | |
export const inferenceSnippetLanguages = ["python", "js", "sh"] as const; | |
export type InferenceSnippetLanguage = (typeof inferenceSnippetLanguages)[number]; | |
export type GetInferenceSnippetReturn = InferenceSnippet[]; | |
export function getInferenceSnippet( | |
conversation: ConversationClass, | |
language: InferenceSnippetLanguage, | |
opts?: { | |
accessToken?: string; | |
messages?: ConversationEntityMembers["messages"]; | |
streaming?: ConversationEntityMembers["streaming"]; | |
max_tokens?: ConversationEntityMembers["config"]["max_tokens"]; | |
temperature?: ConversationEntityMembers["config"]["temperature"]; | |
top_p?: ConversationEntityMembers["config"]["top_p"]; | |
structured_output?: ConversationEntityMembers["structuredOutput"]; | |
billTo?: string; | |
}, | |
): GetInferenceSnippetReturn { | |
const model = conversation.model; | |
const data = conversation.data; | |
const provider = (isCustomModel(model) ? "hf-inference" : data.provider) as Provider; | |
// If it's a custom model, we don't generate inference snippets | |
if (isCustomModel(model)) { | |
return []; | |
} | |
const providerMapping = model.inferenceProviderMapping.find(p => p.provider === provider); | |
if (!providerMapping && provider !== "auto") return []; | |
const allSnippets = snippets.getInferenceSnippets( | |
{ ...model, inference: "" }, | |
provider, | |
// eslint-disable-next-line @typescript-eslint/no-explicit-any | |
{ ...providerMapping, hfModelId: model.id } as any, | |
{ ...opts, directRequest: false }, | |
); | |
return allSnippets | |
.filter(s => s.language === language) | |
.map(s => { | |
if ( | |
opts?.structured_output?.schema && | |
opts.structured_output.enabled && | |
models.supportsStructuredOutput(conversation.model, provider) | |
) { | |
return { | |
...s, | |
content: modifySnippet(s.content, { | |
response_format: getResponseFormatObj(conversation), | |
}), | |
}; | |
} | |
return s; | |
}); | |
} | |
// eslint-disable-next-line svelte/prefer-svelte-reactivity | |
const tokenizers = new Map<string, PreTrainedTokenizer | null>(); | |
export async function getTokenizer(model: Model) { | |
if (tokenizers.has(model.id)) return tokenizers.get(model.id)!; | |
try { | |
const tokenizer = await AutoTokenizer.from_pretrained(model.id); | |
tokenizers.set(model.id, tokenizer); | |
return tokenizer; | |
} catch { | |
tokenizers.set(model.id, null); | |
return null; | |
} | |
} | |
// When you don't have access to a tokenizer, guesstimate | |
export function estimateTokens(conversation: ConversationClass) { | |
if (!conversation.data.messages) return 0; | |
const content = conversation.data.messages?.reduce((acc, curr) => { | |
return acc + (curr?.content ?? ""); | |
}, ""); | |
return content.length / 4; // 1 token ~ 4 characters | |
} | |
export async function getTokens(conversation: ConversationClass): Promise<number> { | |
const model = conversation.model; | |
if (isCustomModel(model)) return estimateTokens(conversation); | |
const tokenizer = await getTokenizer(model); | |
if (tokenizer === null) return estimateTokens(conversation); | |
// This is a simplified version - you might need to adjust based on your exact needs | |
let formattedText = ""; | |
conversation.data.messages?.forEach((message, index) => { | |
let content = `<|start_header_id|>${message.role}<|end_header_id|>\n\n${message.content?.trim()}<|eot_id|>`; | |
// Add BOS token to the first message | |
if (index === 0) { | |
content = "<|begin_of_text|>" + content; | |
} | |
formattedText += content; | |
}); | |
// Encode the text to get tokens | |
const encodedInput = tokenizer.encode(formattedText); | |
// Return the number of tokens | |
return encodedInput.length; | |
} | |