File size: 2,360 Bytes
59c3ada
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import { useState, useCallback } from "react";
import { pipeline, TextStreamer } from "@huggingface/transformers";
import type { TextSplitterStream } from "kokoro-js";

interface LLMState {
  isLoading: boolean;
  isReady: boolean;
  error: string | null;
  progress: number;
}

type LLMGlobal = { generator: any | null };
const g = globalThis as any;
let __LLM: LLMGlobal = g.__LLM || { generator: null };
g.__LLM = __LLM;

export type generateFn = (
  messages: Array<{ role: string; content: string }>,
  onToken?: (token: string) => void,
  splitter?: TextSplitterStream,
) => Promise<void>;

export const useLLM = () => {
  const [state, setState] = useState<LLMState>({
    isLoading: false,
    isReady: !!__LLM.generator,
    error: null,
    progress: __LLM.generator ? 100 : 0,
  });

  const load = async () => {
    if (__LLM.generator) return __LLM.generator;
    setState((p) => ({ ...p, isLoading: true, error: null, progress: 0 }));
    try {
      const generator = await pipeline("text-generation", "onnx-community/gemma-3-270m-it-ONNX", {
        dtype: "fp32",
        device: "webgpu",
        progress_callback: (item) => {
          if (item.status === "progress" && item.file?.endsWith?.("onnx_data")) {
            setState((p) => ({ ...p, progress: item.progress || 0 }));
          }
        },
      });
      __LLM.generator = generator;
      setState((p) => ({
        ...p,
        isLoading: false,
        isReady: true,
        progress: 100,
      }));
      return generator;
    } catch (error) {
      setState((p) => ({
        ...p,
        isLoading: false,
        error: error instanceof Error ? error.message : "Failed to load model",
      }));
      throw error;
    }
  };

  const generate: generateFn = useCallback(async (messages, onToken, splitter) => {
    const generator = __LLM.generator;
    if (!generator) throw new Error("Model not loaded. Call load() first.");
    const streamer = new TextStreamer(generator.tokenizer, {
      skip_prompt: true,
      skip_special_tokens: true,
      callback_function: (token: string) => {
        onToken?.(token);
        splitter?.push(token);
      },
    });
    await generator(messages, {
      max_new_tokens: 1024,
      do_sample: false,
      streamer,
    });
    splitter?.close();
  }, []);

  return {
    ...state,
    load,
    generate,
  };
};