auto-image-generator / extract-weird-reasoning-middleware.ts
T1ckbase
first commit
90989cc
import { LanguageModelV1StreamPart } from '@ai-sdk/provider';
import { LanguageModelV1Middleware } from 'ai';
/**
* Returns the index of the start of the searchedText in the text, or null if it
* is not found.
*/
export function getPotentialStartIndex(
text: string,
searchedText: string,
): number | null {
// Return null immediately if searchedText is empty.
if (searchedText.length === 0) {
return null;
}
// Check if the searchedText exists as a direct substring of text.
const directIndex = text.indexOf(searchedText);
if (directIndex !== -1) {
return directIndex;
}
// Otherwise, look for the largest suffix of "text" that matches
// a prefix of "searchedText". We go from the end of text inward.
for (let i = text.length - 1; i >= 0; i--) {
const suffix = text.substring(i);
if (searchedText.startsWith(suffix)) {
return i;
}
}
return null;
}
/**
* Extract an XML-tagged reasoning section from the generated text and exposes it
* as a `reasoning` property on the result. This version supports handling only closing tags.
*
* @param tagName - The name of the XML tag to extract reasoning from.
* @param separator - The separator to use between reasoning and text sections.
* @param onlyClosingTag - Whether to only look for closing tags (defaults to false).
*/
export function extractWeirdReasoningMiddleware({
tagName,
separator = '\n',
onlyClosingTag = false,
}: {
tagName: string;
separator?: string;
onlyClosingTag?: boolean;
}): LanguageModelV1Middleware {
const openingTag = `<${tagName}>`;
const closingTag = `<\/${tagName}>`;
return {
middlewareVersion: 'v1',
wrapGenerate: async ({ doGenerate }) => {
const { text, ...rest } = await doGenerate();
if (text == null) {
return { text, ...rest };
}
if (onlyClosingTag) {
// Split by closing tags
const parts = text.split(closingTag);
if (parts.length <= 1) {
return { text, ...rest };
}
// Everything before the last closing tag is considered reasoning
const reasoning = parts.slice(0, -1).join(separator);
const textWithoutReasoning = parts[parts.length - 1];
return {
...rest,
text: textWithoutReasoning.trim(),
reasoning: reasoning.trim(),
};
}
const regexp = new RegExp(`${openingTag}(.*?)${closingTag}`, 'gs');
const matches = Array.from(text.matchAll(regexp));
if (!matches.length) {
return { text, ...rest };
}
const reasoning = matches.map((match) => match[1]).join(separator);
let textWithoutReasoning = text;
for (let i = matches.length - 1; i >= 0; i--) {
const match = matches[i];
const beforeMatch = textWithoutReasoning.slice(0, match.index);
const afterMatch = textWithoutReasoning.slice(
match.index! + match[0].length,
);
textWithoutReasoning = beforeMatch +
(beforeMatch.length > 0 && afterMatch.length > 0 ? separator : '') +
afterMatch;
}
return {
...rest,
text: textWithoutReasoning,
reasoning,
};
},
wrapStream: async ({ doStream }) => {
const { stream, ...rest } = await doStream();
let isFirstReasoning = true;
let isFirstText = true;
let afterSwitch = false;
let isReasoning: boolean = onlyClosingTag ? true : false; // Start with reasoning if only closing tags
let buffer = '';
return {
stream: stream.pipeThrough(
new TransformStream<
LanguageModelV1StreamPart,
LanguageModelV1StreamPart
>({
transform: (chunk, controller) => {
if (chunk.type !== 'text-delta') {
controller.enqueue(chunk);
return;
}
buffer += chunk.textDelta;
function publish(text: string) {
if (text.length > 0) {
const prefix = afterSwitch &&
(isReasoning ? !isFirstReasoning : !isFirstText)
? separator
: '';
controller.enqueue({
type: isReasoning ? 'reasoning' : 'text-delta',
textDelta: prefix + text,
});
afterSwitch = false;
if (isReasoning) {
isFirstReasoning = false;
} else {
isFirstText = false;
}
}
}
do {
const nextTag = onlyClosingTag ? closingTag : (isReasoning ? closingTag : openingTag);
const startIndex = getPotentialStartIndex(buffer, nextTag);
// no tag found, publish the buffer
if (startIndex == null) {
publish(buffer);
buffer = '';
break;
}
// publish text before the tag
publish(buffer.slice(0, startIndex));
const foundFullMatch = startIndex + nextTag.length <= buffer.length;
if (foundFullMatch) {
buffer = buffer.slice(startIndex + nextTag.length);
isReasoning = onlyClosingTag ? false : !isReasoning;
afterSwitch = true;
} else {
buffer = buffer.slice(startIndex);
break;
}
} while (true);
},
}),
),
...rest,
};
},
};
}