File size: 2,634 Bytes
3cd4fb7
0bd4051
97c4991
 
1778c9e
af1f386
97c4991
 
812d95a
bf1aea7
 
73a8db9
af1f386
 
bf1aea7
 
0bd4051
97c4991
0bd4051
97c4991
 
0bd4051
97c4991
 
 
 
 
 
 
 
 
e657b46
 
 
 
 
 
 
0bd4051
 
97c4991
 
 
 
0bd4051
 
 
 
 
 
 
 
 
97c4991
 
 
0bd4051
97c4991
 
 
 
 
 
 
 
0bd4051
97c4991
 
 
 
0bd4051
1778c9e
97c4991
1778c9e
97c4991
 
3b86586
 
af1f386
3b86586
af1f386
3b86586
 
 
73a8db9
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import { type CustomModel, type Model } from "$lib/types.js";
import { edit, randomPick } from "$lib/utils/array.js";
import { safeParse } from "$lib/utils/json.js";
import typia from "typia";
import { conversations } from "./conversations.svelte";
import { getModels, getRouterData, type RouterData } from "$lib/remote/models.remote";

const LOCAL_STORAGE_KEY = "hf_inference_playground_custom_models";

const trendingSort = (a: Model, b: Model) => b.trendingScore - a.trendingScore;

class Models {
	routerData = $state<RouterData>();
	remote: Model[] = $state([]);
	trending = $derived(this.remote.toSorted(trendingSort).slice(0, 5));
	nonTrending = $derived(this.remote.filter(m => !this.trending.includes(m)).toSorted(trendingSort));
	all = $derived([...this.remote, ...this.custom]);

	constructor() {
		const savedData = localStorage.getItem(LOCAL_STORAGE_KEY);
		if (!savedData) return;

		const parsed = safeParse(savedData);
		const res = typia.validate<CustomModel[]>(parsed);
		if (res.success) {
			this.#custom = parsed;
		} else {
			localStorage.setItem(LOCAL_STORAGE_KEY, JSON.stringify([]));
		}
	}

	async load() {
		await Promise.all([getModels(), getRouterData()]).then(([models, data]) => {
			this.remote = models;
			this.routerData = data;
		});
	}

	#custom = $state.raw<CustomModel[]>([]);

	get custom() {
		return this.#custom;
	}

	set custom(models: CustomModel[]) {
		this.#custom = models;

		try {
			localStorage.setItem(LOCAL_STORAGE_KEY, JSON.stringify(models));
		} catch (e) {
			console.error("Failed to save session to localStorage:", e);
		}
	}

	addCustom(model: CustomModel) {
		if (this.#custom.find(m => m.id === model.id)) return null;
		this.custom = [...this.custom, model];
		return model;
	}

	upsertCustom(model: CustomModel) {
		const index = this.#custom.findIndex(m => m._id === model._id);
		if (index === -1) {
			this.addCustom(model);
		} else {
			this.custom = edit(this.custom, index, model);
		}
	}

	removeCustom(uuid: CustomModel["_id"]) {
		this.custom = this.custom.filter(m => m._id !== uuid);
		conversations.active.forEach(c => {
			if (c.model._id !== uuid) return;
			c.update({ modelId: randomPick(models.trending)?.id });
		});
	}

	supportsStructuredOutput(model: Model | CustomModel, provider?: string) {
		if (!this.routerData) return false;
		if (typia.is<CustomModel>(model)) return true;
		const routerDataEntry = this.routerData?.data.find(d => d.id === model.id);
		if (!routerDataEntry) return false;
		return routerDataEntry.providers.find(p => p.provider === provider)?.supports_structured_output ?? false;
	}
}

export const models = new Models();