Spaces:
Running
Running
import gptTokenizer from "gpt-tokenizer"; | |
import type { ChatMessage } from "gpt-tokenizer/GptEncoding"; | |
import prettyMilliseconds from "pretty-ms"; | |
import { addLogEntry } from "./logEntries"; | |
import { | |
getQuery, | |
getSettings, | |
getTextGenerationState, | |
listenToSettingsChanges, | |
updateImageSearchResults, | |
updateImageSearchState, | |
updateLlmTextSearchResults, | |
updateResponse, | |
updateSearchPromise, | |
updateTextGenerationState, | |
updateTextSearchResults, | |
updateTextSearchState, | |
} from "./pubSub"; | |
import { searchImages, searchText } from "./search"; | |
import { getSystemPrompt } from "./systemPrompt"; | |
import { | |
ChatGenerationError, | |
defaultContextSize, | |
getFormattedSearchResults, | |
} from "./textGenerationUtilities"; | |
import type { ImageSearchResults, TextSearchResults } from "./types"; | |
import { isWebGPUAvailable } from "./webGpu"; | |
export async function searchAndRespond() { | |
if (getQuery() === "") return; | |
document.title = getQuery(); | |
updateResponse(""); | |
updateTextSearchResults([]); | |
updateImageSearchResults([]); | |
updateSearchPromise(startTextSearch(getQuery())); | |
if (!getSettings().enableAiResponse) return; | |
const responseGenerationStartTime = Date.now(); | |
try { | |
const settings = getSettings(); | |
if (settings.inferenceType === "openai") { | |
const { generateTextWithOpenAi } = await import( | |
"./textGenerationWithOpenAi" | |
); | |
await generateTextWithOpenAi(); | |
} else if (settings.inferenceType === "internal") { | |
const { generateTextWithInternalApi } = await import( | |
"./textGenerationWithInternalApi" | |
); | |
await generateTextWithInternalApi(); | |
} else if (settings.inferenceType === "horde") { | |
const { generateTextWithHorde } = await import( | |
"./textGenerationWithHorde" | |
); | |
await generateTextWithHorde(); | |
} else { | |
await canDownloadModels(); | |
updateTextGenerationState("loadingModel"); | |
if (isWebGPUAvailable && settings.enableWebGpu) { | |
const { generateTextWithWebLlm } = await import( | |
"./textGenerationWithWebLlm" | |
); | |
await generateTextWithWebLlm(); | |
} else { | |
const { generateTextWithWllama } = await import( | |
"./textGenerationWithWllama" | |
); | |
await generateTextWithWllama(); | |
} | |
} | |
updateTextGenerationState("completed"); | |
} catch (error) { | |
if (getTextGenerationState() !== "interrupted") { | |
addLogEntry(`Error generating text: ${error}`); | |
updateTextGenerationState("failed"); | |
} | |
} | |
addLogEntry( | |
`Response generation took ${prettyMilliseconds( | |
Date.now() - responseGenerationStartTime, | |
{ verbose: true }, | |
)}`, | |
); | |
} | |
export async function generateChatResponse( | |
newMessages: ChatMessage[], | |
onUpdate: (partialResponse: string) => void, | |
) { | |
const settings = getSettings(); | |
let response = ""; | |
try { | |
const systemPrompt: ChatMessage = { | |
role: "user", | |
content: getSystemPrompt(getFormattedSearchResults(true)), | |
}; | |
const initialResponse: ChatMessage = { role: "assistant", content: "Ok!" }; | |
const systemPromptTokens = gptTokenizer.encode(systemPrompt.content).length; | |
const initialResponseTokens = gptTokenizer.encode( | |
initialResponse.content, | |
).length; | |
const reservedTokens = systemPromptTokens + initialResponseTokens; | |
const availableTokenBudget = defaultContextSize * 0.85 - reservedTokens; | |
const processedMessages: ChatMessage[] = []; | |
const reversedMessages = [...newMessages].reverse(); | |
let currentTokenCount = 0; | |
for (let i = 0; i < reversedMessages.length; i++) { | |
const message = reversedMessages[i]; | |
const messageTokens = gptTokenizer.encode(message.content).length; | |
if (currentTokenCount + messageTokens > availableTokenBudget) { | |
break; | |
} | |
processedMessages.unshift(message); | |
currentTokenCount += messageTokens; | |
} | |
if (processedMessages.length > 0) { | |
const expectedFirstRole = "user"; | |
if (processedMessages[0].role !== expectedFirstRole) { | |
processedMessages.shift(); | |
} | |
} | |
const lastMessages = [systemPrompt, initialResponse, ...processedMessages]; | |
if (settings.inferenceType === "openai") { | |
const { generateChatWithOpenAi } = await import( | |
"./textGenerationWithOpenAi" | |
); | |
response = await generateChatWithOpenAi(lastMessages, onUpdate); | |
} else if (settings.inferenceType === "internal") { | |
const { generateChatWithInternalApi } = await import( | |
"./textGenerationWithInternalApi" | |
); | |
response = await generateChatWithInternalApi(lastMessages, onUpdate); | |
} else if (settings.inferenceType === "horde") { | |
const { generateChatWithHorde } = await import( | |
"./textGenerationWithHorde" | |
); | |
response = await generateChatWithHorde(lastMessages, onUpdate); | |
} else { | |
if (isWebGPUAvailable && settings.enableWebGpu) { | |
const { generateChatWithWebLlm } = await import( | |
"./textGenerationWithWebLlm" | |
); | |
response = await generateChatWithWebLlm(lastMessages, onUpdate); | |
} else { | |
const { generateChatWithWllama } = await import( | |
"./textGenerationWithWllama" | |
); | |
response = await generateChatWithWllama(lastMessages, onUpdate); | |
} | |
} | |
} catch (error) { | |
if (error instanceof ChatGenerationError) { | |
addLogEntry(`Chat generation interrupted: ${error.message}`); | |
} else { | |
addLogEntry(`Error generating chat response: ${error}`); | |
} | |
throw error; | |
} | |
return response; | |
} | |
async function getKeywords(text: string, limit?: number) { | |
return (await import("keyword-extractor")).default | |
.extract(text, { language: "english" }) | |
.slice(0, limit); | |
} | |
async function startTextSearch(query: string) { | |
const results = { | |
textResults: [] as TextSearchResults, | |
imageResults: [] as ImageSearchResults, | |
}; | |
const searchQuery = | |
query.length > 2000 ? (await getKeywords(query, 20)).join(" ") : query; | |
if (getSettings().enableImageSearch) { | |
updateImageSearchState("running"); | |
} | |
if (getSettings().enableTextSearch) { | |
updateTextSearchState("running"); | |
let textResults = await searchText( | |
searchQuery, | |
getSettings().searchResultsLimit, | |
); | |
if (textResults.length === 0) { | |
const queryKeywords = await getKeywords(query, 10); | |
const keywordResults = await searchText( | |
queryKeywords.join(" "), | |
getSettings().searchResultsLimit, | |
); | |
textResults = keywordResults; | |
} | |
results.textResults = textResults; | |
updateTextSearchState( | |
results.textResults.length === 0 ? "failed" : "completed", | |
); | |
updateTextSearchResults(textResults); | |
updateLlmTextSearchResults( | |
textResults.slice(0, getSettings().searchResultsToConsider), | |
); | |
} | |
if (getSettings().enableImageSearch) { | |
startImageSearch(searchQuery, results); | |
} | |
return results; | |
} | |
async function startImageSearch( | |
searchQuery: string, | |
results: { textResults: TextSearchResults; imageResults: ImageSearchResults }, | |
) { | |
const imageResults = await searchImages( | |
searchQuery, | |
getSettings().searchResultsLimit, | |
); | |
results.imageResults = imageResults; | |
updateImageSearchState( | |
results.imageResults.length === 0 ? "failed" : "completed", | |
); | |
updateImageSearchResults(imageResults); | |
} | |
function canDownloadModels(): Promise<void> { | |
return new Promise((resolve) => { | |
if (getSettings().allowAiModelDownload) { | |
resolve(); | |
} else { | |
updateTextGenerationState("awaitingModelDownloadAllowance"); | |
listenToSettingsChanges((settings) => { | |
if (settings.allowAiModelDownload) { | |
resolve(); | |
} | |
}); | |
} | |
}); | |
} | |