coyotte508 HF Staff commited on
Commit
21dd449
·
verified ·
1 Parent(s): e22900f

Add 1 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. consts.ts +1 -0
  2. error.ts +49 -0
  3. index.ts +25 -0
  4. lib/cache-management.spec.ts +137 -0
  5. lib/cache-management.ts +265 -0
  6. lib/check-repo-access.spec.ts +34 -0
  7. lib/check-repo-access.ts +32 -0
  8. lib/commit.spec.ts +271 -0
  9. lib/commit.ts +609 -0
  10. lib/count-commits.spec.ts +16 -0
  11. lib/count-commits.ts +35 -0
  12. lib/create-branch.spec.ts +159 -0
  13. lib/create-branch.ts +54 -0
  14. lib/create-repo.spec.ts +103 -0
  15. lib/create-repo.ts +78 -0
  16. lib/dataset-info.spec.ts +56 -0
  17. lib/dataset-info.ts +61 -0
  18. lib/delete-branch.spec.ts +43 -0
  19. lib/delete-branch.ts +32 -0
  20. lib/delete-file.spec.ts +64 -0
  21. lib/delete-file.ts +35 -0
  22. lib/delete-files.spec.ts +81 -0
  23. lib/delete-files.ts +33 -0
  24. lib/delete-repo.ts +37 -0
  25. lib/download-file-to-cache-dir.spec.ts +306 -0
  26. lib/download-file-to-cache-dir.ts +138 -0
  27. lib/download-file.spec.ts +82 -0
  28. lib/download-file.ts +77 -0
  29. lib/file-download-info.spec.ts +59 -0
  30. lib/file-download-info.ts +151 -0
  31. lib/file-exists.spec.ts +30 -0
  32. lib/file-exists.ts +41 -0
  33. lib/index.ts +32 -0
  34. lib/list-commits.spec.ts +117 -0
  35. lib/list-commits.ts +70 -0
  36. lib/list-datasets.spec.ts +47 -0
  37. lib/list-datasets.ts +121 -0
  38. lib/list-files.spec.ts +173 -0
  39. lib/list-files.ts +94 -0
  40. lib/list-models.spec.ts +118 -0
  41. lib/list-models.ts +139 -0
  42. lib/list-spaces.spec.ts +40 -0
  43. lib/list-spaces.ts +111 -0
  44. lib/model-info.spec.ts +59 -0
  45. lib/model-info.ts +62 -0
  46. lib/oauth-handle-redirect.spec.ts +60 -0
  47. lib/oauth-handle-redirect.ts +334 -0
  48. lib/oauth-login-url.ts +166 -0
  49. lib/parse-safetensors-metadata.spec.ts +122 -0
  50. lib/parse-safetensors-metadata.ts +274 -0
consts.ts ADDED
@@ -0,0 +1 @@
 
 
1
+ export const HUB_URL = "https://huggingface.co";
error.ts ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import type { JsonObject } from "./vendor/type-fest/basic";
2
+
3
+ export async function createApiError(
4
+ response: Response,
5
+ opts?: { requestId?: string; message?: string }
6
+ ): Promise<never> {
7
+ const error = new HubApiError(response.url, response.status, response.headers.get("X-Request-Id") ?? opts?.requestId);
8
+
9
+ error.message = `Api error with status ${error.statusCode}${opts?.message ? `. ${opts.message}` : ""}`;
10
+
11
+ const trailer = [`URL: ${error.url}`, error.requestId ? `Request ID: ${error.requestId}` : undefined]
12
+ .filter(Boolean)
13
+ .join(". ");
14
+
15
+ if (response.headers.get("Content-Type")?.startsWith("application/json")) {
16
+ const json = await response.json();
17
+ error.message = json.error || json.message || error.message;
18
+ if (json.error_description) {
19
+ error.message = error.message ? error.message + `: ${json.error_description}` : json.error_description;
20
+ }
21
+ error.data = json;
22
+ } else {
23
+ error.data = { message: await response.text() };
24
+ }
25
+
26
+ error.message += `. ${trailer}`;
27
+
28
+ throw error;
29
+ }
30
+
31
+ /**
32
+ * Error thrown when an API call to the Hugging Face Hub fails.
33
+ */
34
+ export class HubApiError extends Error {
35
+ statusCode: number;
36
+ url: string;
37
+ requestId?: string;
38
+ data?: JsonObject;
39
+
40
+ constructor(url: string, statusCode: number, requestId?: string, message?: string) {
41
+ super(message);
42
+
43
+ this.statusCode = statusCode;
44
+ this.requestId = requestId;
45
+ this.url = url;
46
+ }
47
+ }
48
+
49
+ export class InvalidApiResponseFormatError extends Error {}
index.ts ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export * from "./lib";
2
+ // Typescript 5 will add 'export type *'
3
+ export type {
4
+ AccessToken,
5
+ AccessTokenRole,
6
+ AuthType,
7
+ Credentials,
8
+ PipelineType,
9
+ RepoDesignation,
10
+ RepoFullName,
11
+ RepoId,
12
+ RepoType,
13
+ SpaceHardwareFlavor,
14
+ SpaceResourceConfig,
15
+ SpaceResourceRequirement,
16
+ SpaceRuntime,
17
+ SpaceSdk,
18
+ SpaceStage,
19
+ } from "./types/public";
20
+ export { HubApiError, InvalidApiResponseFormatError } from "./error";
21
+ /**
22
+ * Only exported for E2Es convenience
23
+ */
24
+ export { sha256 as __internal_sha256 } from "./utils/sha256";
25
+ export { XetBlob as __internal_XetBlob } from "./utils/XetBlob";
lib/cache-management.spec.ts ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { describe, test, expect, vi, beforeEach } from "vitest";
2
+ import {
3
+ scanCacheDir,
4
+ scanCachedRepo,
5
+ scanSnapshotDir,
6
+ parseRepoType,
7
+ getBlobStat,
8
+ type CachedFileInfo,
9
+ } from "./cache-management";
10
+ import { stat, readdir, realpath, lstat } from "node:fs/promises";
11
+ import type { Dirent, Stats } from "node:fs";
12
+ import { join } from "node:path";
13
+
14
+ // Mocks
15
+ vi.mock("node:fs/promises");
16
+
17
+ beforeEach(() => {
18
+ vi.resetAllMocks();
19
+ vi.restoreAllMocks();
20
+ });
21
+
22
+ describe("scanCacheDir", () => {
23
+ test("should throw an error if cacheDir is not a directory", async () => {
24
+ vi.mocked(stat).mockResolvedValueOnce({
25
+ isDirectory: () => false,
26
+ } as Stats);
27
+
28
+ await expect(scanCacheDir("/fake/dir")).rejects.toThrow("Scan cache expects a directory");
29
+ });
30
+
31
+ test("empty directory should return an empty set of repository and no warnings", async () => {
32
+ vi.mocked(stat).mockResolvedValueOnce({
33
+ isDirectory: () => true,
34
+ } as Stats);
35
+
36
+ // mock empty cache folder
37
+ vi.mocked(readdir).mockResolvedValue([]);
38
+
39
+ const result = await scanCacheDir("/fake/dir");
40
+
41
+ // cacheDir must have been read
42
+ expect(readdir).toHaveBeenCalledWith("/fake/dir");
43
+
44
+ expect(result.warnings.length).toBe(0);
45
+ expect(result.repos).toHaveLength(0);
46
+ expect(result.size).toBe(0);
47
+ });
48
+ });
49
+
50
+ describe("scanCachedRepo", () => {
51
+ test("should throw an error for invalid repo path", async () => {
52
+ await expect(() => {
53
+ return scanCachedRepo("/fake/repo_path");
54
+ }).rejects.toThrow("Repo path is not a valid HuggingFace cache directory");
55
+ });
56
+
57
+ test("should throw an error if the snapshot folder does not exist", async () => {
58
+ vi.mocked(readdir).mockResolvedValue([]);
59
+ vi.mocked(stat).mockResolvedValue({
60
+ isDirectory: () => false,
61
+ } as Stats);
62
+
63
+ await expect(() => {
64
+ return scanCachedRepo("/fake/cacheDir/models--hello-world--name");
65
+ }).rejects.toThrow("Snapshots dir doesn't exist in cached repo");
66
+ });
67
+
68
+ test("should properly parse the repository name", async () => {
69
+ const repoPath = "/fake/cacheDir/models--hello-world--name";
70
+ vi.mocked(readdir).mockResolvedValue([]);
71
+ vi.mocked(stat).mockResolvedValue({
72
+ isDirectory: () => true,
73
+ } as Stats);
74
+
75
+ const result = await scanCachedRepo(repoPath);
76
+ expect(readdir).toHaveBeenCalledWith(join(repoPath, "refs"), {
77
+ withFileTypes: true,
78
+ });
79
+
80
+ expect(result.id.name).toBe("hello-world/name");
81
+ expect(result.id.type).toBe("model");
82
+ });
83
+ });
84
+
85
+ describe("scanSnapshotDir", () => {
86
+ test("should scan a valid snapshot directory", async () => {
87
+ const cachedFiles: CachedFileInfo[] = [];
88
+ const blobStats = new Map<string, Stats>();
89
+ vi.mocked(readdir).mockResolvedValueOnce([{ name: "file1", isDirectory: () => false } as Dirent]);
90
+
91
+ vi.mocked(realpath).mockResolvedValueOnce("/fake/realpath");
92
+ vi.mocked(lstat).mockResolvedValueOnce({ size: 1024, atimeMs: Date.now(), mtimeMs: Date.now() } as Stats);
93
+
94
+ await scanSnapshotDir("/fake/revision", cachedFiles, blobStats);
95
+
96
+ expect(cachedFiles).toHaveLength(1);
97
+ expect(blobStats.size).toBe(1);
98
+ });
99
+ });
100
+
101
+ describe("getBlobStat", () => {
102
+ test("should retrieve blob stat if already cached", async () => {
103
+ const blobStats = new Map<string, Stats>([["/fake/blob", { size: 1024 } as Stats]]);
104
+ const result = await getBlobStat("/fake/blob", blobStats);
105
+
106
+ expect(lstat).not.toHaveBeenCalled();
107
+ expect(result.size).toBe(1024);
108
+ });
109
+
110
+ test("should fetch and cache blob stat if not cached", async () => {
111
+ const blobStats = new Map();
112
+ vi.mocked(lstat).mockResolvedValueOnce({ size: 2048 } as Stats);
113
+
114
+ const result = await getBlobStat("/fake/blob", blobStats);
115
+
116
+ expect(result.size).toBe(2048);
117
+ expect(blobStats.size).toBe(1);
118
+ });
119
+ });
120
+
121
+ describe("parseRepoType", () => {
122
+ test("should parse models repo type", () => {
123
+ expect(parseRepoType("models")).toBe("model");
124
+ });
125
+
126
+ test("should parse dataset repo type", () => {
127
+ expect(parseRepoType("datasets")).toBe("dataset");
128
+ });
129
+
130
+ test("should parse space repo type", () => {
131
+ expect(parseRepoType("spaces")).toBe("space");
132
+ });
133
+
134
+ test("should throw an error for invalid repo type", () => {
135
+ expect(() => parseRepoType("invalid")).toThrowError("Invalid repo type: invalid");
136
+ });
137
+ });
lib/cache-management.ts ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { homedir } from "node:os";
2
+ import { join, basename } from "node:path";
3
+ import { stat, readdir, readFile, realpath, lstat } from "node:fs/promises";
4
+ import type { Stats } from "node:fs";
5
+ import type { RepoType, RepoId } from "../types/public";
6
+
7
+ function getDefaultHome(): string {
8
+ return join(homedir(), ".cache");
9
+ }
10
+
11
+ function getDefaultCachePath(): string {
12
+ return join(process.env["HF_HOME"] ?? join(process.env["XDG_CACHE_HOME"] ?? getDefaultHome(), "huggingface"), "hub");
13
+ }
14
+
15
+ function getHuggingFaceHubCache(): string {
16
+ return process.env["HUGGINGFACE_HUB_CACHE"] ?? getDefaultCachePath();
17
+ }
18
+
19
+ export function getHFHubCachePath(): string {
20
+ return process.env["HF_HUB_CACHE"] ?? getHuggingFaceHubCache();
21
+ }
22
+
23
+ const FILES_TO_IGNORE: string[] = [".DS_Store"];
24
+
25
+ export const REPO_ID_SEPARATOR: string = "--";
26
+
27
+ export function getRepoFolderName({ name, type }: RepoId): string {
28
+ const parts = [`${type}s`, ...name.split("/")];
29
+ return parts.join(REPO_ID_SEPARATOR);
30
+ }
31
+
32
+ export interface CachedFileInfo {
33
+ path: string;
34
+ /**
35
+ * Underlying file - which `path` is symlinked to
36
+ */
37
+ blob: {
38
+ size: number;
39
+ path: string;
40
+ lastModifiedAt: Date;
41
+ lastAccessedAt: Date;
42
+ };
43
+ }
44
+
45
+ export interface CachedRevisionInfo {
46
+ commitOid: string;
47
+ path: string;
48
+ size: number;
49
+ files: CachedFileInfo[];
50
+ refs: string[];
51
+
52
+ lastModifiedAt: Date;
53
+ }
54
+
55
+ export interface CachedRepoInfo {
56
+ id: RepoId;
57
+ path: string;
58
+ size: number;
59
+ filesCount: number;
60
+ revisions: CachedRevisionInfo[];
61
+
62
+ lastAccessedAt: Date;
63
+ lastModifiedAt: Date;
64
+ }
65
+
66
+ export interface HFCacheInfo {
67
+ size: number;
68
+ repos: CachedRepoInfo[];
69
+ warnings: Error[];
70
+ }
71
+
72
+ export async function scanCacheDir(cacheDir: string | undefined = undefined): Promise<HFCacheInfo> {
73
+ if (!cacheDir) cacheDir = getHFHubCachePath();
74
+
75
+ const s = await stat(cacheDir);
76
+ if (!s.isDirectory()) {
77
+ throw new Error(
78
+ `Scan cache expects a directory but found a file: ${cacheDir}. Please use \`cacheDir\` argument or set \`HF_HUB_CACHE\` environment variable.`
79
+ );
80
+ }
81
+
82
+ const repos: CachedRepoInfo[] = [];
83
+ const warnings: Error[] = [];
84
+
85
+ const directories = await readdir(cacheDir);
86
+ for (const repo of directories) {
87
+ // skip .locks folder
88
+ if (repo === ".locks") continue;
89
+
90
+ // get the absolute path of the repo
91
+ const absolute = join(cacheDir, repo);
92
+
93
+ // ignore non-directory element
94
+ const s = await stat(absolute);
95
+ if (!s.isDirectory()) {
96
+ continue;
97
+ }
98
+
99
+ try {
100
+ const cached = await scanCachedRepo(absolute);
101
+ repos.push(cached);
102
+ } catch (err: unknown) {
103
+ warnings.push(err as Error);
104
+ }
105
+ }
106
+
107
+ return {
108
+ repos: repos,
109
+ size: [...repos.values()].reduce((sum, repo) => sum + repo.size, 0),
110
+ warnings: warnings,
111
+ };
112
+ }
113
+
114
+ export async function scanCachedRepo(repoPath: string): Promise<CachedRepoInfo> {
115
+ // get the directory name
116
+ const name = basename(repoPath);
117
+ if (!name.includes(REPO_ID_SEPARATOR)) {
118
+ throw new Error(`Repo path is not a valid HuggingFace cache directory: ${name}`);
119
+ }
120
+
121
+ // parse the repoId from directory name
122
+ const [type, ...remaining] = name.split(REPO_ID_SEPARATOR);
123
+ const repoType = parseRepoType(type);
124
+ const repoId = remaining.join("/");
125
+
126
+ const snapshotsPath = join(repoPath, "snapshots");
127
+ const refsPath = join(repoPath, "refs");
128
+
129
+ const snapshotStat = await stat(snapshotsPath);
130
+ if (!snapshotStat.isDirectory()) {
131
+ throw new Error(`Snapshots dir doesn't exist in cached repo ${snapshotsPath}`);
132
+ }
133
+
134
+ // Check if the refs directory exists and scan it
135
+ const refsByHash: Map<string, string[]> = new Map();
136
+ const refsStat = await stat(refsPath);
137
+ if (refsStat.isDirectory()) {
138
+ await scanRefsDir(refsPath, refsByHash);
139
+ }
140
+
141
+ // Scan snapshots directory and collect cached revision information
142
+ const cachedRevisions: CachedRevisionInfo[] = [];
143
+ const blobStats: Map<string, Stats> = new Map(); // Store blob stats
144
+
145
+ const snapshotDirs = await readdir(snapshotsPath);
146
+ for (const dir of snapshotDirs) {
147
+ if (FILES_TO_IGNORE.includes(dir)) continue; // Ignore unwanted files
148
+
149
+ const revisionPath = join(snapshotsPath, dir);
150
+ const revisionStat = await stat(revisionPath);
151
+ if (!revisionStat.isDirectory()) {
152
+ throw new Error(`Snapshots folder corrupted. Found a file: ${revisionPath}`);
153
+ }
154
+
155
+ const cachedFiles: CachedFileInfo[] = [];
156
+ await scanSnapshotDir(revisionPath, cachedFiles, blobStats);
157
+
158
+ const revisionLastModified =
159
+ cachedFiles.length > 0
160
+ ? Math.max(...[...cachedFiles].map((file) => file.blob.lastModifiedAt.getTime()))
161
+ : revisionStat.mtimeMs;
162
+
163
+ cachedRevisions.push({
164
+ commitOid: dir,
165
+ files: cachedFiles,
166
+ refs: refsByHash.get(dir) || [],
167
+ size: [...cachedFiles].reduce((sum, file) => sum + file.blob.size, 0),
168
+ path: revisionPath,
169
+ lastModifiedAt: new Date(revisionLastModified),
170
+ });
171
+
172
+ refsByHash.delete(dir);
173
+ }
174
+
175
+ // Verify that all refs refer to a valid revision
176
+ if (refsByHash.size > 0) {
177
+ throw new Error(
178
+ `Reference(s) refer to missing commit hashes: ${JSON.stringify(Object.fromEntries(refsByHash))} (${repoPath})`
179
+ );
180
+ }
181
+
182
+ const repoStats = await stat(repoPath);
183
+ const repoLastAccessed =
184
+ blobStats.size > 0 ? Math.max(...[...blobStats.values()].map((stat) => stat.atimeMs)) : repoStats.atimeMs;
185
+
186
+ const repoLastModified =
187
+ blobStats.size > 0 ? Math.max(...[...blobStats.values()].map((stat) => stat.mtimeMs)) : repoStats.mtimeMs;
188
+
189
+ // Return the constructed CachedRepoInfo object
190
+ return {
191
+ id: {
192
+ name: repoId,
193
+ type: repoType,
194
+ },
195
+ path: repoPath,
196
+ filesCount: blobStats.size,
197
+ revisions: cachedRevisions,
198
+ size: [...blobStats.values()].reduce((sum, stat) => sum + stat.size, 0),
199
+ lastAccessedAt: new Date(repoLastAccessed),
200
+ lastModifiedAt: new Date(repoLastModified),
201
+ };
202
+ }
203
+
204
+ export async function scanRefsDir(refsPath: string, refsByHash: Map<string, string[]>): Promise<void> {
205
+ const refFiles = await readdir(refsPath, { withFileTypes: true });
206
+ for (const refFile of refFiles) {
207
+ const refFilePath = join(refsPath, refFile.name);
208
+ if (refFile.isDirectory()) continue; // Skip directories
209
+
210
+ const commitHash = await readFile(refFilePath, "utf-8");
211
+ const refName = refFile.name;
212
+ if (!refsByHash.has(commitHash)) {
213
+ refsByHash.set(commitHash, []);
214
+ }
215
+ refsByHash.get(commitHash)?.push(refName);
216
+ }
217
+ }
218
+
219
+ export async function scanSnapshotDir(
220
+ revisionPath: string,
221
+ cachedFiles: CachedFileInfo[],
222
+ blobStats: Map<string, Stats>
223
+ ): Promise<void> {
224
+ const files = await readdir(revisionPath, { withFileTypes: true });
225
+ for (const file of files) {
226
+ if (file.isDirectory()) continue; // Skip directories
227
+
228
+ const filePath = join(revisionPath, file.name);
229
+ const blobPath = await realpath(filePath);
230
+ const blobStat = await getBlobStat(blobPath, blobStats);
231
+
232
+ cachedFiles.push({
233
+ path: filePath,
234
+ blob: {
235
+ path: blobPath,
236
+ size: blobStat.size,
237
+ lastAccessedAt: new Date(blobStat.atimeMs),
238
+ lastModifiedAt: new Date(blobStat.mtimeMs),
239
+ },
240
+ });
241
+ }
242
+ }
243
+
244
+ export async function getBlobStat(blobPath: string, blobStats: Map<string, Stats>): Promise<Stats> {
245
+ const blob = blobStats.get(blobPath);
246
+ if (!blob) {
247
+ const statResult = await lstat(blobPath);
248
+ blobStats.set(blobPath, statResult);
249
+ return statResult;
250
+ }
251
+ return blob;
252
+ }
253
+
254
+ export function parseRepoType(type: string): RepoType {
255
+ switch (type) {
256
+ case "models":
257
+ return "model";
258
+ case "datasets":
259
+ return "dataset";
260
+ case "spaces":
261
+ return "space";
262
+ default:
263
+ throw new TypeError(`Invalid repo type: ${type}`);
264
+ }
265
+ }
lib/check-repo-access.spec.ts ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { assert, describe, expect, it } from "vitest";
2
+ import { checkRepoAccess } from "./check-repo-access";
3
+ import { HubApiError } from "../error";
4
+ import { TEST_ACCESS_TOKEN, TEST_HUB_URL } from "../test/consts";
5
+
6
+ describe("checkRepoAccess", () => {
7
+ it("should throw 401 when accessing unexisting repo unauthenticated", async () => {
8
+ try {
9
+ await checkRepoAccess({ repo: { name: "i--d/dont", type: "model" } });
10
+ assert(false, "should have thrown");
11
+ } catch (err) {
12
+ expect(err).toBeInstanceOf(HubApiError);
13
+ expect((err as HubApiError).statusCode).toBe(401);
14
+ }
15
+ });
16
+
17
+ it("should throw 404 when accessing unexisting repo authenticated", async () => {
18
+ try {
19
+ await checkRepoAccess({
20
+ repo: { name: "i--d/dont", type: "model" },
21
+ hubUrl: TEST_HUB_URL,
22
+ accessToken: TEST_ACCESS_TOKEN,
23
+ });
24
+ assert(false, "should have thrown");
25
+ } catch (err) {
26
+ expect(err).toBeInstanceOf(HubApiError);
27
+ expect((err as HubApiError).statusCode).toBe(404);
28
+ }
29
+ });
30
+
31
+ it("should not throw when accessing public repo", async () => {
32
+ await checkRepoAccess({ repo: { name: "openai-community/gpt2", type: "model" } });
33
+ });
34
+ });
lib/check-repo-access.ts ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { HUB_URL } from "../consts";
2
+ // eslint-disable-next-line @typescript-eslint/no-unused-vars
3
+ import { createApiError, type HubApiError } from "../error";
4
+ import type { CredentialsParams, RepoDesignation } from "../types/public";
5
+ import { checkCredentials } from "../utils/checkCredentials";
6
+ import { toRepoId } from "../utils/toRepoId";
7
+
8
+ /**
9
+ * Check if we have read access to a repository.
10
+ *
11
+ * Throw a {@link HubApiError} error if we don't have access. HubApiError.statusCode will be 401, 403 or 404.
12
+ */
13
+ export async function checkRepoAccess(
14
+ params: {
15
+ repo: RepoDesignation;
16
+ hubUrl?: string;
17
+ fetch?: typeof fetch;
18
+ } & Partial<CredentialsParams>
19
+ ): Promise<void> {
20
+ const accessToken = params && checkCredentials(params);
21
+ const repoId = toRepoId(params.repo);
22
+
23
+ const response = await (params.fetch || fetch)(`${params?.hubUrl || HUB_URL}/api/${repoId.type}s/${repoId.name}`, {
24
+ headers: {
25
+ ...(accessToken ? { Authorization: `Bearer ${accessToken}` } : {}),
26
+ },
27
+ });
28
+
29
+ if (!response.ok) {
30
+ throw await createApiError(response);
31
+ }
32
+ }
lib/commit.spec.ts ADDED
@@ -0,0 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { assert, it, describe } from "vitest";
2
+
3
+ import { TEST_HUB_URL, TEST_ACCESS_TOKEN, TEST_USER } from "../test/consts";
4
+ import type { RepoId } from "../types/public";
5
+ import type { CommitFile } from "./commit";
6
+ import { commit } from "./commit";
7
+ import { createRepo } from "./create-repo";
8
+ import { deleteRepo } from "./delete-repo";
9
+ import { downloadFile } from "./download-file";
10
+ import { fileDownloadInfo } from "./file-download-info";
11
+ import { insecureRandomString } from "../utils/insecureRandomString";
12
+ import { isFrontend } from "../utils/isFrontend";
13
+
14
+ const lfsContent = "O123456789".repeat(100_000);
15
+
16
+ describe("commit", () => {
17
+ it("should commit to a repo with blobs", async function () {
18
+ const tokenizerJsonUrl = new URL(
19
+ "https://huggingface.co/spaces/aschen/push-model-from-web/raw/main/mobilenet/model.json"
20
+ );
21
+ const repoName = `${TEST_USER}/TEST-${insecureRandomString()}`;
22
+ const repo: RepoId = {
23
+ name: repoName,
24
+ type: "model",
25
+ };
26
+
27
+ await createRepo({
28
+ accessToken: TEST_ACCESS_TOKEN,
29
+ hubUrl: TEST_HUB_URL,
30
+ repo,
31
+ license: "mit",
32
+ });
33
+
34
+ try {
35
+ const readme1 = await downloadFile({ repo, path: "README.md", hubUrl: TEST_HUB_URL });
36
+ assert(readme1, "Readme doesn't exist");
37
+
38
+ const nodeOperation: CommitFile[] = isFrontend
39
+ ? []
40
+ : [
41
+ {
42
+ operation: "addOrUpdate",
43
+ path: "tsconfig.json",
44
+ content: (await import("node:url")).pathToFileURL("./tsconfig.json") as URL,
45
+ },
46
+ ];
47
+
48
+ await commit({
49
+ repo,
50
+ title: "Some commit",
51
+ accessToken: TEST_ACCESS_TOKEN,
52
+ hubUrl: TEST_HUB_URL,
53
+ operations: [
54
+ {
55
+ operation: "addOrUpdate",
56
+ content: new Blob(["This is me"]),
57
+ path: "test.txt",
58
+ },
59
+ {
60
+ operation: "addOrUpdate",
61
+ content: new Blob([lfsContent]),
62
+ path: "test.lfs.txt",
63
+ },
64
+ ...nodeOperation,
65
+ {
66
+ operation: "addOrUpdate",
67
+ content: tokenizerJsonUrl,
68
+ path: "lamaral.json",
69
+ },
70
+ {
71
+ operation: "delete",
72
+ path: "README.md",
73
+ },
74
+ ],
75
+ // To test web workers in the front-end
76
+ useWebWorkers: { minSize: 5_000 },
77
+ });
78
+
79
+ const fileContent = await downloadFile({ repo, path: "test.txt", hubUrl: TEST_HUB_URL });
80
+ assert.strictEqual(await fileContent?.text(), "This is me");
81
+
82
+ const lfsFileContent = await downloadFile({ repo, path: "test.lfs.txt", hubUrl: TEST_HUB_URL });
83
+ assert.strictEqual(await lfsFileContent?.text(), lfsContent);
84
+
85
+ const lfsFileUrl = `${TEST_HUB_URL}/${repoName}/raw/main/test.lfs.txt`;
86
+ const lfsFilePointer = await fetch(lfsFileUrl);
87
+ assert.strictEqual(lfsFilePointer.status, 200);
88
+ assert.strictEqual(
89
+ (await lfsFilePointer.text()).trim(),
90
+ `
91
+ version https://git-lfs.github.com/spec/v1
92
+ oid sha256:a3bbce7ee1df7233d85b5f4d60faa3755f93f537804f8b540c72b0739239ddf8
93
+ size ${lfsContent.length}
94
+ `.trim()
95
+ );
96
+
97
+ if (!isFrontend) {
98
+ const fileUrlContent = await downloadFile({ repo, path: "tsconfig.json", hubUrl: TEST_HUB_URL });
99
+ assert.strictEqual(
100
+ await fileUrlContent?.text(),
101
+ (await import("node:fs")).readFileSync("./tsconfig.json", "utf-8")
102
+ );
103
+ }
104
+
105
+ const webResourceContent = await downloadFile({ repo, path: "lamaral.json", hubUrl: TEST_HUB_URL });
106
+ assert.strictEqual(await webResourceContent?.text(), await (await fetch(tokenizerJsonUrl)).text());
107
+
108
+ const readme2 = await downloadFile({ repo, path: "README.md", hubUrl: TEST_HUB_URL });
109
+ assert.strictEqual(readme2, null);
110
+ } finally {
111
+ await deleteRepo({
112
+ repo: {
113
+ name: repoName,
114
+ type: "model",
115
+ },
116
+ hubUrl: TEST_HUB_URL,
117
+ credentials: { accessToken: TEST_ACCESS_TOKEN },
118
+ });
119
+ }
120
+ }, 60_000);
121
+
122
+ it("should commit a full repo from HF with web urls", async function () {
123
+ const repoName = `${TEST_USER}/TEST-${insecureRandomString()}`;
124
+ const repo: RepoId = {
125
+ name: repoName,
126
+ type: "model",
127
+ };
128
+
129
+ await createRepo({
130
+ accessToken: TEST_ACCESS_TOKEN,
131
+ repo,
132
+ hubUrl: TEST_HUB_URL,
133
+ });
134
+
135
+ try {
136
+ const FILES_TO_UPLOAD = [
137
+ `https://huggingface.co/spaces/huggingfacejs/push-model-from-web/resolve/main/mobilenet/model.json`,
138
+ `https://huggingface.co/spaces/huggingfacejs/push-model-from-web/resolve/main/mobilenet/group1-shard1of2`,
139
+ `https://huggingface.co/spaces/huggingfacejs/push-model-from-web/resolve/main/mobilenet/group1-shard2of2`,
140
+ `https://huggingface.co/spaces/huggingfacejs/push-model-from-web/resolve/main/mobilenet/coffee.jpg`,
141
+ `https://huggingface.co/spaces/huggingfacejs/push-model-from-web/resolve/main/mobilenet/README.md`,
142
+ ];
143
+
144
+ const operations: CommitFile[] = await Promise.all(
145
+ FILES_TO_UPLOAD.map(async (file) => {
146
+ return {
147
+ operation: "addOrUpdate",
148
+ path: file.slice(file.indexOf("main/") + "main/".length),
149
+ // upload remote file
150
+ content: new URL(file),
151
+ };
152
+ })
153
+ );
154
+ await commit({
155
+ repo,
156
+ accessToken: TEST_ACCESS_TOKEN,
157
+ hubUrl: TEST_HUB_URL,
158
+ title: "upload model",
159
+ operations,
160
+ });
161
+
162
+ const LFSSize = (await fileDownloadInfo({ repo, path: "mobilenet/group1-shard1of2", hubUrl: TEST_HUB_URL }))
163
+ ?.size;
164
+
165
+ assert.strictEqual(LFSSize, 4_194_304);
166
+
167
+ const pointerFile = await downloadFile({
168
+ repo,
169
+ path: "mobilenet/group1-shard1of2",
170
+ raw: true,
171
+ hubUrl: TEST_HUB_URL,
172
+ });
173
+
174
+ // Make sure SHA is computed properly as well
175
+ assert.strictEqual(
176
+ (await pointerFile?.text())?.trim(),
177
+ `
178
+ version https://git-lfs.github.com/spec/v1
179
+ oid sha256:3fb621eb9b37478239504ee083042d5b18699e8b8618e569478b03b119a85a69
180
+ size 4194304
181
+ `.trim()
182
+ );
183
+ } finally {
184
+ await deleteRepo({
185
+ repo: {
186
+ name: repoName,
187
+ type: "model",
188
+ },
189
+ hubUrl: TEST_HUB_URL,
190
+ credentials: { accessToken: TEST_ACCESS_TOKEN },
191
+ });
192
+ }
193
+ // https://huggingfacejs-push-model-from-web.hf.space/
194
+ }, 60_000);
195
+
196
+ it("should be able to create a PR and then commit to it", async function () {
197
+ const repoName = `${TEST_USER}/TEST-${insecureRandomString()}`;
198
+ const repo: RepoId = {
199
+ name: repoName,
200
+ type: "model",
201
+ };
202
+
203
+ await createRepo({
204
+ credentials: {
205
+ accessToken: TEST_ACCESS_TOKEN,
206
+ },
207
+ repo,
208
+ hubUrl: TEST_HUB_URL,
209
+ });
210
+
211
+ try {
212
+ const pr = await commit({
213
+ repo,
214
+ credentials: {
215
+ accessToken: TEST_ACCESS_TOKEN,
216
+ },
217
+ hubUrl: TEST_HUB_URL,
218
+ title: "Create PR",
219
+ isPullRequest: true,
220
+ operations: [
221
+ {
222
+ operation: "addOrUpdate",
223
+ content: new Blob(["This is me"]),
224
+ path: "test.txt",
225
+ },
226
+ ],
227
+ });
228
+
229
+ if (!pr) {
230
+ throw new Error("PR creation failed");
231
+ }
232
+
233
+ if (!pr.pullRequestUrl) {
234
+ throw new Error("No pull request url");
235
+ }
236
+
237
+ const prNumber = pr.pullRequestUrl.split("/").pop();
238
+ const prRef = `refs/pr/${prNumber}`;
239
+
240
+ await commit({
241
+ repo,
242
+ credentials: {
243
+ accessToken: TEST_ACCESS_TOKEN,
244
+ },
245
+ hubUrl: TEST_HUB_URL,
246
+ branch: prRef,
247
+ title: "Some commit",
248
+ operations: [
249
+ {
250
+ operation: "addOrUpdate",
251
+ content: new URL(
252
+ `https://huggingface.co/spaces/huggingfacejs/push-model-from-web/resolve/main/mobilenet/group1-shard1of2`
253
+ ),
254
+ path: "mobilenet/group1-shard1of2",
255
+ },
256
+ ],
257
+ });
258
+
259
+ assert(commit, "PR commit failed");
260
+ } finally {
261
+ await deleteRepo({
262
+ repo: {
263
+ name: repoName,
264
+ type: "model",
265
+ },
266
+ hubUrl: TEST_HUB_URL,
267
+ credentials: { accessToken: TEST_ACCESS_TOKEN },
268
+ });
269
+ }
270
+ }, 60_000);
271
+ });
lib/commit.ts ADDED
@@ -0,0 +1,609 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { HUB_URL } from "../consts";
2
+ import { HubApiError, createApiError, InvalidApiResponseFormatError } from "../error";
3
+ import type {
4
+ ApiCommitHeader,
5
+ ApiCommitLfsFile,
6
+ ApiCommitOperation,
7
+ ApiLfsBatchRequest,
8
+ ApiLfsBatchResponse,
9
+ ApiLfsCompleteMultipartRequest,
10
+ ApiPreuploadRequest,
11
+ ApiPreuploadResponse,
12
+ } from "../types/api/api-commit";
13
+ import type { CredentialsParams, RepoDesignation } from "../types/public";
14
+ import { checkCredentials } from "../utils/checkCredentials";
15
+ import { chunk } from "../utils/chunk";
16
+ import { promisesQueue } from "../utils/promisesQueue";
17
+ import { promisesQueueStreaming } from "../utils/promisesQueueStreaming";
18
+ import { sha256 } from "../utils/sha256";
19
+ import { toRepoId } from "../utils/toRepoId";
20
+ import { WebBlob } from "../utils/WebBlob";
21
+ import { eventToGenerator } from "../utils/eventToGenerator";
22
+ import { base64FromBytes } from "../utils/base64FromBytes";
23
+ import { isFrontend } from "../utils/isFrontend";
24
+ import { createBlobs } from "../utils/createBlobs";
25
+
26
+ const CONCURRENT_SHAS = 5;
27
+ const CONCURRENT_LFS_UPLOADS = 5;
28
+ const MULTIPART_PARALLEL_UPLOAD = 5;
29
+
30
+ export interface CommitDeletedEntry {
31
+ operation: "delete";
32
+ path: string;
33
+ }
34
+
35
+ export type ContentSource = Blob | URL;
36
+
37
+ export interface CommitFile {
38
+ operation: "addOrUpdate";
39
+ path: string;
40
+ content: ContentSource;
41
+ // forceLfs?: boolean
42
+ }
43
+
44
+ type CommitBlob = Omit<CommitFile, "content"> & { content: Blob };
45
+
46
+ // TODO: find a nice way to handle LFS & non-LFS files in an uniform manner, see https://github.com/huggingface/moon-landing/issues/4370
47
+ // export type CommitRenameFile = {
48
+ // operation: "rename";
49
+ // path: string;
50
+ // oldPath: string;
51
+ // content?: ContentSource;
52
+ // };
53
+
54
+ export type CommitOperation = CommitDeletedEntry | CommitFile /* | CommitRenameFile */;
55
+ type CommitBlobOperation = Exclude<CommitOperation, CommitFile> | CommitBlob;
56
+
57
+ export type CommitParams = {
58
+ title: string;
59
+ description?: string;
60
+ repo: RepoDesignation;
61
+ operations: CommitOperation[];
62
+ /** @default "main" */
63
+ branch?: string;
64
+ /**
65
+ * Parent commit. Optional
66
+ *
67
+ * - When opening a PR: will use parentCommit as the parent commit
68
+ * - When committing on a branch: Will make sure that there were no intermediate commits
69
+ */
70
+ parentCommit?: string;
71
+ isPullRequest?: boolean;
72
+ hubUrl?: string;
73
+ /**
74
+ * Whether to use web workers to compute SHA256 hashes.
75
+ *
76
+ * @default false
77
+ */
78
+ useWebWorkers?: boolean | { minSize?: number; poolSize?: number };
79
+ /**
80
+ * Maximum depth of folders to upload. Files deeper than this will be ignored
81
+ *
82
+ * @default 5
83
+ */
84
+ maxFolderDepth?: number;
85
+ /**
86
+ * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
87
+ */
88
+ fetch?: typeof fetch;
89
+ abortSignal?: AbortSignal;
90
+ // Credentials are optional due to custom fetch functions or cookie auth
91
+ } & Partial<CredentialsParams>;
92
+
93
+ export interface CommitOutput {
94
+ pullRequestUrl?: string;
95
+ commit: {
96
+ oid: string;
97
+ url: string;
98
+ };
99
+ hookOutput: string;
100
+ }
101
+
102
+ function isFileOperation(op: CommitOperation): op is CommitBlob {
103
+ const ret = op.operation === "addOrUpdate";
104
+
105
+ if (ret && !(op.content instanceof Blob)) {
106
+ throw new TypeError("Precondition failed: op.content should be a Blob");
107
+ }
108
+
109
+ return ret;
110
+ }
111
+
112
+ export type CommitProgressEvent =
113
+ | {
114
+ event: "phase";
115
+ phase: "preuploading" | "uploadingLargeFiles" | "committing";
116
+ }
117
+ | {
118
+ event: "fileProgress";
119
+ path: string;
120
+ progress: number;
121
+ state: "hashing" | "uploading";
122
+ };
123
+
124
+ /**
125
+ * Internal function for now, used by commit.
126
+ *
127
+ * Can be exposed later to offer fine-tuned progress info
128
+ */
129
+ export async function* commitIter(params: CommitParams): AsyncGenerator<CommitProgressEvent, CommitOutput> {
130
+ const accessToken = checkCredentials(params);
131
+ const repoId = toRepoId(params.repo);
132
+ yield { event: "phase", phase: "preuploading" };
133
+
134
+ const lfsShas = new Map<string, string | null>();
135
+
136
+ const abortController = new AbortController();
137
+ const abortSignal = abortController.signal;
138
+
139
+ // Polyfill see https://discuss.huggingface.co/t/why-cant-i-upload-a-parquet-file-to-my-dataset-error-o-throwifaborted-is-not-a-function/62245
140
+ if (!abortSignal.throwIfAborted) {
141
+ abortSignal.throwIfAborted = () => {
142
+ if (abortSignal.aborted) {
143
+ throw new DOMException("Aborted", "AbortError");
144
+ }
145
+ };
146
+ }
147
+
148
+ if (params.abortSignal) {
149
+ params.abortSignal.addEventListener("abort", () => abortController.abort());
150
+ }
151
+
152
+ try {
153
+ const allOperations = (
154
+ await Promise.all(
155
+ params.operations.map(async (operation) => {
156
+ if (operation.operation !== "addOrUpdate") {
157
+ return operation;
158
+ }
159
+
160
+ if (!(operation.content instanceof URL)) {
161
+ /** TS trick to enforce `content` to be a `Blob` */
162
+ return { ...operation, content: operation.content };
163
+ }
164
+
165
+ const lazyBlobs = await createBlobs(operation.content, operation.path, {
166
+ fetch: params.fetch,
167
+ maxFolderDepth: params.maxFolderDepth,
168
+ });
169
+
170
+ abortSignal?.throwIfAborted();
171
+
172
+ return lazyBlobs.map((blob) => ({
173
+ ...operation,
174
+ content: blob.blob,
175
+ path: blob.path,
176
+ }));
177
+ })
178
+ )
179
+ ).flat(1);
180
+
181
+ const gitAttributes = allOperations.filter(isFileOperation).find((op) => op.path === ".gitattributes")?.content;
182
+
183
+ for (const operations of chunk(allOperations.filter(isFileOperation), 100)) {
184
+ const payload: ApiPreuploadRequest = {
185
+ gitAttributes: gitAttributes && (await gitAttributes.text()),
186
+ files: await Promise.all(
187
+ operations.map(async (operation) => ({
188
+ path: operation.path,
189
+ size: operation.content.size,
190
+ sample: base64FromBytes(new Uint8Array(await operation.content.slice(0, 512).arrayBuffer())),
191
+ }))
192
+ ),
193
+ };
194
+
195
+ abortSignal?.throwIfAborted();
196
+
197
+ const res = await (params.fetch ?? fetch)(
198
+ `${params.hubUrl ?? HUB_URL}/api/${repoId.type}s/${repoId.name}/preupload/${encodeURIComponent(
199
+ params.branch ?? "main"
200
+ )}` + (params.isPullRequest ? "?create_pr=1" : ""),
201
+ {
202
+ method: "POST",
203
+ headers: {
204
+ ...(accessToken && { Authorization: `Bearer ${accessToken}` }),
205
+ "Content-Type": "application/json",
206
+ },
207
+ body: JSON.stringify(payload),
208
+ signal: abortSignal,
209
+ }
210
+ );
211
+
212
+ if (!res.ok) {
213
+ throw await createApiError(res);
214
+ }
215
+
216
+ const json: ApiPreuploadResponse = await res.json();
217
+
218
+ for (const file of json.files) {
219
+ if (file.uploadMode === "lfs") {
220
+ lfsShas.set(file.path, null);
221
+ }
222
+ }
223
+ }
224
+
225
+ yield { event: "phase", phase: "uploadingLargeFiles" };
226
+
227
+ for (const operations of chunk(
228
+ allOperations.filter(isFileOperation).filter((op) => lfsShas.has(op.path)),
229
+ 100
230
+ )) {
231
+ const shas = yield* eventToGenerator<
232
+ { event: "fileProgress"; state: "hashing"; path: string; progress: number },
233
+ string[]
234
+ >((yieldCallback, returnCallback, rejectCallack) => {
235
+ return promisesQueue(
236
+ operations.map((op) => async () => {
237
+ const iterator = sha256(op.content, { useWebWorker: params.useWebWorkers, abortSignal: abortSignal });
238
+ let res: IteratorResult<number, string>;
239
+ do {
240
+ res = await iterator.next();
241
+ if (!res.done) {
242
+ yieldCallback({ event: "fileProgress", path: op.path, progress: res.value, state: "hashing" });
243
+ }
244
+ } while (!res.done);
245
+ const sha = res.value;
246
+ lfsShas.set(op.path, res.value);
247
+ return sha;
248
+ }),
249
+ CONCURRENT_SHAS
250
+ ).then(returnCallback, rejectCallack);
251
+ });
252
+
253
+ abortSignal?.throwIfAborted();
254
+
255
+ const payload: ApiLfsBatchRequest = {
256
+ operation: "upload",
257
+ // multipart is a custom protocol for HF
258
+ transfers: ["basic", "multipart"],
259
+ hash_algo: "sha_256",
260
+ ...(!params.isPullRequest && {
261
+ ref: {
262
+ name: params.branch ?? "main",
263
+ },
264
+ }),
265
+ objects: operations.map((op, i) => ({
266
+ oid: shas[i],
267
+ size: op.content.size,
268
+ })),
269
+ };
270
+
271
+ const res = await (params.fetch ?? fetch)(
272
+ `${params.hubUrl ?? HUB_URL}/${repoId.type === "model" ? "" : repoId.type + "s/"}${
273
+ repoId.name
274
+ }.git/info/lfs/objects/batch`,
275
+ {
276
+ method: "POST",
277
+ headers: {
278
+ ...(accessToken && { Authorization: `Bearer ${accessToken}` }),
279
+ Accept: "application/vnd.git-lfs+json",
280
+ "Content-Type": "application/vnd.git-lfs+json",
281
+ },
282
+ body: JSON.stringify(payload),
283
+ signal: abortSignal,
284
+ }
285
+ );
286
+
287
+ if (!res.ok) {
288
+ throw await createApiError(res);
289
+ }
290
+
291
+ const json: ApiLfsBatchResponse = await res.json();
292
+ const batchRequestId = res.headers.get("X-Request-Id") || undefined;
293
+
294
+ const shaToOperation = new Map(operations.map((op, i) => [shas[i], op]));
295
+
296
+ yield* eventToGenerator<CommitProgressEvent, void>((yieldCallback, returnCallback, rejectCallback) => {
297
+ return promisesQueueStreaming(
298
+ json.objects.map((obj) => async () => {
299
+ const op = shaToOperation.get(obj.oid);
300
+
301
+ if (!op) {
302
+ throw new InvalidApiResponseFormatError("Unrequested object ID in response");
303
+ }
304
+
305
+ abortSignal?.throwIfAborted();
306
+
307
+ if (obj.error) {
308
+ const errorMessage = `Error while doing LFS batch call for ${operations[shas.indexOf(obj.oid)].path}: ${
309
+ obj.error.message
310
+ }${batchRequestId ? ` - Request ID: ${batchRequestId}` : ""}`;
311
+ throw new HubApiError(res.url, obj.error.code, batchRequestId, errorMessage);
312
+ }
313
+ if (!obj.actions?.upload) {
314
+ // Already uploaded
315
+ yieldCallback({
316
+ event: "fileProgress",
317
+ path: op.path,
318
+ progress: 1,
319
+ state: "uploading",
320
+ });
321
+ return;
322
+ }
323
+ yieldCallback({
324
+ event: "fileProgress",
325
+ path: op.path,
326
+ progress: 0,
327
+ state: "uploading",
328
+ });
329
+ const content = op.content;
330
+ const header = obj.actions.upload.header;
331
+ if (header?.chunk_size) {
332
+ const chunkSize = parseInt(header.chunk_size);
333
+
334
+ // multipart upload
335
+ // parts are in upload.header['00001'] to upload.header['99999']
336
+
337
+ const completionUrl = obj.actions.upload.href;
338
+ const parts = Object.keys(header).filter((key) => /^[0-9]+$/.test(key));
339
+
340
+ if (parts.length !== Math.ceil(content.size / chunkSize)) {
341
+ throw new Error("Invalid server response to upload large LFS file, wrong number of parts");
342
+ }
343
+
344
+ const completeReq: ApiLfsCompleteMultipartRequest = {
345
+ oid: obj.oid,
346
+ parts: parts.map((part) => ({
347
+ partNumber: +part,
348
+ etag: "",
349
+ })),
350
+ };
351
+
352
+ // Defined here so that it's not redefined at each iteration (and the caller can tell it's for the same file)
353
+ const progressCallback = (progress: number) =>
354
+ yieldCallback({ event: "fileProgress", path: op.path, progress, state: "uploading" });
355
+
356
+ await promisesQueueStreaming(
357
+ parts.map((part) => async () => {
358
+ abortSignal?.throwIfAborted();
359
+
360
+ const index = parseInt(part) - 1;
361
+ const slice = content.slice(index * chunkSize, (index + 1) * chunkSize);
362
+
363
+ const res = await (params.fetch ?? fetch)(header[part], {
364
+ method: "PUT",
365
+ /** Unfortunately, browsers don't support our inherited version of Blob in fetch calls */
366
+ body: slice instanceof WebBlob && isFrontend ? await slice.arrayBuffer() : slice,
367
+ signal: abortSignal,
368
+ ...({
369
+ progressHint: {
370
+ path: op.path,
371
+ part: index,
372
+ numParts: parts.length,
373
+ progressCallback,
374
+ },
375
+ // eslint-disable-next-line @typescript-eslint/no-explicit-any
376
+ } as any),
377
+ });
378
+
379
+ if (!res.ok) {
380
+ throw await createApiError(res, {
381
+ requestId: batchRequestId,
382
+ message: `Error while uploading part ${part} of ${
383
+ operations[shas.indexOf(obj.oid)].path
384
+ } to LFS storage`,
385
+ });
386
+ }
387
+
388
+ const eTag = res.headers.get("ETag");
389
+
390
+ if (!eTag) {
391
+ throw new Error("Cannot get ETag of part during multipart upload");
392
+ }
393
+
394
+ completeReq.parts[Number(part) - 1].etag = eTag;
395
+ }),
396
+ MULTIPART_PARALLEL_UPLOAD
397
+ );
398
+
399
+ abortSignal?.throwIfAborted();
400
+
401
+ const res = await (params.fetch ?? fetch)(completionUrl, {
402
+ method: "POST",
403
+ body: JSON.stringify(completeReq),
404
+ headers: {
405
+ Accept: "application/vnd.git-lfs+json",
406
+ "Content-Type": "application/vnd.git-lfs+json",
407
+ },
408
+ signal: abortSignal,
409
+ });
410
+
411
+ if (!res.ok) {
412
+ throw await createApiError(res, {
413
+ requestId: batchRequestId,
414
+ message: `Error completing multipart upload of ${
415
+ operations[shas.indexOf(obj.oid)].path
416
+ } to LFS storage`,
417
+ });
418
+ }
419
+
420
+ yieldCallback({
421
+ event: "fileProgress",
422
+ path: op.path,
423
+ progress: 1,
424
+ state: "uploading",
425
+ });
426
+ } else {
427
+ const res = await (params.fetch ?? fetch)(obj.actions.upload.href, {
428
+ method: "PUT",
429
+ headers: {
430
+ ...(batchRequestId ? { "X-Request-Id": batchRequestId } : undefined),
431
+ },
432
+ /** Unfortunately, browsers don't support our inherited version of Blob in fetch calls */
433
+ body: content instanceof WebBlob && isFrontend ? await content.arrayBuffer() : content,
434
+ signal: abortSignal,
435
+ ...({
436
+ progressHint: {
437
+ path: op.path,
438
+ progressCallback: (progress: number) =>
439
+ yieldCallback({
440
+ event: "fileProgress",
441
+ path: op.path,
442
+ progress,
443
+ state: "uploading",
444
+ }),
445
+ },
446
+ // eslint-disable-next-line @typescript-eslint/no-explicit-any
447
+ } as any),
448
+ });
449
+
450
+ if (!res.ok) {
451
+ throw await createApiError(res, {
452
+ requestId: batchRequestId,
453
+ message: `Error while uploading ${operations[shas.indexOf(obj.oid)].path} to LFS storage`,
454
+ });
455
+ }
456
+
457
+ yieldCallback({
458
+ event: "fileProgress",
459
+ path: op.path,
460
+ progress: 1,
461
+ state: "uploading",
462
+ });
463
+ }
464
+ }),
465
+ CONCURRENT_LFS_UPLOADS
466
+ ).then(returnCallback, rejectCallback);
467
+ });
468
+ }
469
+
470
+ abortSignal?.throwIfAborted();
471
+
472
+ yield { event: "phase", phase: "committing" };
473
+
474
+ return yield* eventToGenerator<CommitProgressEvent, CommitOutput>(
475
+ async (yieldCallback, returnCallback, rejectCallback) =>
476
+ (params.fetch ?? fetch)(
477
+ `${params.hubUrl ?? HUB_URL}/api/${repoId.type}s/${repoId.name}/commit/${encodeURIComponent(
478
+ params.branch ?? "main"
479
+ )}` + (params.isPullRequest ? "?create_pr=1" : ""),
480
+ {
481
+ method: "POST",
482
+ headers: {
483
+ ...(accessToken && { Authorization: `Bearer ${accessToken}` }),
484
+ "Content-Type": "application/x-ndjson",
485
+ },
486
+ body: [
487
+ {
488
+ key: "header",
489
+ value: {
490
+ summary: params.title,
491
+ description: params.description,
492
+ parentCommit: params.parentCommit,
493
+ } satisfies ApiCommitHeader,
494
+ },
495
+ ...((await Promise.all(
496
+ allOperations.map((operation) => {
497
+ if (isFileOperation(operation)) {
498
+ const sha = lfsShas.get(operation.path);
499
+ if (sha) {
500
+ return {
501
+ key: "lfsFile",
502
+ value: {
503
+ path: operation.path,
504
+ algo: "sha256",
505
+ size: operation.content.size,
506
+ oid: sha,
507
+ } satisfies ApiCommitLfsFile,
508
+ };
509
+ }
510
+ }
511
+
512
+ return convertOperationToNdJson(operation);
513
+ })
514
+ )) satisfies ApiCommitOperation[]),
515
+ ]
516
+ .map((x) => JSON.stringify(x))
517
+ .join("\n"),
518
+ signal: abortSignal,
519
+ ...({
520
+ progressHint: {
521
+ progressCallback: (progress: number) => {
522
+ // For now, we display equal progress for all files
523
+ // We could compute the progress based on the size of `convertOperationToNdJson` for each of the files instead
524
+ for (const op of allOperations) {
525
+ if (isFileOperation(op) && !lfsShas.has(op.path)) {
526
+ yieldCallback({
527
+ event: "fileProgress",
528
+ path: op.path,
529
+ progress,
530
+ state: "uploading",
531
+ });
532
+ }
533
+ }
534
+ },
535
+ },
536
+ // eslint-disable-next-line @typescript-eslint/no-explicit-any
537
+ } as any),
538
+ }
539
+ )
540
+ .then(async (res) => {
541
+ if (!res.ok) {
542
+ throw await createApiError(res);
543
+ }
544
+
545
+ const json = await res.json();
546
+
547
+ returnCallback({
548
+ pullRequestUrl: json.pullRequestUrl,
549
+ commit: {
550
+ oid: json.commitOid,
551
+ url: json.commitUrl,
552
+ },
553
+ hookOutput: json.hookOutput,
554
+ });
555
+ })
556
+ .catch(rejectCallback)
557
+ );
558
+ } catch (err) {
559
+ // For parallel requests, cancel them all if one fails
560
+ abortController.abort();
561
+ throw err;
562
+ }
563
+ }
564
+
565
+ export async function commit(params: CommitParams): Promise<CommitOutput> {
566
+ const iterator = commitIter(params);
567
+ let res = await iterator.next();
568
+ while (!res.done) {
569
+ res = await iterator.next();
570
+ }
571
+ return res.value;
572
+ }
573
+
574
+ async function convertOperationToNdJson(operation: CommitBlobOperation): Promise<ApiCommitOperation> {
575
+ switch (operation.operation) {
576
+ case "addOrUpdate": {
577
+ // todo: handle LFS
578
+ return {
579
+ key: "file",
580
+ value: {
581
+ content: base64FromBytes(new Uint8Array(await operation.content.arrayBuffer())),
582
+ path: operation.path,
583
+ encoding: "base64",
584
+ },
585
+ };
586
+ }
587
+ // case "rename": {
588
+ // // todo: detect when remote file is already LFS, and in that case rename as LFS
589
+ // return {
590
+ // key: "file",
591
+ // value: {
592
+ // content: operation.content,
593
+ // path: operation.path,
594
+ // oldPath: operation.oldPath
595
+ // }
596
+ // };
597
+ // }
598
+ case "delete": {
599
+ return {
600
+ key: "deletedFile",
601
+ value: {
602
+ path: operation.path,
603
+ },
604
+ };
605
+ }
606
+ default:
607
+ throw new TypeError("Unknown operation: " + (operation as { operation: string }).operation);
608
+ }
609
+ }
lib/count-commits.spec.ts ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { assert, it, describe } from "vitest";
2
+ import { countCommits } from "./count-commits";
3
+
4
+ describe("countCommits", () => {
5
+ it("should fetch paginated commits from the repo", async () => {
6
+ const count = await countCommits({
7
+ repo: {
8
+ name: "openai-community/gpt2",
9
+ type: "model",
10
+ },
11
+ revision: "607a30d783dfa663caf39e06633721c8d4cfcd7e",
12
+ });
13
+
14
+ assert.equal(count, 26);
15
+ });
16
+ });
lib/count-commits.ts ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { HUB_URL } from "../consts";
2
+ import { createApiError } from "../error";
3
+ import type { CredentialsParams, RepoDesignation } from "../types/public";
4
+ import { checkCredentials } from "../utils/checkCredentials";
5
+ import { toRepoId } from "../utils/toRepoId";
6
+
7
+ export async function countCommits(
8
+ params: {
9
+ repo: RepoDesignation;
10
+ /**
11
+ * Revision to list commits from. Defaults to the default branch.
12
+ */
13
+ revision?: string;
14
+ hubUrl?: string;
15
+ fetch?: typeof fetch;
16
+ } & Partial<CredentialsParams>
17
+ ): Promise<number> {
18
+ const accessToken = checkCredentials(params);
19
+ const repoId = toRepoId(params.repo);
20
+
21
+ // Could upgrade to 1000 commits per page
22
+ const url: string | undefined = `${params.hubUrl ?? HUB_URL}/api/${repoId.type}s/${repoId.name}/commits/${
23
+ params.revision ?? "main"
24
+ }?limit=1`;
25
+
26
+ const res: Response = await (params.fetch ?? fetch)(url, {
27
+ headers: accessToken ? { Authorization: `Bearer ${accessToken}` } : {},
28
+ });
29
+
30
+ if (!res.ok) {
31
+ throw await createApiError(res);
32
+ }
33
+
34
+ return parseInt(res.headers.get("x-total-count") ?? "0", 10);
35
+ }
lib/create-branch.spec.ts ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { assert, it, describe } from "vitest";
2
+ import { TEST_ACCESS_TOKEN, TEST_HUB_URL, TEST_USER } from "../test/consts";
3
+ import type { RepoId } from "../types/public";
4
+ import { insecureRandomString } from "../utils/insecureRandomString";
5
+ import { createRepo } from "./create-repo";
6
+ import { deleteRepo } from "./delete-repo";
7
+ import { createBranch } from "./create-branch";
8
+ import { uploadFile } from "./upload-file";
9
+ import { downloadFile } from "./download-file";
10
+
11
+ describe("createBranch", () => {
12
+ it("should create a new branch from the default branch", async () => {
13
+ const repoName = `${TEST_USER}/TEST-${insecureRandomString()}`;
14
+ const repo = { type: "model", name: repoName } satisfies RepoId;
15
+
16
+ try {
17
+ await createRepo({
18
+ accessToken: TEST_ACCESS_TOKEN,
19
+ hubUrl: TEST_HUB_URL,
20
+ repo,
21
+ });
22
+
23
+ await uploadFile({
24
+ repo,
25
+ accessToken: TEST_ACCESS_TOKEN,
26
+ hubUrl: TEST_HUB_URL,
27
+ file: {
28
+ path: "file.txt",
29
+ content: new Blob(["file content"]),
30
+ },
31
+ });
32
+
33
+ await createBranch({
34
+ repo,
35
+ branch: "new-branch",
36
+ accessToken: TEST_ACCESS_TOKEN,
37
+ hubUrl: TEST_HUB_URL,
38
+ });
39
+
40
+ const content = await downloadFile({
41
+ repo,
42
+ accessToken: TEST_ACCESS_TOKEN,
43
+ hubUrl: TEST_HUB_URL,
44
+ path: "file.txt",
45
+ revision: "new-branch",
46
+ });
47
+
48
+ assert.equal(await content?.text(), "file content");
49
+ } finally {
50
+ await deleteRepo({
51
+ repo,
52
+ accessToken: TEST_ACCESS_TOKEN,
53
+ hubUrl: TEST_HUB_URL,
54
+ });
55
+ }
56
+ });
57
+
58
+ it("should create an empty branch", async () => {
59
+ const repoName = `${TEST_USER}/TEST-${insecureRandomString()}`;
60
+ const repo = { type: "model", name: repoName } satisfies RepoId;
61
+
62
+ try {
63
+ await createRepo({
64
+ accessToken: TEST_ACCESS_TOKEN,
65
+ hubUrl: TEST_HUB_URL,
66
+ repo,
67
+ });
68
+
69
+ await uploadFile({
70
+ repo,
71
+ accessToken: TEST_ACCESS_TOKEN,
72
+ hubUrl: TEST_HUB_URL,
73
+ file: {
74
+ path: "file.txt",
75
+ content: new Blob(["file content"]),
76
+ },
77
+ });
78
+
79
+ await createBranch({
80
+ repo,
81
+ branch: "empty-branch",
82
+ empty: true,
83
+ accessToken: TEST_ACCESS_TOKEN,
84
+ hubUrl: TEST_HUB_URL,
85
+ });
86
+
87
+ const content = await downloadFile({
88
+ repo,
89
+ accessToken: TEST_ACCESS_TOKEN,
90
+ hubUrl: TEST_HUB_URL,
91
+ path: "file.txt",
92
+ revision: "empty-branch",
93
+ });
94
+
95
+ assert.equal(content, null);
96
+ } finally {
97
+ await deleteRepo({
98
+ repo,
99
+ accessToken: TEST_ACCESS_TOKEN,
100
+ hubUrl: TEST_HUB_URL,
101
+ });
102
+ }
103
+ });
104
+
105
+ it("should overwrite an existing branch", async () => {
106
+ const repoName = `${TEST_USER}/TEST-${insecureRandomString()}`;
107
+ const repo = { type: "model", name: repoName } satisfies RepoId;
108
+
109
+ try {
110
+ await createRepo({
111
+ accessToken: TEST_ACCESS_TOKEN,
112
+ hubUrl: TEST_HUB_URL,
113
+ repo,
114
+ });
115
+
116
+ await uploadFile({
117
+ repo,
118
+ accessToken: TEST_ACCESS_TOKEN,
119
+ hubUrl: TEST_HUB_URL,
120
+ file: {
121
+ path: "file.txt",
122
+ content: new Blob(["file content"]),
123
+ },
124
+ });
125
+
126
+ await createBranch({
127
+ repo,
128
+ branch: "overwrite-branch",
129
+ accessToken: TEST_ACCESS_TOKEN,
130
+ hubUrl: TEST_HUB_URL,
131
+ });
132
+
133
+ await createBranch({
134
+ repo,
135
+ branch: "overwrite-branch",
136
+ overwrite: true,
137
+ empty: true,
138
+ accessToken: TEST_ACCESS_TOKEN,
139
+ hubUrl: TEST_HUB_URL,
140
+ });
141
+
142
+ const content = await downloadFile({
143
+ repo,
144
+ accessToken: TEST_ACCESS_TOKEN,
145
+ hubUrl: TEST_HUB_URL,
146
+ path: "file.txt",
147
+ revision: "overwrite-branch",
148
+ });
149
+
150
+ assert.equal(content, null);
151
+ } finally {
152
+ await deleteRepo({
153
+ repo,
154
+ accessToken: TEST_ACCESS_TOKEN,
155
+ hubUrl: TEST_HUB_URL,
156
+ });
157
+ }
158
+ });
159
+ });
lib/create-branch.ts ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { HUB_URL } from "../consts";
2
+ import { createApiError } from "../error";
3
+ import type { AccessToken, RepoDesignation } from "../types/public";
4
+ import { toRepoId } from "../utils/toRepoId";
5
+
6
+ export async function createBranch(params: {
7
+ repo: RepoDesignation;
8
+ /**
9
+ * Revision to create the branch from. Defaults to the default branch.
10
+ *
11
+ * Use empty: true to create an empty branch.
12
+ */
13
+ revision?: string;
14
+ hubUrl?: string;
15
+ accessToken?: AccessToken;
16
+ fetch?: typeof fetch;
17
+ /**
18
+ * The name of the branch to create
19
+ */
20
+ branch: string;
21
+ /**
22
+ * Use this to create an empty branch, with no commits.
23
+ */
24
+ empty?: boolean;
25
+ /**
26
+ * Use this to overwrite the branch if it already exists.
27
+ *
28
+ * If you only specify `overwrite` and no `revision`/`empty`, and the branch already exists, it will be a no-op.
29
+ */
30
+ overwrite?: boolean;
31
+ }): Promise<void> {
32
+ const repoId = toRepoId(params.repo);
33
+ const res = await (params.fetch ?? fetch)(
34
+ `${params.hubUrl ?? HUB_URL}/api/${repoId.type}s/${repoId.name}/branch/${encodeURIComponent(params.branch)}`,
35
+ {
36
+ method: "POST",
37
+ headers: {
38
+ "Content-Type": "application/json",
39
+ ...(params.accessToken && {
40
+ Authorization: `Bearer ${params.accessToken}`,
41
+ }),
42
+ },
43
+ body: JSON.stringify({
44
+ startingPoint: params.revision,
45
+ ...(params.empty && { emptyBranch: true }),
46
+ overwrite: params.overwrite,
47
+ }),
48
+ }
49
+ );
50
+
51
+ if (!res.ok) {
52
+ throw await createApiError(res);
53
+ }
54
+ }
lib/create-repo.spec.ts ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { assert, it, describe, expect } from "vitest";
2
+
3
+ import { TEST_HUB_URL, TEST_ACCESS_TOKEN, TEST_USER } from "../test/consts";
4
+ import { insecureRandomString } from "../utils/insecureRandomString";
5
+ import { createRepo } from "./create-repo";
6
+ import { deleteRepo } from "./delete-repo";
7
+ import { downloadFile } from "./download-file";
8
+
9
+ describe("createRepo", () => {
10
+ it("should create a repo", async () => {
11
+ const repoName = `${TEST_USER}/TEST-${insecureRandomString()}`;
12
+
13
+ const result = await createRepo({
14
+ accessToken: TEST_ACCESS_TOKEN,
15
+ repo: {
16
+ name: repoName,
17
+ type: "model",
18
+ },
19
+ hubUrl: TEST_HUB_URL,
20
+ files: [{ path: ".gitattributes", content: new Blob(["*.html filter=lfs diff=lfs merge=lfs -text"]) }],
21
+ });
22
+
23
+ assert.deepStrictEqual(result, {
24
+ repoUrl: `${TEST_HUB_URL}/${repoName}`,
25
+ });
26
+
27
+ const content = await downloadFile({
28
+ repo: {
29
+ name: repoName,
30
+ type: "model",
31
+ },
32
+ path: ".gitattributes",
33
+ hubUrl: TEST_HUB_URL,
34
+ });
35
+
36
+ assert(content);
37
+ assert.strictEqual(await content.text(), "*.html filter=lfs diff=lfs merge=lfs -text");
38
+
39
+ await deleteRepo({
40
+ repo: {
41
+ name: repoName,
42
+ type: "model",
43
+ },
44
+ credentials: { accessToken: TEST_ACCESS_TOKEN },
45
+ hubUrl: TEST_HUB_URL,
46
+ });
47
+ });
48
+
49
+ it("should throw a client error when trying to create a repo without a fully-qualified name", async () => {
50
+ const tryCreate = createRepo({
51
+ repo: { name: "canonical", type: "model" },
52
+ credentials: { accessToken: TEST_ACCESS_TOKEN },
53
+ hubUrl: TEST_HUB_URL,
54
+ });
55
+
56
+ await expect(tryCreate).rejects.toBeInstanceOf(TypeError);
57
+ });
58
+
59
+ it("should create a model with a string as name", async () => {
60
+ const repoName = `${TEST_USER}/TEST-${insecureRandomString()}`;
61
+
62
+ const result = await createRepo({
63
+ accessToken: TEST_ACCESS_TOKEN,
64
+ hubUrl: TEST_HUB_URL,
65
+ repo: repoName,
66
+ files: [{ path: ".gitattributes", content: new Blob(["*.html filter=lfs diff=lfs merge=lfs -text"]) }],
67
+ });
68
+
69
+ assert.deepStrictEqual(result, {
70
+ repoUrl: `${TEST_HUB_URL}/${repoName}`,
71
+ });
72
+
73
+ await deleteRepo({
74
+ repo: {
75
+ name: repoName,
76
+ type: "model",
77
+ },
78
+ hubUrl: TEST_HUB_URL,
79
+ credentials: { accessToken: TEST_ACCESS_TOKEN },
80
+ });
81
+ });
82
+
83
+ it("should create a dataset with a string as name", async () => {
84
+ const repoName = `datasets/${TEST_USER}/TEST-${insecureRandomString()}`;
85
+
86
+ const result = await createRepo({
87
+ accessToken: TEST_ACCESS_TOKEN,
88
+ hubUrl: TEST_HUB_URL,
89
+ repo: repoName,
90
+ files: [{ path: ".gitattributes", content: new Blob(["*.html filter=lfs diff=lfs merge=lfs -text"]) }],
91
+ });
92
+
93
+ assert.deepStrictEqual(result, {
94
+ repoUrl: `${TEST_HUB_URL}/${repoName}`,
95
+ });
96
+
97
+ await deleteRepo({
98
+ repo: repoName,
99
+ hubUrl: TEST_HUB_URL,
100
+ credentials: { accessToken: TEST_ACCESS_TOKEN },
101
+ });
102
+ });
103
+ });
lib/create-repo.ts ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { HUB_URL } from "../consts";
2
+ import { createApiError } from "../error";
3
+ import type { ApiCreateRepoPayload } from "../types/api/api-create-repo";
4
+ import type { CredentialsParams, RepoDesignation, SpaceSdk } from "../types/public";
5
+ import { base64FromBytes } from "../utils/base64FromBytes";
6
+ import { checkCredentials } from "../utils/checkCredentials";
7
+ import { toRepoId } from "../utils/toRepoId";
8
+
9
+ export async function createRepo(
10
+ params: {
11
+ repo: RepoDesignation;
12
+ /**
13
+ * If unset, will follow the organization's default setting. (typically public, except for some Enterprise organizations)
14
+ */
15
+ private?: boolean;
16
+ license?: string;
17
+ /**
18
+ * Only a few lightweight files are supported at repo creation
19
+ */
20
+ files?: Array<{ content: ArrayBuffer | Blob; path: string }>;
21
+ /** @required for when {@link repo.type} === "space" */
22
+ sdk?: SpaceSdk;
23
+ hubUrl?: string;
24
+ /**
25
+ * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
26
+ */
27
+ fetch?: typeof fetch;
28
+ } & CredentialsParams
29
+ ): Promise<{ repoUrl: string }> {
30
+ const accessToken = checkCredentials(params);
31
+ const repoId = toRepoId(params.repo);
32
+ const [namespace, repoName] = repoId.name.split("/");
33
+
34
+ if (!namespace || !repoName) {
35
+ throw new TypeError(
36
+ `"${repoId.name}" is not a fully qualified repo name. It should be of the form "{namespace}/{repoName}".`
37
+ );
38
+ }
39
+
40
+ const res = await (params.fetch ?? fetch)(`${params.hubUrl ?? HUB_URL}/api/repos/create`, {
41
+ method: "POST",
42
+ body: JSON.stringify({
43
+ name: repoName,
44
+ private: params.private,
45
+ organization: namespace,
46
+ license: params.license,
47
+ ...(repoId.type === "space"
48
+ ? {
49
+ type: "space",
50
+ sdk: "static",
51
+ }
52
+ : {
53
+ type: repoId.type,
54
+ }),
55
+ files: params.files
56
+ ? await Promise.all(
57
+ params.files.map(async (file) => ({
58
+ encoding: "base64",
59
+ path: file.path,
60
+ content: base64FromBytes(
61
+ new Uint8Array(file.content instanceof Blob ? await file.content.arrayBuffer() : file.content)
62
+ ),
63
+ }))
64
+ )
65
+ : undefined,
66
+ } satisfies ApiCreateRepoPayload),
67
+ headers: {
68
+ Authorization: `Bearer ${accessToken}`,
69
+ "Content-Type": "application/json",
70
+ },
71
+ });
72
+
73
+ if (!res.ok) {
74
+ throw await createApiError(res);
75
+ }
76
+ const output = await res.json();
77
+ return { repoUrl: output.url };
78
+ }
lib/dataset-info.spec.ts ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { describe, expect, it } from "vitest";
2
+ import { datasetInfo } from "./dataset-info";
3
+ import type { DatasetEntry } from "./list-datasets";
4
+ import type { ApiDatasetInfo } from "../types/api/api-dataset";
5
+
6
+ describe("datasetInfo", () => {
7
+ it("should return the dataset info", async () => {
8
+ const info = await datasetInfo({
9
+ name: "nyu-mll/glue",
10
+ });
11
+ expect(info).toEqual({
12
+ id: "621ffdd236468d709f181e3f",
13
+ downloads: expect.any(Number),
14
+ gated: false,
15
+ name: "nyu-mll/glue",
16
+ updatedAt: expect.any(Date),
17
+ likes: expect.any(Number),
18
+ private: false,
19
+ });
20
+ });
21
+
22
+ it("should return the dataset info with author", async () => {
23
+ const info: DatasetEntry & Pick<ApiDatasetInfo, "author"> = await datasetInfo({
24
+ name: "nyu-mll/glue",
25
+ additionalFields: ["author"],
26
+ });
27
+ expect(info).toEqual({
28
+ id: "621ffdd236468d709f181e3f",
29
+ downloads: expect.any(Number),
30
+ gated: false,
31
+ name: "nyu-mll/glue",
32
+ updatedAt: expect.any(Date),
33
+ likes: expect.any(Number),
34
+ private: false,
35
+ author: "nyu-mll",
36
+ });
37
+ });
38
+
39
+ it("should return the dataset info for a specific revision", async () => {
40
+ const info: DatasetEntry & Pick<ApiDatasetInfo, "sha"> = await datasetInfo({
41
+ name: "nyu-mll/glue",
42
+ revision: "cb2099c76426ff97da7aa591cbd317d91fb5fcb7",
43
+ additionalFields: ["sha"],
44
+ });
45
+ expect(info).toEqual({
46
+ id: "621ffdd236468d709f181e3f",
47
+ downloads: expect.any(Number),
48
+ gated: false,
49
+ name: "nyu-mll/glue",
50
+ updatedAt: expect.any(Date),
51
+ likes: expect.any(Number),
52
+ private: false,
53
+ sha: "cb2099c76426ff97da7aa591cbd317d91fb5fcb7",
54
+ });
55
+ });
56
+ });
lib/dataset-info.ts ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { HUB_URL } from "../consts";
2
+ import { createApiError } from "../error";
3
+ import type { ApiDatasetInfo } from "../types/api/api-dataset";
4
+ import type { CredentialsParams } from "../types/public";
5
+ import { checkCredentials } from "../utils/checkCredentials";
6
+ import { pick } from "../utils/pick";
7
+ import { type DATASET_EXPANDABLE_KEYS, DATASET_EXPAND_KEYS, type DatasetEntry } from "./list-datasets";
8
+
9
+ export async function datasetInfo<
10
+ const T extends Exclude<(typeof DATASET_EXPANDABLE_KEYS)[number], (typeof DATASET_EXPAND_KEYS)[number]> = never,
11
+ >(
12
+ params: {
13
+ name: string;
14
+ hubUrl?: string;
15
+ additionalFields?: T[];
16
+ /**
17
+ * An optional Git revision id which can be a branch name, a tag, or a commit hash.
18
+ */
19
+ revision?: string;
20
+ /**
21
+ * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
22
+ */
23
+ fetch?: typeof fetch;
24
+ } & Partial<CredentialsParams>
25
+ ): Promise<DatasetEntry & Pick<ApiDatasetInfo, T>> {
26
+ const accessToken = params && checkCredentials(params);
27
+
28
+ const search = new URLSearchParams([
29
+ ...DATASET_EXPAND_KEYS.map((val) => ["expand", val] satisfies [string, string]),
30
+ ...(params?.additionalFields?.map((val) => ["expand", val] satisfies [string, string]) ?? []),
31
+ ]).toString();
32
+
33
+ const response = await (params.fetch || fetch)(
34
+ `${params?.hubUrl || HUB_URL}/api/datasets/${params.name}/revision/${encodeURIComponent(
35
+ params.revision ?? "HEAD"
36
+ )}?${search.toString()}`,
37
+ {
38
+ headers: {
39
+ ...(accessToken ? { Authorization: `Bearer ${accessToken}` } : {}),
40
+ Accepts: "application/json",
41
+ },
42
+ }
43
+ );
44
+
45
+ if (!response.ok) {
46
+ throw await createApiError(response);
47
+ }
48
+
49
+ const data = await response.json();
50
+
51
+ return {
52
+ ...(params?.additionalFields && pick(data, params.additionalFields)),
53
+ id: data._id,
54
+ name: data.id,
55
+ private: data.private,
56
+ downloads: data.downloads,
57
+ likes: data.likes,
58
+ gated: data.gated,
59
+ updatedAt: new Date(data.lastModified),
60
+ } as DatasetEntry & Pick<ApiDatasetInfo, T>;
61
+ }
lib/delete-branch.spec.ts ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { it, describe } from "vitest";
2
+ import { TEST_ACCESS_TOKEN, TEST_HUB_URL, TEST_USER } from "../test/consts";
3
+ import type { RepoId } from "../types/public";
4
+ import { insecureRandomString } from "../utils/insecureRandomString";
5
+ import { createRepo } from "./create-repo";
6
+ import { deleteRepo } from "./delete-repo";
7
+ import { createBranch } from "./create-branch";
8
+ import { deleteBranch } from "./delete-branch";
9
+
10
+ describe("deleteBranch", () => {
11
+ it("should delete an existing branch", async () => {
12
+ const repoName = `${TEST_USER}/TEST-${insecureRandomString()}`;
13
+ const repo = { type: "model", name: repoName } satisfies RepoId;
14
+
15
+ try {
16
+ await createRepo({
17
+ accessToken: TEST_ACCESS_TOKEN,
18
+ hubUrl: TEST_HUB_URL,
19
+ repo,
20
+ });
21
+
22
+ await createBranch({
23
+ repo,
24
+ branch: "branch-to-delete",
25
+ accessToken: TEST_ACCESS_TOKEN,
26
+ hubUrl: TEST_HUB_URL,
27
+ });
28
+
29
+ await deleteBranch({
30
+ repo,
31
+ branch: "branch-to-delete",
32
+ accessToken: TEST_ACCESS_TOKEN,
33
+ hubUrl: TEST_HUB_URL,
34
+ });
35
+ } finally {
36
+ await deleteRepo({
37
+ repo,
38
+ accessToken: TEST_ACCESS_TOKEN,
39
+ hubUrl: TEST_HUB_URL,
40
+ });
41
+ }
42
+ });
43
+ });
lib/delete-branch.ts ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { HUB_URL } from "../consts";
2
+ import { createApiError } from "../error";
3
+ import type { AccessToken, RepoDesignation } from "../types/public";
4
+ import { toRepoId } from "../utils/toRepoId";
5
+
6
+ export async function deleteBranch(params: {
7
+ repo: RepoDesignation;
8
+ /**
9
+ * The name of the branch to delete
10
+ */
11
+ branch: string;
12
+ hubUrl?: string;
13
+ accessToken?: AccessToken;
14
+ fetch?: typeof fetch;
15
+ }): Promise<void> {
16
+ const repoId = toRepoId(params.repo);
17
+ const res = await (params.fetch ?? fetch)(
18
+ `${params.hubUrl ?? HUB_URL}/api/${repoId.type}s/${repoId.name}/branch/${encodeURIComponent(params.branch)}`,
19
+ {
20
+ method: "DELETE",
21
+ headers: {
22
+ ...(params.accessToken && {
23
+ Authorization: `Bearer ${params.accessToken}`,
24
+ }),
25
+ },
26
+ }
27
+ );
28
+
29
+ if (!res.ok) {
30
+ throw await createApiError(res);
31
+ }
32
+ }
lib/delete-file.spec.ts ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { assert, it, describe } from "vitest";
2
+
3
+ import { TEST_ACCESS_TOKEN, TEST_HUB_URL, TEST_USER } from "../test/consts";
4
+ import type { RepoId } from "../types/public";
5
+ import { insecureRandomString } from "../utils/insecureRandomString";
6
+ import { createRepo } from "./create-repo";
7
+ import { deleteRepo } from "./delete-repo";
8
+ import { deleteFile } from "./delete-file";
9
+ import { downloadFile } from "./download-file";
10
+
11
+ describe("deleteFile", () => {
12
+ it("should delete a file", async () => {
13
+ const repoName = `${TEST_USER}/TEST-${insecureRandomString()}`;
14
+ const repo = { type: "model", name: repoName } satisfies RepoId;
15
+
16
+ try {
17
+ const result = await createRepo({
18
+ accessToken: TEST_ACCESS_TOKEN,
19
+ hubUrl: TEST_HUB_URL,
20
+ repo,
21
+ files: [
22
+ { path: "file1", content: new Blob(["file1"]) },
23
+ { path: "file2", content: new Blob(["file2"]) },
24
+ ],
25
+ });
26
+
27
+ assert.deepStrictEqual(result, {
28
+ repoUrl: `${TEST_HUB_URL}/${repoName}`,
29
+ });
30
+
31
+ let content = await downloadFile({
32
+ hubUrl: TEST_HUB_URL,
33
+ repo,
34
+ path: "file1",
35
+ });
36
+
37
+ assert.strictEqual(await content?.text(), "file1");
38
+
39
+ await deleteFile({ path: "file1", repo, accessToken: TEST_ACCESS_TOKEN, hubUrl: TEST_HUB_URL });
40
+
41
+ content = await downloadFile({
42
+ repo,
43
+ path: "file1",
44
+ hubUrl: TEST_HUB_URL,
45
+ });
46
+
47
+ assert.strictEqual(content, null);
48
+
49
+ content = await downloadFile({
50
+ repo,
51
+ path: "file2",
52
+ hubUrl: TEST_HUB_URL,
53
+ });
54
+
55
+ assert.strictEqual(await content?.text(), "file2");
56
+ } finally {
57
+ await deleteRepo({
58
+ repo,
59
+ accessToken: TEST_ACCESS_TOKEN,
60
+ hubUrl: TEST_HUB_URL,
61
+ });
62
+ }
63
+ });
64
+ });
lib/delete-file.ts ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import type { CredentialsParams } from "../types/public";
2
+ import type { CommitOutput, CommitParams } from "./commit";
3
+ import { commit } from "./commit";
4
+
5
+ export function deleteFile(
6
+ params: {
7
+ repo: CommitParams["repo"];
8
+ path: string;
9
+ commitTitle?: CommitParams["title"];
10
+ commitDescription?: CommitParams["description"];
11
+ hubUrl?: CommitParams["hubUrl"];
12
+ fetch?: CommitParams["fetch"];
13
+ branch?: CommitParams["branch"];
14
+ isPullRequest?: CommitParams["isPullRequest"];
15
+ parentCommit?: CommitParams["parentCommit"];
16
+ } & CredentialsParams
17
+ ): Promise<CommitOutput> {
18
+ return commit({
19
+ ...(params.accessToken ? { accessToken: params.accessToken } : { credentials: params.credentials }),
20
+ repo: params.repo,
21
+ operations: [
22
+ {
23
+ operation: "delete",
24
+ path: params.path,
25
+ },
26
+ ],
27
+ title: params.commitTitle ?? `Delete ${params.path}`,
28
+ description: params.commitDescription,
29
+ hubUrl: params.hubUrl,
30
+ branch: params.branch,
31
+ isPullRequest: params.isPullRequest,
32
+ parentCommit: params.parentCommit,
33
+ fetch: params.fetch,
34
+ });
35
+ }
lib/delete-files.spec.ts ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { assert, it, describe } from "vitest";
2
+
3
+ import { TEST_HUB_URL, TEST_ACCESS_TOKEN, TEST_USER } from "../test/consts";
4
+ import type { RepoId } from "../types/public";
5
+ import { insecureRandomString } from "../utils/insecureRandomString";
6
+ import { createRepo } from "./create-repo";
7
+ import { deleteRepo } from "./delete-repo";
8
+ import { deleteFiles } from "./delete-files";
9
+ import { downloadFile } from "./download-file";
10
+
11
+ describe("deleteFiles", () => {
12
+ it("should delete multiple files", async () => {
13
+ const repoName = `${TEST_USER}/TEST-${insecureRandomString()}`;
14
+ const repo = { type: "model", name: repoName } satisfies RepoId;
15
+
16
+ try {
17
+ const result = await createRepo({
18
+ accessToken: TEST_ACCESS_TOKEN,
19
+ repo,
20
+ files: [
21
+ { path: "file1", content: new Blob(["file1"]) },
22
+ { path: "file2", content: new Blob(["file2"]) },
23
+ { path: "file3", content: new Blob(["file3"]) },
24
+ ],
25
+ hubUrl: TEST_HUB_URL,
26
+ });
27
+
28
+ assert.deepStrictEqual(result, {
29
+ repoUrl: `${TEST_HUB_URL}/${repoName}`,
30
+ });
31
+
32
+ let content = await downloadFile({
33
+ repo,
34
+ path: "file1",
35
+ hubUrl: TEST_HUB_URL,
36
+ });
37
+
38
+ assert.strictEqual(await content?.text(), "file1");
39
+
40
+ content = await downloadFile({
41
+ repo,
42
+ path: "file2",
43
+ hubUrl: TEST_HUB_URL,
44
+ });
45
+
46
+ assert.strictEqual(await content?.text(), "file2");
47
+
48
+ await deleteFiles({ paths: ["file1", "file2"], repo, accessToken: TEST_ACCESS_TOKEN, hubUrl: TEST_HUB_URL });
49
+
50
+ content = await downloadFile({
51
+ repo,
52
+ path: "file1",
53
+ hubUrl: TEST_HUB_URL,
54
+ });
55
+
56
+ assert.strictEqual(content, null);
57
+
58
+ content = await downloadFile({
59
+ repo,
60
+ path: "file2",
61
+ hubUrl: TEST_HUB_URL,
62
+ });
63
+
64
+ assert.strictEqual(content, null);
65
+
66
+ content = await downloadFile({
67
+ repo,
68
+ path: "file3",
69
+ hubUrl: TEST_HUB_URL,
70
+ });
71
+
72
+ assert.strictEqual(await content?.text(), "file3");
73
+ } finally {
74
+ await deleteRepo({
75
+ repo,
76
+ accessToken: TEST_ACCESS_TOKEN,
77
+ hubUrl: TEST_HUB_URL,
78
+ });
79
+ }
80
+ });
81
+ });
lib/delete-files.ts ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import type { CredentialsParams } from "../types/public";
2
+ import type { CommitOutput, CommitParams } from "./commit";
3
+ import { commit } from "./commit";
4
+
5
+ export function deleteFiles(
6
+ params: {
7
+ repo: CommitParams["repo"];
8
+ paths: string[];
9
+ commitTitle?: CommitParams["title"];
10
+ commitDescription?: CommitParams["description"];
11
+ hubUrl?: CommitParams["hubUrl"];
12
+ branch?: CommitParams["branch"];
13
+ isPullRequest?: CommitParams["isPullRequest"];
14
+ parentCommit?: CommitParams["parentCommit"];
15
+ fetch?: CommitParams["fetch"];
16
+ } & CredentialsParams
17
+ ): Promise<CommitOutput> {
18
+ return commit({
19
+ ...(params.accessToken ? { accessToken: params.accessToken } : { credentials: params.credentials }),
20
+ repo: params.repo,
21
+ operations: params.paths.map((path) => ({
22
+ operation: "delete",
23
+ path,
24
+ })),
25
+ title: params.commitTitle ?? `Deletes ${params.paths.length} files`,
26
+ description: params.commitDescription,
27
+ hubUrl: params.hubUrl,
28
+ branch: params.branch,
29
+ isPullRequest: params.isPullRequest,
30
+ parentCommit: params.parentCommit,
31
+ fetch: params.fetch,
32
+ });
33
+ }
lib/delete-repo.ts ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { HUB_URL } from "../consts";
2
+ import { createApiError } from "../error";
3
+ import type { CredentialsParams, RepoDesignation } from "../types/public";
4
+ import { checkCredentials } from "../utils/checkCredentials";
5
+ import { toRepoId } from "../utils/toRepoId";
6
+
7
+ export async function deleteRepo(
8
+ params: {
9
+ repo: RepoDesignation;
10
+ hubUrl?: string;
11
+ /**
12
+ * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
13
+ */
14
+ fetch?: typeof fetch;
15
+ } & CredentialsParams
16
+ ): Promise<void> {
17
+ const accessToken = checkCredentials(params);
18
+ const repoId = toRepoId(params.repo);
19
+ const [namespace, repoName] = repoId.name.split("/");
20
+
21
+ const res = await (params.fetch ?? fetch)(`${params.hubUrl ?? HUB_URL}/api/repos/delete`, {
22
+ method: "DELETE",
23
+ body: JSON.stringify({
24
+ name: repoName,
25
+ organization: namespace,
26
+ type: repoId.type,
27
+ }),
28
+ headers: {
29
+ Authorization: `Bearer ${accessToken}`,
30
+ "Content-Type": "application/json",
31
+ },
32
+ });
33
+
34
+ if (!res.ok) {
35
+ throw await createApiError(res);
36
+ }
37
+ }
lib/download-file-to-cache-dir.spec.ts ADDED
@@ -0,0 +1,306 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { expect, test, describe, vi, beforeEach } from "vitest";
2
+ import type { RepoDesignation, RepoId } from "../types/public";
3
+ import { dirname, join } from "node:path";
4
+ import { lstat, mkdir, stat, symlink, rename } from "node:fs/promises";
5
+ import { pathsInfo } from "./paths-info";
6
+ import { createWriteStream, type Stats } from "node:fs";
7
+ import { getHFHubCachePath, getRepoFolderName } from "./cache-management";
8
+ import { toRepoId } from "../utils/toRepoId";
9
+ import { downloadFileToCacheDir } from "./download-file-to-cache-dir";
10
+ import { createSymlink } from "../utils/symlink";
11
+
12
+ vi.mock("node:fs/promises", () => ({
13
+ rename: vi.fn(),
14
+ symlink: vi.fn(),
15
+ lstat: vi.fn(),
16
+ mkdir: vi.fn(),
17
+ stat: vi.fn(),
18
+ }));
19
+
20
+ vi.mock("node:fs", () => ({
21
+ createWriteStream: vi.fn(),
22
+ }));
23
+
24
+ vi.mock("./paths-info", () => ({
25
+ pathsInfo: vi.fn(),
26
+ }));
27
+
28
+ vi.mock("../utils/symlink", () => ({
29
+ createSymlink: vi.fn(),
30
+ }));
31
+
32
+ const DUMMY_REPO: RepoId = {
33
+ name: "hello-world",
34
+ type: "model",
35
+ };
36
+
37
+ const DUMMY_ETAG = "dummy-etag";
38
+
39
+ // utility test method to get blob file path
40
+ function _getBlobFile(params: {
41
+ repo: RepoDesignation;
42
+ etag: string;
43
+ cacheDir?: string; // default to {@link getHFHubCache}
44
+ }) {
45
+ return join(params.cacheDir ?? getHFHubCachePath(), getRepoFolderName(toRepoId(params.repo)), "blobs", params.etag);
46
+ }
47
+
48
+ // utility test method to get snapshot file path
49
+ function _getSnapshotFile(params: {
50
+ repo: RepoDesignation;
51
+ path: string;
52
+ revision: string;
53
+ cacheDir?: string; // default to {@link getHFHubCache}
54
+ }) {
55
+ return join(
56
+ params.cacheDir ?? getHFHubCachePath(),
57
+ getRepoFolderName(toRepoId(params.repo)),
58
+ "snapshots",
59
+ params.revision,
60
+ params.path
61
+ );
62
+ }
63
+
64
+ describe("downloadFileToCacheDir", () => {
65
+ const fetchMock: typeof fetch = vi.fn();
66
+ beforeEach(() => {
67
+ vi.resetAllMocks();
68
+ // mock 200 request
69
+ vi.mocked(fetchMock).mockResolvedValue(
70
+ new Response("dummy-body", {
71
+ status: 200,
72
+ headers: {
73
+ etag: DUMMY_ETAG,
74
+ "Content-Range": "bytes 0-54/55",
75
+ },
76
+ })
77
+ );
78
+
79
+ // prevent to use caching
80
+ vi.mocked(stat).mockRejectedValue(new Error("Do not exists"));
81
+ vi.mocked(lstat).mockRejectedValue(new Error("Do not exists"));
82
+ });
83
+
84
+ test("should throw an error if fileDownloadInfo return nothing", async () => {
85
+ await expect(async () => {
86
+ await downloadFileToCacheDir({
87
+ repo: DUMMY_REPO,
88
+ path: "/README.md",
89
+ fetch: fetchMock,
90
+ });
91
+ }).rejects.toThrowError("cannot get path info for /README.md");
92
+
93
+ expect(pathsInfo).toHaveBeenCalledWith(
94
+ expect.objectContaining({
95
+ repo: DUMMY_REPO,
96
+ paths: ["/README.md"],
97
+ fetch: fetchMock,
98
+ })
99
+ );
100
+ });
101
+
102
+ test("existing symlinked and blob should not re-download it", async () => {
103
+ // <cache>/<repo>/<revision>/snapshots/README.md
104
+ const expectPointer = _getSnapshotFile({
105
+ repo: DUMMY_REPO,
106
+ path: "/README.md",
107
+ revision: "dd4bc8b21efa05ec961e3efc4ee5e3832a3679c7",
108
+ });
109
+ // stat ensure a symlink and the pointed file exists
110
+ vi.mocked(stat).mockResolvedValue({} as Stats); // prevent default mocked reject
111
+
112
+ const output = await downloadFileToCacheDir({
113
+ repo: DUMMY_REPO,
114
+ path: "/README.md",
115
+ fetch: fetchMock,
116
+ revision: "dd4bc8b21efa05ec961e3efc4ee5e3832a3679c7",
117
+ });
118
+
119
+ expect(stat).toHaveBeenCalledOnce();
120
+ // Get call argument for stat
121
+ const starArg = vi.mocked(stat).mock.calls[0][0];
122
+
123
+ expect(starArg).toBe(expectPointer);
124
+ expect(fetchMock).not.toHaveBeenCalledWith();
125
+
126
+ expect(output).toBe(expectPointer);
127
+ });
128
+
129
+ test("existing symlinked and blob with default revision should not re-download it", async () => {
130
+ // <cache>/<repo>/<revision>/snapshots/README.md
131
+ const expectPointer = _getSnapshotFile({
132
+ repo: DUMMY_REPO,
133
+ path: "/README.md",
134
+ revision: "main",
135
+ });
136
+ // stat ensure a symlink and the pointed file exists
137
+ vi.mocked(stat).mockResolvedValue({} as Stats); // prevent default mocked reject
138
+ vi.mocked(lstat).mockResolvedValue({} as Stats);
139
+ vi.mocked(pathsInfo).mockResolvedValue([
140
+ {
141
+ oid: DUMMY_ETAG,
142
+ size: 55,
143
+ path: "README.md",
144
+ type: "file",
145
+ lastCommit: {
146
+ date: new Date(),
147
+ id: "main",
148
+ title: "Commit msg",
149
+ },
150
+ },
151
+ ]);
152
+
153
+ const output = await downloadFileToCacheDir({
154
+ repo: DUMMY_REPO,
155
+ path: "/README.md",
156
+ fetch: fetchMock,
157
+ });
158
+
159
+ expect(stat).toHaveBeenCalledOnce();
160
+ expect(symlink).not.toHaveBeenCalledOnce();
161
+ // Get call argument for stat
162
+ const starArg = vi.mocked(stat).mock.calls[0][0];
163
+
164
+ expect(starArg).toBe(expectPointer);
165
+ expect(fetchMock).not.toHaveBeenCalledWith();
166
+
167
+ expect(output).toBe(expectPointer);
168
+ });
169
+
170
+ test("existing blob should only create the symlink", async () => {
171
+ // <cache>/<repo>/<revision>/snapshots/README.md
172
+ const expectPointer = _getSnapshotFile({
173
+ repo: DUMMY_REPO,
174
+ path: "/README.md",
175
+ revision: "dummy-commit-hash",
176
+ });
177
+ // <cache>/<repo>/blobs/<etag>
178
+ const expectedBlob = _getBlobFile({
179
+ repo: DUMMY_REPO,
180
+ etag: DUMMY_ETAG,
181
+ });
182
+
183
+ // mock existing blob only no symlink
184
+ vi.mocked(lstat).mockResolvedValue({} as Stats);
185
+ // mock pathsInfo resolve content
186
+ vi.mocked(pathsInfo).mockResolvedValue([
187
+ {
188
+ oid: DUMMY_ETAG,
189
+ size: 55,
190
+ path: "README.md",
191
+ type: "file",
192
+ lastCommit: {
193
+ date: new Date(),
194
+ id: "dummy-commit-hash",
195
+ title: "Commit msg",
196
+ },
197
+ },
198
+ ]);
199
+
200
+ const output = await downloadFileToCacheDir({
201
+ repo: DUMMY_REPO,
202
+ path: "/README.md",
203
+ fetch: fetchMock,
204
+ });
205
+
206
+ // should have check for the blob
207
+ expect(lstat).toHaveBeenCalled();
208
+ expect(vi.mocked(lstat).mock.calls[0][0]).toBe(expectedBlob);
209
+
210
+ // symlink should have been created
211
+ expect(createSymlink).toHaveBeenCalledOnce();
212
+ // no download done
213
+ expect(fetchMock).not.toHaveBeenCalled();
214
+
215
+ expect(output).toBe(expectPointer);
216
+ });
217
+
218
+ test("expect resolve value to be the pointer path of downloaded file", async () => {
219
+ // <cache>/<repo>/<revision>/snapshots/README.md
220
+ const expectPointer = _getSnapshotFile({
221
+ repo: DUMMY_REPO,
222
+ path: "/README.md",
223
+ revision: "dummy-commit-hash",
224
+ });
225
+ // <cache>/<repo>/blobs/<etag>
226
+ const expectedBlob = _getBlobFile({
227
+ repo: DUMMY_REPO,
228
+ etag: DUMMY_ETAG,
229
+ });
230
+
231
+ vi.mocked(pathsInfo).mockResolvedValue([
232
+ {
233
+ oid: DUMMY_ETAG,
234
+ size: 55,
235
+ path: "README.md",
236
+ type: "file",
237
+ lastCommit: {
238
+ date: new Date(),
239
+ id: "dummy-commit-hash",
240
+ title: "Commit msg",
241
+ },
242
+ },
243
+ ]);
244
+
245
+ // eslint-disable-next-line @typescript-eslint/no-explicit-any
246
+ vi.mocked(createWriteStream).mockReturnValue(async function* () {} as any);
247
+
248
+ const output = await downloadFileToCacheDir({
249
+ repo: DUMMY_REPO,
250
+ path: "/README.md",
251
+ fetch: fetchMock,
252
+ });
253
+
254
+ // expect blobs and snapshots folder to have been mkdir
255
+ expect(vi.mocked(mkdir).mock.calls[0][0]).toBe(dirname(expectedBlob));
256
+ expect(vi.mocked(mkdir).mock.calls[1][0]).toBe(dirname(expectPointer));
257
+
258
+ expect(output).toBe(expectPointer);
259
+ });
260
+
261
+ test("should write fetch response to blob", async () => {
262
+ // <cache>/<repo>/<revision>/snapshots/README.md
263
+ const expectPointer = _getSnapshotFile({
264
+ repo: DUMMY_REPO,
265
+ path: "/README.md",
266
+ revision: "dummy-commit-hash",
267
+ });
268
+ // <cache>/<repo>/blobs/<etag>
269
+ const expectedBlob = _getBlobFile({
270
+ repo: DUMMY_REPO,
271
+ etag: DUMMY_ETAG,
272
+ });
273
+
274
+ // mock pathsInfo resolve content
275
+ vi.mocked(pathsInfo).mockResolvedValue([
276
+ {
277
+ oid: DUMMY_ETAG,
278
+ size: 55,
279
+ path: "README.md",
280
+ type: "file",
281
+ lastCommit: {
282
+ date: new Date(),
283
+ id: "dummy-commit-hash",
284
+ title: "Commit msg",
285
+ },
286
+ },
287
+ ]);
288
+
289
+ // eslint-disable-next-line @typescript-eslint/no-explicit-any
290
+ vi.mocked(createWriteStream).mockReturnValue(async function* () {} as any);
291
+
292
+ await downloadFileToCacheDir({
293
+ repo: DUMMY_REPO,
294
+ path: "/README.md",
295
+ fetch: fetchMock,
296
+ });
297
+
298
+ const incomplete = `${expectedBlob}.incomplete`;
299
+ // 1. should write fetch#response#body to incomplete file
300
+ expect(createWriteStream).toHaveBeenCalledWith(incomplete);
301
+ // 2. should rename the incomplete to the blob expected name
302
+ expect(rename).toHaveBeenCalledWith(incomplete, expectedBlob);
303
+ // 3. should create symlink pointing to blob
304
+ expect(createSymlink).toHaveBeenCalledWith({ sourcePath: expectedBlob, finalPath: expectPointer });
305
+ });
306
+ });
lib/download-file-to-cache-dir.ts ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { getHFHubCachePath, getRepoFolderName } from "./cache-management";
2
+ import { dirname, join } from "node:path";
3
+ import { rename, lstat, mkdir, stat } from "node:fs/promises";
4
+ import type { CommitInfo, PathInfo } from "./paths-info";
5
+ import { pathsInfo } from "./paths-info";
6
+ import type { CredentialsParams, RepoDesignation } from "../types/public";
7
+ import { toRepoId } from "../utils/toRepoId";
8
+ import { downloadFile } from "./download-file";
9
+ import { createSymlink } from "../utils/symlink";
10
+ import { Readable } from "node:stream";
11
+ import type { ReadableStream } from "node:stream/web";
12
+ import { pipeline } from "node:stream/promises";
13
+ import { createWriteStream } from "node:fs";
14
+
15
+ export const REGEX_COMMIT_HASH: RegExp = new RegExp("^[0-9a-f]{40}$");
16
+
17
+ function getFilePointer(storageFolder: string, revision: string, relativeFilename: string): string {
18
+ const snapshotPath = join(storageFolder, "snapshots");
19
+ return join(snapshotPath, revision, relativeFilename);
20
+ }
21
+
22
+ /**
23
+ * handy method to check if a file exists, or the pointer of a symlinks exists
24
+ * @param path
25
+ * @param followSymlinks
26
+ */
27
+ async function exists(path: string, followSymlinks?: boolean): Promise<boolean> {
28
+ try {
29
+ if (followSymlinks) {
30
+ await stat(path);
31
+ } else {
32
+ await lstat(path);
33
+ }
34
+ return true;
35
+ } catch (err: unknown) {
36
+ return false;
37
+ }
38
+ }
39
+
40
+ /**
41
+ * Download a given file if it's not already present in the local cache.
42
+ * @param params
43
+ * @return the symlink to the blob object
44
+ */
45
+ export async function downloadFileToCacheDir(
46
+ params: {
47
+ repo: RepoDesignation;
48
+ path: string;
49
+ /**
50
+ * If true, will download the raw git file.
51
+ *
52
+ * For example, when calling on a file stored with Git LFS, the pointer file will be downloaded instead.
53
+ */
54
+ raw?: boolean;
55
+ /**
56
+ * An optional Git revision id which can be a branch name, a tag, or a commit hash.
57
+ *
58
+ * @default "main"
59
+ */
60
+ revision?: string;
61
+ hubUrl?: string;
62
+ cacheDir?: string;
63
+ /**
64
+ * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
65
+ */
66
+ fetch?: typeof fetch;
67
+ } & Partial<CredentialsParams>
68
+ ): Promise<string> {
69
+ // get revision provided or default to main
70
+ const revision = params.revision ?? "main";
71
+ const cacheDir = params.cacheDir ?? getHFHubCachePath();
72
+ // get repo id
73
+ const repoId = toRepoId(params.repo);
74
+ // get storage folder
75
+ const storageFolder = join(cacheDir, getRepoFolderName(repoId));
76
+
77
+ let commitHash: string | undefined;
78
+
79
+ // if user provides a commitHash as revision, and they already have the file on disk, shortcut everything.
80
+ if (REGEX_COMMIT_HASH.test(revision)) {
81
+ commitHash = revision;
82
+ const pointerPath = getFilePointer(storageFolder, revision, params.path);
83
+ if (await exists(pointerPath, true)) return pointerPath;
84
+ }
85
+
86
+ const pathsInformation: (PathInfo & { lastCommit: CommitInfo })[] = await pathsInfo({
87
+ ...params,
88
+ paths: [params.path],
89
+ revision: revision,
90
+ expand: true,
91
+ });
92
+ if (!pathsInformation || pathsInformation.length !== 1) throw new Error(`cannot get path info for ${params.path}`);
93
+
94
+ let etag: string;
95
+ if (pathsInformation[0].lfs) {
96
+ etag = pathsInformation[0].lfs.oid; // get the LFS pointed file oid
97
+ } else {
98
+ etag = pathsInformation[0].oid; // get the repo file if not a LFS pointer
99
+ }
100
+
101
+ const pointerPath = getFilePointer(storageFolder, commitHash ?? pathsInformation[0].lastCommit.id, params.path);
102
+ const blobPath = join(storageFolder, "blobs", etag);
103
+
104
+ // if we have the pointer file, we can shortcut the download
105
+ if (await exists(pointerPath, true)) return pointerPath;
106
+
107
+ // mkdir blob and pointer path parent directory
108
+ await mkdir(dirname(blobPath), { recursive: true });
109
+ await mkdir(dirname(pointerPath), { recursive: true });
110
+
111
+ // We might already have the blob but not the pointer
112
+ // shortcut the download if needed
113
+ if (await exists(blobPath)) {
114
+ // create symlinks in snapshot folder to blob object
115
+ await createSymlink({ sourcePath: blobPath, finalPath: pointerPath });
116
+ return pointerPath;
117
+ }
118
+
119
+ const incomplete = `${blobPath}.incomplete`;
120
+ console.debug(`Downloading ${params.path} to ${incomplete}`);
121
+
122
+ const blob: Blob | null = await downloadFile({
123
+ ...params,
124
+ revision: commitHash,
125
+ });
126
+
127
+ if (!blob) {
128
+ throw new Error(`invalid response for file ${params.path}`);
129
+ }
130
+
131
+ await pipeline(Readable.fromWeb(blob.stream() as ReadableStream), createWriteStream(incomplete));
132
+
133
+ // rename .incomplete file to expect blob
134
+ await rename(incomplete, blobPath);
135
+ // create symlinks in snapshot folder to blob object
136
+ await createSymlink({ sourcePath: blobPath, finalPath: pointerPath });
137
+ return pointerPath;
138
+ }
lib/download-file.spec.ts ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { expect, test, describe, assert } from "vitest";
2
+ import { downloadFile } from "./download-file";
3
+ import { deleteRepo } from "./delete-repo";
4
+ import { createRepo } from "./create-repo";
5
+ import { TEST_ACCESS_TOKEN, TEST_HUB_URL, TEST_USER } from "../test/consts";
6
+ import { insecureRandomString } from "../utils/insecureRandomString";
7
+
8
+ describe("downloadFile", () => {
9
+ test("should download regular file", async () => {
10
+ const blob = await downloadFile({
11
+ repo: {
12
+ type: "model",
13
+ name: "openai-community/gpt2",
14
+ },
15
+ path: "README.md",
16
+ });
17
+
18
+ const text = await blob?.slice(0, 1000).text();
19
+ assert(
20
+ text?.includes(`---
21
+ language: en
22
+ tags:
23
+ - exbert
24
+
25
+ license: mit
26
+ ---
27
+
28
+
29
+ # GPT-2
30
+
31
+ Test the whole generation capabilities here: https://transformer.huggingface.co/doc/gpt2-large`)
32
+ );
33
+ });
34
+ test("should downoad xet file", async () => {
35
+ const blob = await downloadFile({
36
+ repo: {
37
+ type: "model",
38
+ name: "celinah/xet-experiments",
39
+ },
40
+ path: "large_text.txt",
41
+ });
42
+
43
+ const text = await blob?.slice(0, 100).text();
44
+ expect(text).toMatch("this is a text file.".repeat(10).slice(0, 100));
45
+ });
46
+
47
+ test("should download private file", async () => {
48
+ const repoName = `datasets/${TEST_USER}/TEST-${insecureRandomString()}`;
49
+
50
+ const result = await createRepo({
51
+ accessToken: TEST_ACCESS_TOKEN,
52
+ hubUrl: TEST_HUB_URL,
53
+ private: true,
54
+ repo: repoName,
55
+ files: [{ path: ".gitattributes", content: new Blob(["*.html filter=lfs diff=lfs merge=lfs -text"]) }],
56
+ });
57
+
58
+ assert.deepStrictEqual(result, {
59
+ repoUrl: `${TEST_HUB_URL}/${repoName}`,
60
+ });
61
+
62
+ try {
63
+ const blob = await downloadFile({
64
+ repo: repoName,
65
+ path: ".gitattributes",
66
+ hubUrl: TEST_HUB_URL,
67
+ accessToken: TEST_ACCESS_TOKEN,
68
+ });
69
+
70
+ assert(blob, "File should be found");
71
+
72
+ const text = await blob?.text();
73
+ assert.strictEqual(text, "*.html filter=lfs diff=lfs merge=lfs -text");
74
+ } finally {
75
+ await deleteRepo({
76
+ repo: repoName,
77
+ hubUrl: TEST_HUB_URL,
78
+ accessToken: TEST_ACCESS_TOKEN,
79
+ });
80
+ }
81
+ });
82
+ });
lib/download-file.ts ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import type { CredentialsParams, RepoDesignation } from "../types/public";
2
+ import { checkCredentials } from "../utils/checkCredentials";
3
+ import { WebBlob } from "../utils/WebBlob";
4
+ import { XetBlob } from "../utils/XetBlob";
5
+ import type { FileDownloadInfoOutput } from "./file-download-info";
6
+ import { fileDownloadInfo } from "./file-download-info";
7
+
8
+ /**
9
+ * @returns null when the file doesn't exist
10
+ */
11
+ export async function downloadFile(
12
+ params: {
13
+ repo: RepoDesignation;
14
+ path: string;
15
+ /**
16
+ * If true, will download the raw git file.
17
+ *
18
+ * For example, when calling on a file stored with Git LFS, the pointer file will be downloaded instead.
19
+ */
20
+ raw?: boolean;
21
+ /**
22
+ * An optional Git revision id which can be a branch name, a tag, or a commit hash.
23
+ *
24
+ * @default "main"
25
+ */
26
+ revision?: string;
27
+ hubUrl?: string;
28
+ /**
29
+ * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
30
+ */
31
+ fetch?: typeof fetch;
32
+ /**
33
+ * Whether to use the xet protocol to download the file (if applicable).
34
+ *
35
+ * Currently there's experimental support for it, so it's not enabled by default.
36
+ *
37
+ * It will be enabled automatically in a future minor version.
38
+ *
39
+ * @default false
40
+ */
41
+ xet?: boolean;
42
+ /**
43
+ * Can save an http request if provided
44
+ */
45
+ downloadInfo?: FileDownloadInfoOutput;
46
+ } & Partial<CredentialsParams>
47
+ ): Promise<Blob | null> {
48
+ const accessToken = checkCredentials(params);
49
+
50
+ const info =
51
+ params.downloadInfo ??
52
+ (await fileDownloadInfo({
53
+ accessToken,
54
+ repo: params.repo,
55
+ path: params.path,
56
+ revision: params.revision,
57
+ hubUrl: params.hubUrl,
58
+ fetch: params.fetch,
59
+ raw: params.raw,
60
+ }));
61
+
62
+ if (!info) {
63
+ return null;
64
+ }
65
+
66
+ if (info.xet && params.xet) {
67
+ return new XetBlob({
68
+ refreshUrl: info.xet.refreshUrl.href,
69
+ reconstructionUrl: info.xet.reconstructionUrl.href,
70
+ fetch: params.fetch,
71
+ accessToken,
72
+ size: info.size,
73
+ });
74
+ }
75
+
76
+ return new WebBlob(new URL(info.url), 0, info.size, "", true, params.fetch ?? fetch, accessToken);
77
+ }
lib/file-download-info.spec.ts ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { assert, it, describe } from "vitest";
2
+ import { fileDownloadInfo } from "./file-download-info";
3
+
4
+ describe("fileDownloadInfo", () => {
5
+ it("should fetch LFS file info", async () => {
6
+ const info = await fileDownloadInfo({
7
+ repo: {
8
+ name: "bert-base-uncased",
9
+ type: "model",
10
+ },
11
+ path: "tf_model.h5",
12
+ revision: "dd4bc8b21efa05ec961e3efc4ee5e3832a3679c7",
13
+ });
14
+
15
+ assert.strictEqual(info?.size, 536063208);
16
+ assert.strictEqual(info?.etag, '"a7a17d6d844b5de815ccab5f42cad6d24496db3850a2a43d8258221018ce87d2"');
17
+ });
18
+
19
+ it("should fetch raw LFS pointer info", async () => {
20
+ const info = await fileDownloadInfo({
21
+ repo: {
22
+ name: "bert-base-uncased",
23
+ type: "model",
24
+ },
25
+ path: "tf_model.h5",
26
+ revision: "dd4bc8b21efa05ec961e3efc4ee5e3832a3679c7",
27
+ raw: true,
28
+ });
29
+
30
+ assert.strictEqual(info?.size, 134);
31
+ assert.strictEqual(info?.etag, '"9eb98c817f04b051b3bcca591bcd4e03cec88018"');
32
+ });
33
+
34
+ it("should fetch non-LFS file info", async () => {
35
+ const info = await fileDownloadInfo({
36
+ repo: {
37
+ name: "bert-base-uncased",
38
+ type: "model",
39
+ },
40
+ path: "tokenizer_config.json",
41
+ revision: "1a7dd4986e3dab699c24ca19b2afd0f5e1a80f37",
42
+ });
43
+
44
+ assert.strictEqual(info?.size, 28);
45
+ assert.strictEqual(info?.etag, '"a661b1a138dac6dc5590367402d100765010ffd6"');
46
+ });
47
+
48
+ it("should fetch xet file info", async () => {
49
+ const info = await fileDownloadInfo({
50
+ repo: {
51
+ type: "model",
52
+ name: "celinah/xet-experiments",
53
+ },
54
+ path: "large_text.txt",
55
+ });
56
+ assert.strictEqual(info?.size, 62914580);
57
+ assert.strictEqual(info?.etag, '"c27f98578d9363b27db0bc1cbd9c692f8e6e90ae98c38cee7bc0a88829debd17"');
58
+ });
59
+ });
lib/file-download-info.ts ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { HUB_URL } from "../consts";
2
+ import { createApiError, InvalidApiResponseFormatError } from "../error";
3
+ import type { CredentialsParams, RepoDesignation } from "../types/public";
4
+ import { checkCredentials } from "../utils/checkCredentials";
5
+ import { parseLinkHeader } from "../utils/parseLinkHeader";
6
+ import { toRepoId } from "../utils/toRepoId";
7
+
8
+ export interface XetFileInfo {
9
+ hash: string;
10
+ refreshUrl: URL;
11
+ /**
12
+ * Can be directly used instead of the hash.
13
+ */
14
+ reconstructionUrl: URL;
15
+ }
16
+
17
+ export interface FileDownloadInfoOutput {
18
+ size: number;
19
+ etag: string;
20
+ xet?: XetFileInfo;
21
+ // URL to fetch (with the access token if private file)
22
+ url: string;
23
+ }
24
+ /**
25
+ * @returns null when the file doesn't exist
26
+ */
27
+ export async function fileDownloadInfo(
28
+ params: {
29
+ repo: RepoDesignation;
30
+ path: string;
31
+ revision?: string;
32
+ hubUrl?: string;
33
+ /**
34
+ * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
35
+ */
36
+ fetch?: typeof fetch;
37
+ /**
38
+ * To get the raw pointer file behind a LFS file
39
+ */
40
+ raw?: boolean;
41
+ /**
42
+ * To avoid the content-disposition header in the `downloadLink` for LFS files
43
+ *
44
+ * So that on browsers you can use the URL in an iframe for example
45
+ */
46
+ noContentDisposition?: boolean;
47
+ } & Partial<CredentialsParams>
48
+ ): Promise<FileDownloadInfoOutput | null> {
49
+ const accessToken = checkCredentials(params);
50
+ const repoId = toRepoId(params.repo);
51
+
52
+ const hubUrl = params.hubUrl ?? HUB_URL;
53
+ const url =
54
+ `${hubUrl}/${repoId.type === "model" ? "" : `${repoId.type}s/`}${repoId.name}/${
55
+ params.raw ? "raw" : "resolve"
56
+ }/${encodeURIComponent(params.revision ?? "main")}/${params.path}` +
57
+ (params.noContentDisposition ? "?noContentDisposition=1" : "");
58
+
59
+ const resp = await (params.fetch ?? fetch)(url, {
60
+ method: "GET",
61
+ headers: {
62
+ ...(accessToken && {
63
+ Authorization: `Bearer ${accessToken}`,
64
+ }),
65
+ Range: "bytes=0-0",
66
+ Accept: "application/vnd.xet-fileinfo+json, */*",
67
+ },
68
+ });
69
+
70
+ if (resp.status === 404 && resp.headers.get("X-Error-Code") === "EntryNotFound") {
71
+ return null;
72
+ }
73
+
74
+ if (!resp.ok) {
75
+ throw await createApiError(resp);
76
+ }
77
+
78
+ let size: number | undefined;
79
+ let xetInfo: XetFileInfo | undefined;
80
+
81
+ if (resp.headers.get("Content-Type")?.includes("application/vnd.xet-fileinfo+json")) {
82
+ size = parseInt(resp.headers.get("X-Linked-Size") ?? "invalid");
83
+ if (isNaN(size)) {
84
+ throw new InvalidApiResponseFormatError("Invalid file size received in X-Linked-Size header");
85
+ }
86
+
87
+ const hash = resp.headers.get("X-Xet-Hash");
88
+ const links = parseLinkHeader(resp.headers.get("Link") ?? "");
89
+
90
+ const reconstructionUrl = (() => {
91
+ try {
92
+ return new URL(links["xet-reconstruction-info"]);
93
+ } catch {
94
+ return null;
95
+ }
96
+ })();
97
+ const refreshUrl = (() => {
98
+ try {
99
+ return new URL(links["xet-auth"]);
100
+ } catch {
101
+ return null;
102
+ }
103
+ })();
104
+
105
+ if (!hash) {
106
+ throw new InvalidApiResponseFormatError("No hash received in X-Xet-Hash header");
107
+ }
108
+
109
+ if (!reconstructionUrl || !refreshUrl) {
110
+ throw new InvalidApiResponseFormatError("No xet-reconstruction-info or xet-auth link header");
111
+ }
112
+ xetInfo = {
113
+ hash,
114
+ refreshUrl,
115
+ reconstructionUrl,
116
+ };
117
+ }
118
+
119
+ if (size === undefined || isNaN(size)) {
120
+ const contentRangeHeader = resp.headers.get("content-range");
121
+
122
+ if (!contentRangeHeader) {
123
+ throw new InvalidApiResponseFormatError("Expected size information");
124
+ }
125
+
126
+ const [, parsedSize] = contentRangeHeader.split("/");
127
+ size = parseInt(parsedSize);
128
+
129
+ if (isNaN(size)) {
130
+ throw new InvalidApiResponseFormatError("Invalid file size received");
131
+ }
132
+ }
133
+
134
+ const etag = resp.headers.get("X-Linked-ETag") ?? resp.headers.get("ETag") ?? undefined;
135
+
136
+ if (!etag) {
137
+ throw new InvalidApiResponseFormatError("Expected ETag");
138
+ }
139
+
140
+ return {
141
+ etag,
142
+ size,
143
+ xet: xetInfo,
144
+ // Cannot use resp.url in case it's a S3 url and the user adds an Authorization header to it.
145
+ url:
146
+ resp.url &&
147
+ (new URL(resp.url).origin === new URL(hubUrl).origin || resp.headers.get("X-Cache")?.endsWith(" cloudfront"))
148
+ ? resp.url
149
+ : url,
150
+ };
151
+ }
lib/file-exists.spec.ts ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { assert, it, describe } from "vitest";
2
+ import { fileExists } from "./file-exists";
3
+
4
+ describe("fileExists", () => {
5
+ it("should return true for file that exists", async () => {
6
+ const info = await fileExists({
7
+ repo: {
8
+ name: "bert-base-uncased",
9
+ type: "model",
10
+ },
11
+ path: "tf_model.h5",
12
+ revision: "dd4bc8b21efa05ec961e3efc4ee5e3832a3679c7",
13
+ });
14
+
15
+ assert(info, "file should exist");
16
+ });
17
+
18
+ it("should return false for file that does not exist", async () => {
19
+ const info = await fileExists({
20
+ repo: {
21
+ name: "bert-base-uncased",
22
+ type: "model",
23
+ },
24
+ path: "tf_model.h5dadazdzazd",
25
+ revision: "dd4bc8b21efa05ec961e3efc4ee5e3832a3679c7",
26
+ });
27
+
28
+ assert(!info, "file should not exist");
29
+ });
30
+ });
lib/file-exists.ts ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { HUB_URL } from "../consts";
2
+ import { createApiError } from "../error";
3
+ import type { CredentialsParams, RepoDesignation } from "../types/public";
4
+ import { checkCredentials } from "../utils/checkCredentials";
5
+ import { toRepoId } from "../utils/toRepoId";
6
+
7
+ export async function fileExists(
8
+ params: {
9
+ repo: RepoDesignation;
10
+ path: string;
11
+ revision?: string;
12
+ hubUrl?: string;
13
+ /**
14
+ * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
15
+ */
16
+ fetch?: typeof fetch;
17
+ } & Partial<CredentialsParams>
18
+ ): Promise<boolean> {
19
+ const accessToken = checkCredentials(params);
20
+ const repoId = toRepoId(params.repo);
21
+
22
+ const hubUrl = params.hubUrl ?? HUB_URL;
23
+ const url = `${hubUrl}/${repoId.type === "model" ? "" : `${repoId.type}s/`}${repoId.name}/raw/${encodeURIComponent(
24
+ params.revision ?? "main"
25
+ )}/${params.path}`;
26
+
27
+ const resp = await (params.fetch ?? fetch)(url, {
28
+ method: "HEAD",
29
+ headers: accessToken ? { Authorization: `Bearer ${accessToken}` } : {},
30
+ });
31
+
32
+ if (resp.status === 404) {
33
+ return false;
34
+ }
35
+
36
+ if (!resp.ok) {
37
+ throw await createApiError(resp);
38
+ }
39
+
40
+ return true;
41
+ }
lib/index.ts ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ export * from "./cache-management";
2
+ export * from "./check-repo-access";
3
+ export * from "./commit";
4
+ export * from "./count-commits";
5
+ export * from "./create-repo";
6
+ export * from "./create-branch";
7
+ export * from "./dataset-info";
8
+ export * from "./delete-branch";
9
+ export * from "./delete-file";
10
+ export * from "./delete-files";
11
+ export * from "./delete-repo";
12
+ export * from "./download-file";
13
+ export * from "./download-file-to-cache-dir";
14
+ export * from "./file-download-info";
15
+ export * from "./file-exists";
16
+ export * from "./list-commits";
17
+ export * from "./list-datasets";
18
+ export * from "./list-files";
19
+ export * from "./list-models";
20
+ export * from "./list-spaces";
21
+ export * from "./model-info";
22
+ export * from "./oauth-handle-redirect";
23
+ export * from "./oauth-login-url";
24
+ export * from "./parse-safetensors-metadata";
25
+ export * from "./paths-info";
26
+ export * from "./repo-exists";
27
+ export * from "./snapshot-download";
28
+ export * from "./space-info";
29
+ export * from "./upload-file";
30
+ export * from "./upload-files";
31
+ export * from "./upload-files-with-progress";
32
+ export * from "./who-am-i";
lib/list-commits.spec.ts ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { assert, it, describe } from "vitest";
2
+ import type { CommitData } from "./list-commits";
3
+ import { listCommits } from "./list-commits";
4
+
5
+ describe("listCommits", () => {
6
+ it("should fetch paginated commits from the repo", async () => {
7
+ const commits: CommitData[] = [];
8
+ for await (const commit of listCommits({
9
+ repo: {
10
+ name: "openai-community/gpt2",
11
+ type: "model",
12
+ },
13
+ revision: "607a30d783dfa663caf39e06633721c8d4cfcd7e",
14
+ batchSize: 5,
15
+ })) {
16
+ commits.push(commit);
17
+ }
18
+
19
+ assert.equal(commits.length, 26);
20
+ assert.deepEqual(commits.slice(0, 6), [
21
+ {
22
+ oid: "607a30d783dfa663caf39e06633721c8d4cfcd7e",
23
+ title: "Adds the tokenizer configuration file (#80)",
24
+ message: "\n\n\n- Adds tokenizer_config.json file (db6d57930088fb63e52c010bd9ac77c955ac55e7)\n\n",
25
+ authors: [
26
+ {
27
+ username: "lysandre",
28
+ avatarUrl:
29
+ "https://cdn-avatars.huggingface.co/v1/production/uploads/5e3aec01f55e2b62848a5217/PMKS0NNB4MJQlTSFzh918.jpeg",
30
+ },
31
+ ],
32
+ date: new Date("2024-02-19T10:57:45.000Z"),
33
+ },
34
+ {
35
+ oid: "11c5a3d5811f50298f278a704980280950aedb10",
36
+ title: "Adding ONNX file of this model (#60)",
37
+ message: "\n\n\n- Adding ONNX file of this model (9411f419c589519e1a46c94ac7789ea20fd7c322)\n\n",
38
+ authors: [
39
+ {
40
+ username: "fxmarty",
41
+ avatarUrl:
42
+ "https://cdn-avatars.huggingface.co/v1/production/uploads/1651743336129-624c60cba8ec93a7ac188b56.png",
43
+ },
44
+ ],
45
+ date: new Date("2023-06-30T02:19:43.000Z"),
46
+ },
47
+ {
48
+ oid: "e7da7f221d5bf496a48136c0cd264e630fe9fcc8",
49
+ title: "Update generation_config.json",
50
+ message: "",
51
+ authors: [
52
+ {
53
+ username: "joaogante",
54
+ avatarUrl: "https://cdn-avatars.huggingface.co/v1/production/uploads/1641203017724-noauth.png",
55
+ },
56
+ ],
57
+ date: new Date("2022-12-16T15:44:21.000Z"),
58
+ },
59
+ {
60
+ oid: "f27b190eeac4c2302d24068eabf5e9d6044389ae",
61
+ title: "Add note that this is the smallest version of the model (#18)",
62
+ message:
63
+ "\n\n\n- Add note that this is the smallest version of the model (611838ef095a5bb35bf2027d05e1194b7c9d37ac)\n\n\nCo-authored-by: helen <mathemakitten@users.noreply.huggingface.co>\n",
64
+ authors: [
65
+ {
66
+ username: "sgugger",
67
+ avatarUrl:
68
+ "https://cdn-avatars.huggingface.co/v1/production/uploads/1593126474392-5ef50182b71947201082a4e5.jpeg",
69
+ },
70
+ {
71
+ username: "mathemakitten",
72
+ avatarUrl:
73
+ "https://cdn-avatars.huggingface.co/v1/production/uploads/1658248499901-6079afe2d2cd8c150e6ae05e.jpeg",
74
+ },
75
+ ],
76
+ date: new Date("2022-11-23T12:55:26.000Z"),
77
+ },
78
+ {
79
+ oid: "0dd7bcc7a64e4350d8859c9a2813132fbf6ae591",
80
+ title: "Our very first generation_config.json (#17)",
81
+ message:
82
+ "\n\n\n- Our very first generation_config.json (671851b7e9d56ef062890732065d7bd5f4628bd6)\n\n\nCo-authored-by: Joao Gante <joaogante@users.noreply.huggingface.co>\n",
83
+ authors: [
84
+ {
85
+ username: "sgugger",
86
+ avatarUrl:
87
+ "https://cdn-avatars.huggingface.co/v1/production/uploads/1593126474392-5ef50182b71947201082a4e5.jpeg",
88
+ },
89
+ {
90
+ username: "joaogante",
91
+ avatarUrl: "https://cdn-avatars.huggingface.co/v1/production/uploads/1641203017724-noauth.png",
92
+ },
93
+ ],
94
+ date: new Date("2022-11-18T18:19:30.000Z"),
95
+ },
96
+ {
97
+ oid: "75e09b43581151bd1d9ef6700faa605df408979f",
98
+ title: "Upload model.safetensors with huggingface_hub (#12)",
99
+ message:
100
+ "\n\n\n- Upload model.safetensors with huggingface_hub (ba2f794b2e4ea09ef932a6628fa0815dfaf09661)\n\n\nCo-authored-by: Nicolas Patry <Narsil@users.noreply.huggingface.co>\n",
101
+ authors: [
102
+ {
103
+ username: "julien-c",
104
+ avatarUrl:
105
+ "https://cdn-avatars.huggingface.co/v1/production/uploads/5dd96eb166059660ed1ee413/NQtzmrDdbG0H8qkZvRyGk.jpeg",
106
+ },
107
+ {
108
+ username: "Narsil",
109
+ avatarUrl:
110
+ "https://cdn-avatars.huggingface.co/v1/production/uploads/1608285816082-5e2967b819407e3277369b95.png",
111
+ },
112
+ ],
113
+ date: new Date("2022-10-20T09:34:54.000Z"),
114
+ },
115
+ ]);
116
+ });
117
+ });
lib/list-commits.ts ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { HUB_URL } from "../consts";
2
+ import { createApiError } from "../error";
3
+ import type { ApiCommitData } from "../types/api/api-commit";
4
+ import type { CredentialsParams, RepoDesignation } from "../types/public";
5
+ import { checkCredentials } from "../utils/checkCredentials";
6
+ import { parseLinkHeader } from "../utils/parseLinkHeader";
7
+ import { toRepoId } from "../utils/toRepoId";
8
+
9
+ export interface CommitData {
10
+ oid: string;
11
+ title: string;
12
+ message: string;
13
+ authors: Array<{ username: string; avatarUrl: string }>;
14
+ date: Date;
15
+ }
16
+
17
+ export async function* listCommits(
18
+ params: {
19
+ repo: RepoDesignation;
20
+ /**
21
+ * Revision to list commits from. Defaults to the default branch.
22
+ */
23
+ revision?: string;
24
+ hubUrl?: string;
25
+ /**
26
+ * Number of commits to fetch from the hub each http call. Defaults to 100. Can be set to 1000.
27
+ */
28
+ batchSize?: number;
29
+ /**
30
+ * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
31
+ */
32
+ fetch?: typeof fetch;
33
+ } & Partial<CredentialsParams>
34
+ ): AsyncGenerator<CommitData> {
35
+ const accessToken = checkCredentials(params);
36
+ const repoId = toRepoId(params.repo);
37
+
38
+ // Could upgrade to 1000 commits per page
39
+ let url: string | undefined = `${params.hubUrl ?? HUB_URL}/api/${repoId.type}s/${repoId.name}/commits/${
40
+ params.revision ?? "main"
41
+ }?limit=${params.batchSize ?? 100}`;
42
+
43
+ while (url) {
44
+ const res: Response = await (params.fetch ?? fetch)(url, {
45
+ headers: accessToken ? { Authorization: `Bearer ${accessToken}` } : {},
46
+ });
47
+
48
+ if (!res.ok) {
49
+ throw await createApiError(res);
50
+ }
51
+
52
+ const resJson: ApiCommitData[] = await res.json();
53
+ for (const commit of resJson) {
54
+ yield {
55
+ oid: commit.id,
56
+ title: commit.title,
57
+ message: commit.message,
58
+ authors: commit.authors.map((author) => ({
59
+ username: author.user,
60
+ avatarUrl: author.avatar,
61
+ })),
62
+ date: new Date(commit.date),
63
+ };
64
+ }
65
+
66
+ const linkHeader = res.headers.get("Link");
67
+
68
+ url = linkHeader ? parseLinkHeader(linkHeader).next : undefined;
69
+ }
70
+ }
lib/list-datasets.spec.ts ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { describe, expect, it } from "vitest";
2
+ import type { DatasetEntry } from "./list-datasets";
3
+ import { listDatasets } from "./list-datasets";
4
+
5
+ describe("listDatasets", () => {
6
+ it("should list datasets from hf-doc-builder", async () => {
7
+ const results: DatasetEntry[] = [];
8
+
9
+ for await (const entry of listDatasets({ search: { owner: "hf-doc-build" } })) {
10
+ if (entry.name === "hf-doc-build/doc-build-dev-test") {
11
+ continue;
12
+ }
13
+ if (typeof entry.downloads === "number") {
14
+ entry.downloads = 0;
15
+ }
16
+ if (typeof entry.likes === "number") {
17
+ entry.likes = 0;
18
+ }
19
+ if (entry.updatedAt instanceof Date && !isNaN(entry.updatedAt.getTime())) {
20
+ entry.updatedAt = new Date(0);
21
+ }
22
+
23
+ results.push(entry);
24
+ }
25
+
26
+ expect(results).deep.equal([
27
+ {
28
+ id: "6356b19985da6f13863228bd",
29
+ name: "hf-doc-build/doc-build",
30
+ private: false,
31
+ gated: false,
32
+ downloads: 0,
33
+ likes: 0,
34
+ updatedAt: new Date(0),
35
+ },
36
+ {
37
+ id: "636a1b69f2f9ec4289c4c19e",
38
+ name: "hf-doc-build/doc-build-dev",
39
+ gated: false,
40
+ private: false,
41
+ downloads: 0,
42
+ likes: 0,
43
+ updatedAt: new Date(0),
44
+ },
45
+ ]);
46
+ });
47
+ });
lib/list-datasets.ts ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { HUB_URL } from "../consts";
2
+ import { createApiError } from "../error";
3
+ import type { ApiDatasetInfo } from "../types/api/api-dataset";
4
+ import type { CredentialsParams } from "../types/public";
5
+ import { checkCredentials } from "../utils/checkCredentials";
6
+ import { parseLinkHeader } from "../utils/parseLinkHeader";
7
+ import { pick } from "../utils/pick";
8
+
9
+ export const DATASET_EXPAND_KEYS = [
10
+ "private",
11
+ "downloads",
12
+ "gated",
13
+ "likes",
14
+ "lastModified",
15
+ ] as const satisfies readonly (keyof ApiDatasetInfo)[];
16
+
17
+ export const DATASET_EXPANDABLE_KEYS = [
18
+ "author",
19
+ "cardData",
20
+ "citation",
21
+ "createdAt",
22
+ "disabled",
23
+ "description",
24
+ "downloads",
25
+ "downloadsAllTime",
26
+ "gated",
27
+ "gitalyUid",
28
+ "lastModified",
29
+ "likes",
30
+ "paperswithcode_id",
31
+ "private",
32
+ // "siblings",
33
+ "sha",
34
+ "tags",
35
+ ] as const satisfies readonly (keyof ApiDatasetInfo)[];
36
+
37
+ export interface DatasetEntry {
38
+ id: string;
39
+ name: string;
40
+ private: boolean;
41
+ downloads: number;
42
+ gated: false | "auto" | "manual";
43
+ likes: number;
44
+ updatedAt: Date;
45
+ }
46
+
47
+ export async function* listDatasets<
48
+ const T extends Exclude<(typeof DATASET_EXPANDABLE_KEYS)[number], (typeof DATASET_EXPAND_KEYS)[number]> = never,
49
+ >(
50
+ params?: {
51
+ search?: {
52
+ /**
53
+ * Will search in the dataset name for matches
54
+ */
55
+ query?: string;
56
+ owner?: string;
57
+ tags?: string[];
58
+ };
59
+ hubUrl?: string;
60
+ additionalFields?: T[];
61
+ /**
62
+ * Set to limit the number of models returned.
63
+ */
64
+ limit?: number;
65
+ /**
66
+ * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
67
+ */
68
+ fetch?: typeof fetch;
69
+ } & Partial<CredentialsParams>
70
+ ): AsyncGenerator<DatasetEntry & Pick<ApiDatasetInfo, T>> {
71
+ const accessToken = params && checkCredentials(params);
72
+ let totalToFetch = params?.limit ?? Infinity;
73
+ const search = new URLSearchParams([
74
+ ...Object.entries({
75
+ limit: String(Math.min(totalToFetch, 500)),
76
+ ...(params?.search?.owner ? { author: params.search.owner } : undefined),
77
+ ...(params?.search?.query ? { search: params.search.query } : undefined),
78
+ }),
79
+ ...(params?.search?.tags?.map((tag) => ["filter", tag]) ?? []),
80
+ ...DATASET_EXPAND_KEYS.map((val) => ["expand", val] satisfies [string, string]),
81
+ ...(params?.additionalFields?.map((val) => ["expand", val] satisfies [string, string]) ?? []),
82
+ ]).toString();
83
+ let url: string | undefined = `${params?.hubUrl || HUB_URL}/api/datasets` + (search ? "?" + search : "");
84
+
85
+ while (url) {
86
+ const res: Response = await (params?.fetch ?? fetch)(url, {
87
+ headers: {
88
+ accept: "application/json",
89
+ ...(accessToken ? { Authorization: `Bearer ${accessToken}` } : undefined),
90
+ },
91
+ });
92
+
93
+ if (!res.ok) {
94
+ throw await createApiError(res);
95
+ }
96
+
97
+ const items: ApiDatasetInfo[] = await res.json();
98
+
99
+ for (const item of items) {
100
+ yield {
101
+ ...(params?.additionalFields && pick(item, params.additionalFields)),
102
+ id: item._id,
103
+ name: item.id,
104
+ private: item.private,
105
+ downloads: item.downloads,
106
+ likes: item.likes,
107
+ gated: item.gated,
108
+ updatedAt: new Date(item.lastModified),
109
+ } as DatasetEntry & Pick<ApiDatasetInfo, T>;
110
+ totalToFetch--;
111
+ if (totalToFetch <= 0) {
112
+ return;
113
+ }
114
+ }
115
+
116
+ const linkHeader = res.headers.get("Link");
117
+
118
+ url = linkHeader ? parseLinkHeader(linkHeader).next : undefined;
119
+ // Could update limit in url to fetch less items if not all items of next page are needed.
120
+ }
121
+ }
lib/list-files.spec.ts ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { assert, it, describe } from "vitest";
2
+ import type { ListFileEntry } from "./list-files";
3
+ import { listFiles } from "./list-files";
4
+
5
+ describe("listFiles", () => {
6
+ it("should fetch the list of files from the repo", async () => {
7
+ const cursor = listFiles({
8
+ repo: {
9
+ name: "bert-base-uncased",
10
+ type: "model",
11
+ },
12
+ revision: "dd4bc8b21efa05ec961e3efc4ee5e3832a3679c7",
13
+ });
14
+
15
+ const files: ListFileEntry[] = [];
16
+
17
+ for await (const entry of cursor) {
18
+ files.push(entry);
19
+ }
20
+
21
+ assert.deepStrictEqual(files, [
22
+ {
23
+ oid: "dc08351d4dc0732d9c8af04070ced089b201ce2f",
24
+ path: ".gitattributes",
25
+ size: 345,
26
+ type: "file",
27
+ },
28
+ {
29
+ oid: "fca794a5f07ff8f963fe8b61e3694b0fb7f955df",
30
+ path: "config.json",
31
+ size: 313,
32
+ type: "file",
33
+ },
34
+ {
35
+ lfs: {
36
+ oid: "097417381d6c7230bd9e3557456d726de6e83245ec8b24f529f60198a67b203a",
37
+ size: 440473133,
38
+ pointerSize: 134,
39
+ },
40
+ xetHash: "2d8408d3a894d02517d04956e2f7546ff08362594072f3527ce144b5212a3296",
41
+ oid: "ba5d19791be1dd7992e33bd61f20207b0f7f50a5",
42
+ path: "pytorch_model.bin",
43
+ size: 440473133,
44
+ type: "file",
45
+ },
46
+ {
47
+ lfs: {
48
+ oid: "a7a17d6d844b5de815ccab5f42cad6d24496db3850a2a43d8258221018ce87d2",
49
+ size: 536063208,
50
+ pointerSize: 134,
51
+ },
52
+ xetHash: "879c5715c18a0b7f051dd33f70f0a5c8dd1522e0a43f6f75520f16167f29279b",
53
+ oid: "9eb98c817f04b051b3bcca591bcd4e03cec88018",
54
+ path: "tf_model.h5",
55
+ size: 536063208,
56
+ type: "file",
57
+ },
58
+ {
59
+ oid: "fb140275c155a9c7c5a3b3e0e77a9e839594a938",
60
+ path: "vocab.txt",
61
+ size: 231508,
62
+ type: "file",
63
+ },
64
+ ]);
65
+ });
66
+
67
+ it("should fetch the list of files from the repo, including last commit", async () => {
68
+ const cursor = listFiles({
69
+ repo: {
70
+ name: "bert-base-uncased",
71
+ type: "model",
72
+ },
73
+ revision: "dd4bc8b21efa05ec961e3efc4ee5e3832a3679c7",
74
+ expand: true,
75
+ });
76
+
77
+ const files: ListFileEntry[] = [];
78
+
79
+ for await (const entry of cursor) {
80
+ delete entry.securityFileStatus; // flaky
81
+ files.push(entry);
82
+ }
83
+
84
+ assert.deepStrictEqual(files, [
85
+ {
86
+ lastCommit: {
87
+ date: "2018-11-14T23:35:08.000Z",
88
+ id: "504939aa53e8ce310dba3dd2296dbe266c575de4",
89
+ title: "initial commit",
90
+ },
91
+ oid: "dc08351d4dc0732d9c8af04070ced089b201ce2f",
92
+ path: ".gitattributes",
93
+ size: 345,
94
+ type: "file",
95
+ },
96
+ {
97
+ lastCommit: {
98
+ date: "2019-06-18T09:06:51.000Z",
99
+ id: "bb3c1c3256d2598217df9889a14a2e811587891d",
100
+ title: "Update config.json",
101
+ },
102
+ oid: "fca794a5f07ff8f963fe8b61e3694b0fb7f955df",
103
+ path: "config.json",
104
+ size: 313,
105
+ type: "file",
106
+ },
107
+ {
108
+ lastCommit: {
109
+ date: "2019-06-18T09:06:34.000Z",
110
+ id: "3d2477d72b675a999d1b13ca822aaaf4908634ad",
111
+ title: "Update pytorch_model.bin",
112
+ },
113
+ lfs: {
114
+ oid: "097417381d6c7230bd9e3557456d726de6e83245ec8b24f529f60198a67b203a",
115
+ size: 440473133,
116
+ pointerSize: 134,
117
+ },
118
+ xetHash: "2d8408d3a894d02517d04956e2f7546ff08362594072f3527ce144b5212a3296",
119
+ oid: "ba5d19791be1dd7992e33bd61f20207b0f7f50a5",
120
+ path: "pytorch_model.bin",
121
+ size: 440473133,
122
+ type: "file",
123
+ },
124
+ {
125
+ lastCommit: {
126
+ date: "2019-09-23T19:48:44.000Z",
127
+ id: "dd4bc8b21efa05ec961e3efc4ee5e3832a3679c7",
128
+ title: "Update tf_model.h5",
129
+ },
130
+ lfs: {
131
+ oid: "a7a17d6d844b5de815ccab5f42cad6d24496db3850a2a43d8258221018ce87d2",
132
+ size: 536063208,
133
+ pointerSize: 134,
134
+ },
135
+ xetHash: "879c5715c18a0b7f051dd33f70f0a5c8dd1522e0a43f6f75520f16167f29279b",
136
+ oid: "9eb98c817f04b051b3bcca591bcd4e03cec88018",
137
+ path: "tf_model.h5",
138
+ size: 536063208,
139
+ type: "file",
140
+ },
141
+ {
142
+ lastCommit: {
143
+ date: "2018-11-14T23:35:08.000Z",
144
+ id: "2f07d813ca87c8c709147704c87210359ccf2309",
145
+ title: "Update vocab.txt",
146
+ },
147
+ oid: "fb140275c155a9c7c5a3b3e0e77a9e839594a938",
148
+ path: "vocab.txt",
149
+ size: 231508,
150
+ type: "file",
151
+ },
152
+ ]);
153
+ });
154
+
155
+ it("should fetch the list of files from the repo, including subfolders", async () => {
156
+ const cursor = listFiles({
157
+ repo: {
158
+ name: "xsum",
159
+ type: "dataset",
160
+ },
161
+ revision: "0f3ea2f2b55fcb11e71fb1e3aec6822e44ddcb0f",
162
+ recursive: true,
163
+ });
164
+
165
+ const files: ListFileEntry[] = [];
166
+
167
+ for await (const entry of cursor) {
168
+ files.push(entry);
169
+ }
170
+
171
+ assert(files.some((file) => file.path === "data/XSUM-EMNLP18-Summary-Data-Original.tar.gz"));
172
+ });
173
+ });
lib/list-files.ts ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { HUB_URL } from "../consts";
2
+ import { createApiError } from "../error";
3
+ import type { ApiIndexTreeEntry } from "../types/api/api-index-tree";
4
+ import type { CredentialsParams, RepoDesignation } from "../types/public";
5
+ import { checkCredentials } from "../utils/checkCredentials";
6
+ import { parseLinkHeader } from "../utils/parseLinkHeader";
7
+ import { toRepoId } from "../utils/toRepoId";
8
+
9
+ export interface ListFileEntry {
10
+ type: "file" | "directory" | "unknown";
11
+ size: number;
12
+ path: string;
13
+ oid: string;
14
+ lfs?: {
15
+ oid: string;
16
+ size: number;
17
+ /** Size of the raw pointer file, 100~200 bytes */
18
+ pointerSize: number;
19
+ };
20
+ /**
21
+ * Xet-backed hash, a new protocol replacing LFS for big files.
22
+ */
23
+ xetHash?: string;
24
+ /**
25
+ * Only fetched if `expand` is set to `true` in the `listFiles` call.
26
+ */
27
+ lastCommit?: {
28
+ date: string;
29
+ id: string;
30
+ title: string;
31
+ };
32
+ /**
33
+ * Only fetched if `expand` is set to `true` in the `listFiles` call.
34
+ */
35
+ securityFileStatus?: unknown;
36
+ }
37
+
38
+ /**
39
+ * List files in a folder. To list ALL files in the directory, call it
40
+ * with {@link params.recursive} set to `true`.
41
+ */
42
+ export async function* listFiles(
43
+ params: {
44
+ repo: RepoDesignation;
45
+ /**
46
+ * Do we want to list files in subdirectories?
47
+ */
48
+ recursive?: boolean;
49
+ /**
50
+ * Eg 'data' for listing all files in the 'data' folder. Leave it empty to list all
51
+ * files in the repo.
52
+ */
53
+ path?: string;
54
+ /**
55
+ * Fetch `lastCommit` and `securityFileStatus` for each file.
56
+ */
57
+ expand?: boolean;
58
+ revision?: string;
59
+ hubUrl?: string;
60
+ /**
61
+ * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
62
+ */
63
+ fetch?: typeof fetch;
64
+ } & Partial<CredentialsParams>
65
+ ): AsyncGenerator<ListFileEntry> {
66
+ const accessToken = checkCredentials(params);
67
+ const repoId = toRepoId(params.repo);
68
+ let url: string | undefined = `${params.hubUrl || HUB_URL}/api/${repoId.type}s/${repoId.name}/tree/${
69
+ params.revision || "main"
70
+ }${params.path ? "/" + params.path : ""}?recursive=${!!params.recursive}&expand=${!!params.expand}`;
71
+
72
+ while (url) {
73
+ const res: Response = await (params.fetch ?? fetch)(url, {
74
+ headers: {
75
+ accept: "application/json",
76
+ ...(accessToken ? { Authorization: `Bearer ${accessToken}` } : undefined),
77
+ },
78
+ });
79
+
80
+ if (!res.ok) {
81
+ throw await createApiError(res);
82
+ }
83
+
84
+ const items: ApiIndexTreeEntry[] = await res.json();
85
+
86
+ for (const item of items) {
87
+ yield item;
88
+ }
89
+
90
+ const linkHeader = res.headers.get("Link");
91
+
92
+ url = linkHeader ? parseLinkHeader(linkHeader).next : undefined;
93
+ }
94
+ }
lib/list-models.spec.ts ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { describe, expect, it } from "vitest";
2
+ import type { ModelEntry } from "./list-models";
3
+ import { listModels } from "./list-models";
4
+
5
+ describe("listModels", () => {
6
+ it("should list models for depth estimation", async () => {
7
+ const results: ModelEntry[] = [];
8
+
9
+ for await (const entry of listModels({
10
+ search: { owner: "Intel", task: "depth-estimation" },
11
+ })) {
12
+ if (typeof entry.downloads === "number") {
13
+ entry.downloads = 0;
14
+ }
15
+ if (typeof entry.likes === "number") {
16
+ entry.likes = 0;
17
+ }
18
+ if (entry.updatedAt instanceof Date && !isNaN(entry.updatedAt.getTime())) {
19
+ entry.updatedAt = new Date(0);
20
+ }
21
+
22
+ if (!["Intel/dpt-large", "Intel/dpt-hybrid-midas"].includes(entry.name)) {
23
+ expect(entry.task).to.equal("depth-estimation");
24
+ continue;
25
+ }
26
+
27
+ results.push(entry);
28
+ }
29
+
30
+ results.sort((a, b) => a.id.localeCompare(b.id));
31
+
32
+ expect(results).deep.equal([
33
+ {
34
+ id: "621ffdc136468d709f17e709",
35
+ name: "Intel/dpt-large",
36
+ private: false,
37
+ gated: false,
38
+ downloads: 0,
39
+ likes: 0,
40
+ task: "depth-estimation",
41
+ updatedAt: new Date(0),
42
+ },
43
+ {
44
+ id: "638f07977559bf9a2b2b04ac",
45
+ name: "Intel/dpt-hybrid-midas",
46
+ gated: false,
47
+ private: false,
48
+ downloads: 0,
49
+ likes: 0,
50
+ task: "depth-estimation",
51
+ updatedAt: new Date(0),
52
+ },
53
+ ]);
54
+ });
55
+
56
+ it("should list indonesian models with gguf format", async () => {
57
+ let count = 0;
58
+ for await (const entry of listModels({
59
+ search: { tags: ["gguf", "id"] },
60
+ additionalFields: ["tags"],
61
+ limit: 2,
62
+ })) {
63
+ count++;
64
+ expect(entry.tags).to.include("gguf");
65
+ expect(entry.tags).to.include("id");
66
+ }
67
+
68
+ expect(count).to.equal(2);
69
+ });
70
+
71
+ it("should search model by name", async () => {
72
+ let count = 0;
73
+ for await (const entry of listModels({
74
+ search: { query: "t5" },
75
+ limit: 10,
76
+ })) {
77
+ count++;
78
+ expect(entry.name.toLocaleLowerCase()).to.include("t5");
79
+ }
80
+
81
+ expect(count).to.equal(10);
82
+ });
83
+
84
+ it("should search model by inference provider", async () => {
85
+ let count = 0;
86
+ for await (const entry of listModels({
87
+ search: { inferenceProviders: ["together"] },
88
+ additionalFields: ["inferenceProviderMapping"],
89
+ limit: 10,
90
+ })) {
91
+ count++;
92
+ if (Array.isArray(entry.inferenceProviderMapping)) {
93
+ expect(entry.inferenceProviderMapping.map(({ provider }) => provider)).to.include("together");
94
+ }
95
+ }
96
+
97
+ expect(count).to.equal(10);
98
+ });
99
+
100
+ it("should search model by several inference providers", async () => {
101
+ let count = 0;
102
+ const inferenceProviders = ["together", "replicate"];
103
+ for await (const entry of listModels({
104
+ search: { inferenceProviders },
105
+ additionalFields: ["inferenceProviderMapping"],
106
+ limit: 10,
107
+ })) {
108
+ count++;
109
+ if (Array.isArray(entry.inferenceProviderMapping)) {
110
+ expect(
111
+ entry.inferenceProviderMapping.filter(({ provider }) => inferenceProviders.includes(provider)).length
112
+ ).toBeGreaterThan(0);
113
+ }
114
+ }
115
+
116
+ expect(count).to.equal(10);
117
+ });
118
+ });
lib/list-models.ts ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { HUB_URL } from "../consts";
2
+ import { createApiError } from "../error";
3
+ import type { ApiModelInfo } from "../types/api/api-model";
4
+ import type { CredentialsParams, PipelineType } from "../types/public";
5
+ import { checkCredentials } from "../utils/checkCredentials";
6
+ import { parseLinkHeader } from "../utils/parseLinkHeader";
7
+ import { pick } from "../utils/pick";
8
+
9
+ export const MODEL_EXPAND_KEYS = [
10
+ "pipeline_tag",
11
+ "private",
12
+ "gated",
13
+ "downloads",
14
+ "likes",
15
+ "lastModified",
16
+ ] as const satisfies readonly (keyof ApiModelInfo)[];
17
+
18
+ export const MODEL_EXPANDABLE_KEYS = [
19
+ "author",
20
+ "cardData",
21
+ "config",
22
+ "createdAt",
23
+ "disabled",
24
+ "downloads",
25
+ "downloadsAllTime",
26
+ "gated",
27
+ "gitalyUid",
28
+ "inferenceProviderMapping",
29
+ "lastModified",
30
+ "library_name",
31
+ "likes",
32
+ "model-index",
33
+ "pipeline_tag",
34
+ "private",
35
+ "safetensors",
36
+ "sha",
37
+ // "siblings",
38
+ "spaces",
39
+ "tags",
40
+ "transformersInfo",
41
+ ] as const satisfies readonly (keyof ApiModelInfo)[];
42
+
43
+ export interface ModelEntry {
44
+ id: string;
45
+ name: string;
46
+ private: boolean;
47
+ gated: false | "auto" | "manual";
48
+ task?: PipelineType;
49
+ likes: number;
50
+ downloads: number;
51
+ updatedAt: Date;
52
+ }
53
+
54
+ export async function* listModels<
55
+ const T extends Exclude<(typeof MODEL_EXPANDABLE_KEYS)[number], (typeof MODEL_EXPAND_KEYS)[number]> = never,
56
+ >(
57
+ params?: {
58
+ search?: {
59
+ /**
60
+ * Will search in the model name for matches
61
+ */
62
+ query?: string;
63
+ owner?: string;
64
+ task?: PipelineType;
65
+ tags?: string[];
66
+ /**
67
+ * Will search for models that have one of the inference providers in the list.
68
+ */
69
+ inferenceProviders?: string[];
70
+ };
71
+ hubUrl?: string;
72
+ additionalFields?: T[];
73
+ /**
74
+ * Set to limit the number of models returned.
75
+ */
76
+ limit?: number;
77
+ /**
78
+ * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
79
+ */
80
+ fetch?: typeof fetch;
81
+ } & Partial<CredentialsParams>
82
+ ): AsyncGenerator<ModelEntry & Pick<ApiModelInfo, T>> {
83
+ const accessToken = params && checkCredentials(params);
84
+ let totalToFetch = params?.limit ?? Infinity;
85
+ const search = new URLSearchParams([
86
+ ...Object.entries({
87
+ limit: String(Math.min(totalToFetch, 500)),
88
+ ...(params?.search?.owner ? { author: params.search.owner } : undefined),
89
+ ...(params?.search?.task ? { pipeline_tag: params.search.task } : undefined),
90
+ ...(params?.search?.query ? { search: params.search.query } : undefined),
91
+ ...(params?.search?.inferenceProviders
92
+ ? { inference_provider: params.search.inferenceProviders.join(",") }
93
+ : undefined),
94
+ }),
95
+ ...(params?.search?.tags?.map((tag) => ["filter", tag]) ?? []),
96
+ ...MODEL_EXPAND_KEYS.map((val) => ["expand", val] satisfies [string, string]),
97
+ ...(params?.additionalFields?.map((val) => ["expand", val] satisfies [string, string]) ?? []),
98
+ ]).toString();
99
+ let url: string | undefined = `${params?.hubUrl || HUB_URL}/api/models?${search}`;
100
+
101
+ while (url) {
102
+ const res: Response = await (params?.fetch ?? fetch)(url, {
103
+ headers: {
104
+ accept: "application/json",
105
+ ...(accessToken ? { Authorization: `Bearer ${accessToken}` } : undefined),
106
+ },
107
+ });
108
+
109
+ if (!res.ok) {
110
+ throw await createApiError(res);
111
+ }
112
+
113
+ const items: ApiModelInfo[] = await res.json();
114
+
115
+ for (const item of items) {
116
+ yield {
117
+ ...(params?.additionalFields && pick(item, params.additionalFields)),
118
+ id: item._id,
119
+ name: item.id,
120
+ private: item.private,
121
+ task: item.pipeline_tag,
122
+ downloads: item.downloads,
123
+ gated: item.gated,
124
+ likes: item.likes,
125
+ updatedAt: new Date(item.lastModified),
126
+ } as ModelEntry & Pick<ApiModelInfo, T>;
127
+ totalToFetch--;
128
+
129
+ if (totalToFetch <= 0) {
130
+ return;
131
+ }
132
+ }
133
+
134
+ const linkHeader = res.headers.get("Link");
135
+
136
+ url = linkHeader ? parseLinkHeader(linkHeader).next : undefined;
137
+ // Could update url to reduce the limit if we don't need the whole 500 of the next batch.
138
+ }
139
+ }
lib/list-spaces.spec.ts ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { describe, expect, it } from "vitest";
2
+ import type { SpaceEntry } from "./list-spaces";
3
+ import { listSpaces } from "./list-spaces";
4
+
5
+ describe("listSpaces", () => {
6
+ it("should list spaces for Microsoft", async () => {
7
+ const results: SpaceEntry[] = [];
8
+
9
+ for await (const entry of listSpaces({
10
+ search: { owner: "microsoft" },
11
+ additionalFields: ["subdomain"],
12
+ })) {
13
+ if (entry.name !== "microsoft/visual_chatgpt") {
14
+ continue;
15
+ }
16
+ if (typeof entry.likes === "number") {
17
+ entry.likes = 0;
18
+ }
19
+ if (entry.updatedAt instanceof Date && !isNaN(entry.updatedAt.getTime())) {
20
+ entry.updatedAt = new Date(0);
21
+ }
22
+
23
+ results.push(entry);
24
+ }
25
+
26
+ results.sort((a, b) => a.id.localeCompare(b.id));
27
+
28
+ expect(results).deep.equal([
29
+ {
30
+ id: "6409a392bbc73d022c58c980",
31
+ name: "microsoft/visual_chatgpt",
32
+ private: false,
33
+ likes: 0,
34
+ sdk: "gradio",
35
+ subdomain: "microsoft-visual-chatgpt",
36
+ updatedAt: new Date(0),
37
+ },
38
+ ]);
39
+ });
40
+ });
lib/list-spaces.ts ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { HUB_URL } from "../consts";
2
+ import { createApiError } from "../error";
3
+ import type { ApiSpaceInfo } from "../types/api/api-space";
4
+ import type { CredentialsParams, SpaceSdk } from "../types/public";
5
+ import { checkCredentials } from "../utils/checkCredentials";
6
+ import { parseLinkHeader } from "../utils/parseLinkHeader";
7
+ import { pick } from "../utils/pick";
8
+
9
+ export const SPACE_EXPAND_KEYS = [
10
+ "sdk",
11
+ "likes",
12
+ "private",
13
+ "lastModified",
14
+ ] as const satisfies readonly (keyof ApiSpaceInfo)[];
15
+ export const SPACE_EXPANDABLE_KEYS = [
16
+ "author",
17
+ "cardData",
18
+ "datasets",
19
+ "disabled",
20
+ "gitalyUid",
21
+ "lastModified",
22
+ "createdAt",
23
+ "likes",
24
+ "private",
25
+ "runtime",
26
+ "sdk",
27
+ // "siblings",
28
+ "sha",
29
+ "subdomain",
30
+ "tags",
31
+ "models",
32
+ ] as const satisfies readonly (keyof ApiSpaceInfo)[];
33
+
34
+ export interface SpaceEntry {
35
+ id: string;
36
+ name: string;
37
+ sdk?: SpaceSdk;
38
+ likes: number;
39
+ private: boolean;
40
+ updatedAt: Date;
41
+ // Use additionalFields to fetch the fields from ApiSpaceInfo
42
+ }
43
+
44
+ export async function* listSpaces<
45
+ const T extends Exclude<(typeof SPACE_EXPANDABLE_KEYS)[number], (typeof SPACE_EXPAND_KEYS)[number]> = never,
46
+ >(
47
+ params?: {
48
+ search?: {
49
+ /**
50
+ * Will search in the space name for matches
51
+ */
52
+ query?: string;
53
+ owner?: string;
54
+ tags?: string[];
55
+ };
56
+ hubUrl?: string;
57
+ /**
58
+ * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
59
+ */
60
+ fetch?: typeof fetch;
61
+ /**
62
+ * Additional fields to fetch from huggingface.co.
63
+ */
64
+ additionalFields?: T[];
65
+ } & Partial<CredentialsParams>
66
+ ): AsyncGenerator<SpaceEntry & Pick<ApiSpaceInfo, T>> {
67
+ const accessToken = params && checkCredentials(params);
68
+ const search = new URLSearchParams([
69
+ ...Object.entries({
70
+ limit: "500",
71
+ ...(params?.search?.owner ? { author: params.search.owner } : undefined),
72
+ ...(params?.search?.query ? { search: params.search.query } : undefined),
73
+ }),
74
+ ...(params?.search?.tags?.map((tag) => ["filter", tag]) ?? []),
75
+ ...[...SPACE_EXPAND_KEYS, ...(params?.additionalFields ?? [])].map(
76
+ (val) => ["expand", val] satisfies [string, string]
77
+ ),
78
+ ]).toString();
79
+ let url: string | undefined = `${params?.hubUrl || HUB_URL}/api/spaces?${search}`;
80
+
81
+ while (url) {
82
+ const res: Response = await (params?.fetch ?? fetch)(url, {
83
+ headers: {
84
+ accept: "application/json",
85
+ ...(accessToken ? { Authorization: `Bearer ${accessToken}` } : undefined),
86
+ },
87
+ });
88
+
89
+ if (!res.ok) {
90
+ throw await createApiError(res);
91
+ }
92
+
93
+ const items: ApiSpaceInfo[] = await res.json();
94
+
95
+ for (const item of items) {
96
+ yield {
97
+ ...(params?.additionalFields && pick(item, params.additionalFields)),
98
+ id: item._id,
99
+ name: item.id,
100
+ sdk: item.sdk,
101
+ likes: item.likes,
102
+ private: item.private,
103
+ updatedAt: new Date(item.lastModified),
104
+ } as SpaceEntry & Pick<ApiSpaceInfo, T>;
105
+ }
106
+
107
+ const linkHeader = res.headers.get("Link");
108
+
109
+ url = linkHeader ? parseLinkHeader(linkHeader).next : undefined;
110
+ }
111
+ }
lib/model-info.spec.ts ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { describe, expect, it } from "vitest";
2
+ import { modelInfo } from "./model-info";
3
+ import type { ModelEntry } from "./list-models";
4
+ import type { ApiModelInfo } from "../types/api/api-model";
5
+
6
+ describe("modelInfo", () => {
7
+ it("should return the model info", async () => {
8
+ const info = await modelInfo({
9
+ name: "openai-community/gpt2",
10
+ });
11
+ expect(info).toEqual({
12
+ id: "621ffdc036468d709f17434d",
13
+ downloads: expect.any(Number),
14
+ gated: false,
15
+ name: "openai-community/gpt2",
16
+ updatedAt: expect.any(Date),
17
+ likes: expect.any(Number),
18
+ task: "text-generation",
19
+ private: false,
20
+ });
21
+ });
22
+
23
+ it("should return the model info with author", async () => {
24
+ const info: ModelEntry & Pick<ApiModelInfo, "author"> = await modelInfo({
25
+ name: "openai-community/gpt2",
26
+ additionalFields: ["author"],
27
+ });
28
+ expect(info).toEqual({
29
+ id: "621ffdc036468d709f17434d",
30
+ downloads: expect.any(Number),
31
+ author: "openai-community",
32
+ gated: false,
33
+ name: "openai-community/gpt2",
34
+ updatedAt: expect.any(Date),
35
+ likes: expect.any(Number),
36
+ task: "text-generation",
37
+ private: false,
38
+ });
39
+ });
40
+
41
+ it("should return the model info for a specific revision", async () => {
42
+ const info: ModelEntry & Pick<ApiModelInfo, "sha"> = await modelInfo({
43
+ name: "openai-community/gpt2",
44
+ additionalFields: ["sha"],
45
+ revision: "f27b190eeac4c2302d24068eabf5e9d6044389ae",
46
+ });
47
+ expect(info).toEqual({
48
+ id: "621ffdc036468d709f17434d",
49
+ downloads: expect.any(Number),
50
+ gated: false,
51
+ name: "openai-community/gpt2",
52
+ updatedAt: expect.any(Date),
53
+ likes: expect.any(Number),
54
+ task: "text-generation",
55
+ private: false,
56
+ sha: "f27b190eeac4c2302d24068eabf5e9d6044389ae",
57
+ });
58
+ });
59
+ });
lib/model-info.ts ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { HUB_URL } from "../consts";
2
+ import { createApiError } from "../error";
3
+ import type { ApiModelInfo } from "../types/api/api-model";
4
+ import type { CredentialsParams } from "../types/public";
5
+ import { checkCredentials } from "../utils/checkCredentials";
6
+ import { pick } from "../utils/pick";
7
+ import { MODEL_EXPAND_KEYS, type MODEL_EXPANDABLE_KEYS, type ModelEntry } from "./list-models";
8
+
9
+ export async function modelInfo<
10
+ const T extends Exclude<(typeof MODEL_EXPANDABLE_KEYS)[number], (typeof MODEL_EXPAND_KEYS)[number]> = never,
11
+ >(
12
+ params: {
13
+ name: string;
14
+ hubUrl?: string;
15
+ additionalFields?: T[];
16
+ /**
17
+ * An optional Git revision id which can be a branch name, a tag, or a commit hash.
18
+ */
19
+ revision?: string;
20
+ /**
21
+ * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
22
+ */
23
+ fetch?: typeof fetch;
24
+ } & Partial<CredentialsParams>
25
+ ): Promise<ModelEntry & Pick<ApiModelInfo, T>> {
26
+ const accessToken = params && checkCredentials(params);
27
+
28
+ const search = new URLSearchParams([
29
+ ...MODEL_EXPAND_KEYS.map((val) => ["expand", val] satisfies [string, string]),
30
+ ...(params?.additionalFields?.map((val) => ["expand", val] satisfies [string, string]) ?? []),
31
+ ]).toString();
32
+
33
+ const response = await (params.fetch || fetch)(
34
+ `${params?.hubUrl || HUB_URL}/api/models/${params.name}/revision/${encodeURIComponent(
35
+ params.revision ?? "HEAD"
36
+ )}?${search.toString()}`,
37
+ {
38
+ headers: {
39
+ ...(accessToken ? { Authorization: `Bearer ${accessToken}` } : {}),
40
+ Accepts: "application/json",
41
+ },
42
+ }
43
+ );
44
+
45
+ if (!response.ok) {
46
+ throw await createApiError(response);
47
+ }
48
+
49
+ const data = await response.json();
50
+
51
+ return {
52
+ ...(params?.additionalFields && pick(data, params.additionalFields)),
53
+ id: data._id,
54
+ name: data.id,
55
+ private: data.private,
56
+ task: data.pipeline_tag,
57
+ downloads: data.downloads,
58
+ gated: data.gated,
59
+ likes: data.likes,
60
+ updatedAt: new Date(data.lastModified),
61
+ } as ModelEntry & Pick<ApiModelInfo, T>;
62
+ }
lib/oauth-handle-redirect.spec.ts ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { describe, expect, it } from "vitest";
2
+ import { TEST_COOKIE, TEST_HUB_URL } from "../test/consts";
3
+ import { oauthLoginUrl } from "./oauth-login-url";
4
+ import { oauthHandleRedirect } from "./oauth-handle-redirect";
5
+
6
+ describe("oauthHandleRedirect", () => {
7
+ it("should work", async () => {
8
+ const localStorage = {
9
+ nonce: undefined,
10
+ codeVerifier: undefined,
11
+ };
12
+ const url = await oauthLoginUrl({
13
+ clientId: "dummy-app",
14
+ redirectUrl: "http://localhost:3000",
15
+ localStorage,
16
+ scopes: "openid profile email",
17
+ hubUrl: TEST_HUB_URL,
18
+ });
19
+ const resp = await fetch(url, {
20
+ method: "POST",
21
+ headers: {
22
+ Cookie: `token=${TEST_COOKIE}`,
23
+ },
24
+ redirect: "manual",
25
+ });
26
+ if (resp.status !== 303) {
27
+ throw new Error(`Failed to fetch url ${url}: ${resp.status} ${resp.statusText}`);
28
+ }
29
+ const location = resp.headers.get("Location");
30
+ if (!location) {
31
+ throw new Error(`No location header in response`);
32
+ }
33
+ const result = await oauthHandleRedirect({
34
+ redirectedUrl: location,
35
+ codeVerifier: localStorage.codeVerifier,
36
+ nonce: localStorage.nonce,
37
+ hubUrl: TEST_HUB_URL,
38
+ });
39
+
40
+ if (!result) {
41
+ throw new Error("Expected result to be defined");
42
+ }
43
+ expect(result.accessToken).toEqual(expect.any(String));
44
+ expect(result.accessTokenExpiresAt).toBeInstanceOf(Date);
45
+ expect(result.accessTokenExpiresAt.getTime()).toBeGreaterThan(Date.now());
46
+ expect(result.scope).toEqual(expect.any(String));
47
+ expect(result.userInfo).toEqual({
48
+ sub: "62f264b9f3c90f4b6514a269",
49
+ name: "@huggingface/hub CI bot",
50
+ preferred_username: "hub.js",
51
+ email_verified: true,
52
+ email: "eliott@huggingface.co",
53
+ isPro: false,
54
+ picture: "https://hub-ci.huggingface.co/avatars/934b830e9fdaa879487852f79eef7165.svg",
55
+ profile: "https://hub-ci.huggingface.co/hub.js",
56
+ website: "https://github.com/huggingface/hub.js",
57
+ orgs: [],
58
+ });
59
+ });
60
+ });
lib/oauth-handle-redirect.ts ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { HUB_URL } from "../consts";
2
+ import { createApiError } from "../error";
3
+
4
+ export interface UserInfo {
5
+ /**
6
+ * OpenID Connect field. Unique identifier for the user, even in case of rename.
7
+ */
8
+ sub: string;
9
+ /**
10
+ * OpenID Connect field. The user's full name.
11
+ */
12
+ name: string;
13
+ /**
14
+ * OpenID Connect field. The user's username.
15
+ */
16
+ preferred_username: string;
17
+ /**
18
+ * OpenID Connect field, available if scope "email" was granted.
19
+ */
20
+ email_verified?: boolean;
21
+ /**
22
+ * OpenID Connect field, available if scope "email" was granted.
23
+ */
24
+ email?: string;
25
+ /**
26
+ * OpenID Connect field. The user's profile picture URL.
27
+ */
28
+ picture: string;
29
+ /**
30
+ * OpenID Connect field. The user's profile URL.
31
+ */
32
+ profile: string;
33
+ /**
34
+ * OpenID Connect field. The user's website URL.
35
+ */
36
+ website?: string;
37
+
38
+ /**
39
+ * Hugging Face field. Whether the user is a pro user.
40
+ */
41
+ isPro: boolean;
42
+ /**
43
+ * Hugging Face field. Whether the user has a payment method set up. Needs "read-billing" scope.
44
+ */
45
+ canPay?: boolean;
46
+ /**
47
+ * Hugging Face field. The user's orgs
48
+ */
49
+ orgs?: Array<{
50
+ /**
51
+ * OpenID Connect field. Unique identifier for the org.
52
+ */
53
+ sub: string;
54
+ /**
55
+ * OpenID Connect field. The org's full name.
56
+ */
57
+ name: string;
58
+ /**
59
+ * OpenID Connect field. The org's username.
60
+ */
61
+ preferred_username: string;
62
+ /**
63
+ * OpenID Connect field. The org's profile picture URL.
64
+ */
65
+ picture: string;
66
+
67
+ /**
68
+ * Hugging Face field. Whether the org is an enterprise org.
69
+ */
70
+ isEnterprise: boolean;
71
+ /**
72
+ * Hugging Face field. Whether the org has a payment method set up. Needs "read-billing" scope, and the user needs to approve access to the org in the OAuth page.
73
+ */
74
+ canPay?: boolean;
75
+ /**
76
+ * Hugging Face field. The user's role in the org. The user needs to approve access to the org in the OAuth page.
77
+ */
78
+ roleInOrg?: string;
79
+ /**
80
+ * HuggingFace field. When the user granted the oauth app access to the org, but didn't complete SSO.
81
+ *
82
+ * Should never happen directly after the oauth flow.
83
+ */
84
+ pendingSSO?: boolean;
85
+ /**
86
+ * HuggingFace field. When the user granted the oauth app access to the org, but didn't complete MFA.
87
+ *
88
+ * Should never happen directly after the oauth flow.
89
+ */
90
+ missingMFA?: boolean;
91
+ }>;
92
+ }
93
+
94
+ export interface OAuthResult {
95
+ accessToken: string;
96
+ accessTokenExpiresAt: Date;
97
+ userInfo: UserInfo;
98
+ /**
99
+ * State passed to the OAuth provider in the original request to the OAuth provider.
100
+ */
101
+ state?: string;
102
+ /**
103
+ * Granted scope
104
+ */
105
+ scope: string;
106
+ }
107
+
108
+ /**
109
+ * To call after the OAuth provider redirects back to the app.
110
+ *
111
+ * There is also a helper function {@link oauthHandleRedirectIfPresent}, which will call `oauthHandleRedirect` if the URL contains an oauth code
112
+ * in the query parameters and return `false` otherwise.
113
+ */
114
+ export async function oauthHandleRedirect(opts?: {
115
+ /**
116
+ * The URL of the hub. Defaults to {@link HUB_URL}.
117
+ */
118
+ hubUrl?: string;
119
+ /**
120
+ * The URL to analyze.
121
+ *
122
+ * @default window.location.href
123
+ */
124
+ redirectedUrl?: string;
125
+ /**
126
+ * nonce generated by oauthLoginUrl
127
+ *
128
+ * @default localStorage.getItem("huggingface.co:oauth:nonce")
129
+ */
130
+ nonce?: string;
131
+ /**
132
+ * codeVerifier generated by oauthLoginUrl
133
+ *
134
+ * @default localStorage.getItem("huggingface.co:oauth:code_verifier")
135
+ */
136
+ codeVerifier?: string;
137
+ }): Promise<OAuthResult> {
138
+ if (typeof window === "undefined" && !opts?.redirectedUrl) {
139
+ throw new Error("oauthHandleRedirect is only available in the browser, unless you provide redirectedUrl");
140
+ }
141
+ if (typeof localStorage === "undefined" && (!opts?.nonce || !opts?.codeVerifier)) {
142
+ throw new Error(
143
+ "oauthHandleRedirect requires localStorage to be available, unless you provide nonce and codeVerifier"
144
+ );
145
+ }
146
+
147
+ const redirectedUrl = opts?.redirectedUrl ?? window.location.href;
148
+ const searchParams = (() => {
149
+ try {
150
+ return new URL(redirectedUrl).searchParams;
151
+ } catch (err) {
152
+ throw new Error("Failed to parse redirected URL: " + redirectedUrl);
153
+ }
154
+ })();
155
+
156
+ const [error, errorDescription] = [searchParams.get("error"), searchParams.get("error_description")];
157
+
158
+ if (error) {
159
+ throw new Error(`${error}: ${errorDescription}`);
160
+ }
161
+
162
+ const code = searchParams.get("code");
163
+ const nonce = opts?.nonce ?? localStorage.getItem("huggingface.co:oauth:nonce");
164
+
165
+ if (!code) {
166
+ throw new Error("Missing oauth code from query parameters in redirected URL: " + redirectedUrl);
167
+ }
168
+
169
+ if (!nonce) {
170
+ throw new Error("Missing oauth nonce from localStorage");
171
+ }
172
+
173
+ const codeVerifier = opts?.codeVerifier ?? localStorage.getItem("huggingface.co:oauth:code_verifier");
174
+
175
+ if (!codeVerifier) {
176
+ throw new Error("Missing oauth code_verifier from localStorage");
177
+ }
178
+
179
+ const state = searchParams.get("state");
180
+
181
+ if (!state) {
182
+ throw new Error("Missing oauth state from query parameters in redirected URL");
183
+ }
184
+
185
+ let parsedState: { nonce: string; redirectUri: string; state?: string };
186
+
187
+ try {
188
+ parsedState = JSON.parse(state);
189
+ } catch {
190
+ throw new Error("Invalid oauth state in redirected URL, unable to parse JSON: " + state);
191
+ }
192
+
193
+ if (parsedState.nonce !== nonce) {
194
+ throw new Error("Invalid oauth state in redirected URL");
195
+ }
196
+
197
+ const hubUrl = opts?.hubUrl || HUB_URL;
198
+
199
+ const openidConfigUrl = `${new URL(hubUrl).origin}/.well-known/openid-configuration`;
200
+ const openidConfigRes = await fetch(openidConfigUrl, {
201
+ headers: {
202
+ Accept: "application/json",
203
+ },
204
+ });
205
+
206
+ if (!openidConfigRes.ok) {
207
+ throw await createApiError(openidConfigRes);
208
+ }
209
+
210
+ const openidConfig: {
211
+ authorization_endpoint: string;
212
+ token_endpoint: string;
213
+ userinfo_endpoint: string;
214
+ } = await openidConfigRes.json();
215
+
216
+ const tokenRes = await fetch(openidConfig.token_endpoint, {
217
+ method: "POST",
218
+ headers: {
219
+ "Content-Type": "application/x-www-form-urlencoded",
220
+ },
221
+ body: new URLSearchParams({
222
+ grant_type: "authorization_code",
223
+ code,
224
+ redirect_uri: parsedState.redirectUri,
225
+ code_verifier: codeVerifier,
226
+ }).toString(),
227
+ });
228
+
229
+ if (!opts?.codeVerifier) {
230
+ localStorage.removeItem("huggingface.co:oauth:code_verifier");
231
+ }
232
+ if (!opts?.nonce) {
233
+ localStorage.removeItem("huggingface.co:oauth:nonce");
234
+ }
235
+
236
+ if (!tokenRes.ok) {
237
+ throw await createApiError(tokenRes);
238
+ }
239
+
240
+ const token: {
241
+ access_token: string;
242
+ expires_in: number;
243
+ id_token: string;
244
+ // refresh_token: string;
245
+ scope: string;
246
+ token_type: string;
247
+ } = await tokenRes.json();
248
+
249
+ const accessTokenExpiresAt = new Date(Date.now() + token.expires_in * 1000);
250
+
251
+ const userInfoRes = await fetch(openidConfig.userinfo_endpoint, {
252
+ headers: {
253
+ Authorization: `Bearer ${token.access_token}`,
254
+ },
255
+ });
256
+
257
+ if (!userInfoRes.ok) {
258
+ throw await createApiError(userInfoRes);
259
+ }
260
+
261
+ const userInfo: UserInfo = await userInfoRes.json();
262
+
263
+ return {
264
+ accessToken: token.access_token,
265
+ accessTokenExpiresAt,
266
+ userInfo: userInfo,
267
+ state: parsedState.state,
268
+ scope: token.scope,
269
+ };
270
+ }
271
+
272
+ // if (code && !nonce) {
273
+ // console.warn("Missing oauth nonce from localStorage");
274
+ // }
275
+
276
+ /**
277
+ * To call after the OAuth provider redirects back to the app.
278
+ *
279
+ * It returns false if the URL does not contain an oauth code in the query parameters, otherwise
280
+ * it calls {@link oauthHandleRedirect}.
281
+ *
282
+ * Depending on your app, you may want to call {@link oauthHandleRedirect} directly instead.
283
+ */
284
+ export async function oauthHandleRedirectIfPresent(opts?: {
285
+ /**
286
+ * The URL of the hub. Defaults to {@link HUB_URL}.
287
+ */
288
+ hubUrl?: string;
289
+ /**
290
+ * The URL to analyze.
291
+ *
292
+ * @default window.location.href
293
+ */
294
+ redirectedUrl?: string;
295
+ /**
296
+ * nonce generated by oauthLoginUrl
297
+ *
298
+ * @default localStorage.getItem("huggingface.co:oauth:nonce")
299
+ */
300
+ nonce?: string;
301
+ /**
302
+ * codeVerifier generated by oauthLoginUrl
303
+ *
304
+ * @default localStorage.getItem("huggingface.co:oauth:code_verifier")
305
+ */
306
+ codeVerifier?: string;
307
+ }): Promise<OAuthResult | false> {
308
+ if (typeof window === "undefined" && !opts?.redirectedUrl) {
309
+ throw new Error("oauthHandleRedirect is only available in the browser, unless you provide redirectedUrl");
310
+ }
311
+ if (typeof localStorage === "undefined" && (!opts?.nonce || !opts?.codeVerifier)) {
312
+ throw new Error(
313
+ "oauthHandleRedirect requires localStorage to be available, unless you provide nonce and codeVerifier"
314
+ );
315
+ }
316
+ const searchParams = new URLSearchParams(opts?.redirectedUrl ?? window.location.search);
317
+
318
+ if (searchParams.has("error")) {
319
+ return oauthHandleRedirect(opts);
320
+ }
321
+
322
+ if (searchParams.has("code")) {
323
+ if (!localStorage.getItem("huggingface.co:oauth:nonce")) {
324
+ console.warn(
325
+ "Missing oauth nonce from localStorage. This can happen when the user refreshes the page after logging in, without changing the URL."
326
+ );
327
+ return false;
328
+ }
329
+
330
+ return oauthHandleRedirect(opts);
331
+ }
332
+
333
+ return false;
334
+ }
lib/oauth-login-url.ts ADDED
@@ -0,0 +1,166 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { HUB_URL } from "../consts";
2
+ import { createApiError } from "../error";
3
+ import { base64FromBytes } from "../utils/base64FromBytes";
4
+
5
+ /**
6
+ * Use "Sign in with Hub" to authenticate a user, and get oauth user info / access token.
7
+ *
8
+ * Returns an url to redirect to. After the user is redirected back to your app, call `oauthHandleRedirect` to get the oauth user info / access token.
9
+ *
10
+ * When called from inside a static Space with OAuth enabled, it will load the config from the space, otherwise you need to at least specify
11
+ * the client ID of your OAuth App.
12
+ *
13
+ * @example
14
+ * ```ts
15
+ * import { oauthLoginUrl, oauthHandleRedirectIfPresent } from "@huggingface/hub";
16
+ *
17
+ * const oauthResult = await oauthHandleRedirectIfPresent();
18
+ *
19
+ * if (!oauthResult) {
20
+ * // If the user is not logged in, redirect to the login page
21
+ * window.location.href = await oauthLoginUrl();
22
+ * }
23
+ *
24
+ * // You can use oauthResult.accessToken, oauthResult.accessTokenExpiresAt and oauthResult.userInfo
25
+ * console.log(oauthResult);
26
+ * ```
27
+ *
28
+ * (Theoretically, this function could be used to authenticate a user for any OAuth provider supporting PKCE and OpenID Connect by changing `hubUrl`,
29
+ * but it is currently only tested with the Hugging Face Hub.)
30
+ */
31
+ export async function oauthLoginUrl(opts?: {
32
+ /**
33
+ * OAuth client ID.
34
+ *
35
+ * For static Spaces, you can omit this and it will be loaded from the Space config, as long as `hf_oauth: true` is present in the README.md's metadata.
36
+ * For other Spaces, it is available to the backend in the OAUTH_CLIENT_ID environment variable, as long as `hf_oauth: true` is present in the README.md's metadata.
37
+ *
38
+ * You can also create a Developer Application at https://huggingface.co/settings/connected-applications and use its client ID.
39
+ */
40
+ clientId?: string;
41
+ hubUrl?: string;
42
+ /**
43
+ * OAuth scope, a list of space-separated scopes.
44
+ *
45
+ * For static Spaces, you can omit this and it will be loaded from the Space config, as long as `hf_oauth: true` is present in the README.md's metadata.
46
+ * For other Spaces, it is available to the backend in the OAUTH_SCOPES environment variable, as long as `hf_oauth: true` is present in the README.md's metadata.
47
+ *
48
+ * Defaults to "openid profile".
49
+ *
50
+ * You can also create a Developer Application at https://huggingface.co/settings/connected-applications and use its scopes.
51
+ *
52
+ * See https://huggingface.co/docs/hub/oauth for a list of available scopes.
53
+ */
54
+ scopes?: string;
55
+ /**
56
+ * Redirect URI, defaults to the current URL.
57
+ *
58
+ * For Spaces, any URL within the Space is allowed.
59
+ *
60
+ * For Developer Applications, you can add any URL you want to the list of allowed redirect URIs at https://huggingface.co/settings/connected-applications.
61
+ */
62
+ redirectUrl?: string;
63
+ /**
64
+ * State to pass to the OAuth provider, which will be returned in the call to `oauthLogin` after the redirect.
65
+ */
66
+ state?: string;
67
+ /**
68
+ * If provided, will be filled with the code verifier and nonce used for the OAuth flow,
69
+ * instead of using localStorage.
70
+ *
71
+ * When calling {@link `oauthHandleRedirectIfPresent`} or {@link `oauthHandleRedirect`} you will need to provide the same values.
72
+ */
73
+ localStorage?: {
74
+ codeVerifier?: string;
75
+ nonce?: string;
76
+ };
77
+ }): Promise<string> {
78
+ if (typeof window === "undefined" && (!opts?.redirectUrl || !opts?.clientId)) {
79
+ throw new Error("oauthLogin is only available in the browser, unless you provide clientId and redirectUrl");
80
+ }
81
+ if (typeof localStorage === "undefined" && !opts?.localStorage) {
82
+ throw new Error(
83
+ "oauthLogin requires localStorage to be available in the context, unless you provide a localStorage empty object as argument"
84
+ );
85
+ }
86
+
87
+ const hubUrl = opts?.hubUrl || HUB_URL;
88
+ const openidConfigUrl = `${new URL(hubUrl).origin}/.well-known/openid-configuration`;
89
+ const openidConfigRes = await fetch(openidConfigUrl, {
90
+ headers: {
91
+ Accept: "application/json",
92
+ },
93
+ });
94
+
95
+ if (!openidConfigRes.ok) {
96
+ throw await createApiError(openidConfigRes);
97
+ }
98
+
99
+ const opendidConfig: {
100
+ authorization_endpoint: string;
101
+ token_endpoint: string;
102
+ userinfo_endpoint: string;
103
+ } = await openidConfigRes.json();
104
+
105
+ const newNonce = globalThis.crypto.randomUUID();
106
+ // Two random UUIDs concatenated together, because min length is 43 and max length is 128
107
+ const newCodeVerifier = globalThis.crypto.randomUUID() + globalThis.crypto.randomUUID();
108
+
109
+ if (opts?.localStorage) {
110
+ if (opts.localStorage.codeVerifier !== undefined && opts.localStorage.codeVerifier !== null) {
111
+ throw new Error(
112
+ "localStorage.codeVerifier must be initially set to null or undefined, and will be filled by oauthLoginUrl"
113
+ );
114
+ }
115
+ if (opts.localStorage.nonce !== undefined && opts.localStorage.nonce !== null) {
116
+ throw new Error(
117
+ "localStorage.nonce must be initially set to null or undefined, and will be filled by oauthLoginUrl"
118
+ );
119
+ }
120
+ opts.localStorage.codeVerifier = newCodeVerifier;
121
+ opts.localStorage.nonce = newNonce;
122
+ } else {
123
+ localStorage.setItem("huggingface.co:oauth:nonce", newNonce);
124
+ localStorage.setItem("huggingface.co:oauth:code_verifier", newCodeVerifier);
125
+ }
126
+
127
+ const redirectUri = opts?.redirectUrl || (typeof window !== "undefined" ? window.location.href : undefined);
128
+ if (!redirectUri) {
129
+ throw new Error("Missing redirectUrl");
130
+ }
131
+ const state = JSON.stringify({
132
+ nonce: newNonce,
133
+ redirectUri,
134
+ state: opts?.state,
135
+ });
136
+
137
+ const variables: Record<string, string> | null =
138
+ // @ts-expect-error window.huggingface is defined inside static Spaces.
139
+ typeof window !== "undefined" ? window.huggingface?.variables ?? null : null;
140
+
141
+ const clientId = opts?.clientId || variables?.OAUTH_CLIENT_ID;
142
+
143
+ if (!clientId) {
144
+ if (variables) {
145
+ throw new Error("Missing clientId, please add hf_oauth: true to the README.md's metadata in your static Space");
146
+ }
147
+ throw new Error("Missing clientId");
148
+ }
149
+
150
+ const challenge = base64FromBytes(
151
+ new Uint8Array(await globalThis.crypto.subtle.digest("SHA-256", new TextEncoder().encode(newCodeVerifier)))
152
+ )
153
+ .replace(/[+]/g, "-")
154
+ .replace(/[/]/g, "_")
155
+ .replace(/=/g, "");
156
+
157
+ return `${opendidConfig.authorization_endpoint}?${new URLSearchParams({
158
+ client_id: clientId,
159
+ scope: opts?.scopes || variables?.OAUTH_SCOPES || "openid profile",
160
+ response_type: "code",
161
+ redirect_uri: redirectUri,
162
+ state,
163
+ code_challenge: challenge,
164
+ code_challenge_method: "S256",
165
+ }).toString()}`;
166
+ }
lib/parse-safetensors-metadata.spec.ts ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import { assert, it, describe } from "vitest";
2
+ import { parseSafetensorsMetadata, parseSafetensorsShardFilename } from "./parse-safetensors-metadata";
3
+ import { sum } from "../utils/sum";
4
+
5
+ describe("parseSafetensorsMetadata", () => {
6
+ it("fetch info for single-file (with the default conventional filename)", async () => {
7
+ const parse = await parseSafetensorsMetadata({
8
+ repo: "bert-base-uncased",
9
+ computeParametersCount: true,
10
+ revision: "86b5e0934494bd15c9632b12f734a8a67f723594",
11
+ });
12
+
13
+ assert(!parse.sharded);
14
+ assert.deepStrictEqual(parse.header.__metadata__, { format: "pt" });
15
+
16
+ // Example of one tensor (the header contains many tensors)
17
+
18
+ assert.deepStrictEqual(parse.header["bert.embeddings.LayerNorm.beta"], {
19
+ dtype: "F32",
20
+ shape: [768],
21
+ data_offsets: [0, 3072],
22
+ });
23
+
24
+ assert.deepStrictEqual(parse.parameterCount, { F32: 110_106_428 });
25
+ assert.deepStrictEqual(sum(Object.values(parse.parameterCount)), 110_106_428);
26
+ // total params = 110m
27
+ });
28
+
29
+ it("fetch info for sharded (with the default conventional filename)", async () => {
30
+ const parse = await parseSafetensorsMetadata({
31
+ repo: "bigscience/bloom",
32
+ computeParametersCount: true,
33
+ revision: "053d9cd9fbe814e091294f67fcfedb3397b954bb",
34
+ });
35
+
36
+ assert(parse.sharded);
37
+
38
+ assert.strictEqual(Object.keys(parse.headers).length, 72);
39
+ // This model has 72 shards!
40
+
41
+ // Example of one tensor inside one file
42
+
43
+ assert.deepStrictEqual(parse.headers["model_00012-of-00072.safetensors"]["h.10.input_layernorm.weight"], {
44
+ dtype: "BF16",
45
+ shape: [14336],
46
+ data_offsets: [3288649728, 3288678400],
47
+ });
48
+
49
+ assert.deepStrictEqual(parse.parameterCount, { BF16: 176_247_271_424 });
50
+ assert.deepStrictEqual(sum(Object.values(parse.parameterCount)), 176_247_271_424);
51
+ // total params = 176B
52
+ });
53
+
54
+ it("fetch info for single-file with multiple dtypes", async () => {
55
+ const parse = await parseSafetensorsMetadata({
56
+ repo: "roberta-base",
57
+ computeParametersCount: true,
58
+ revision: "e2da8e2f811d1448a5b465c236feacd80ffbac7b",
59
+ });
60
+
61
+ assert(!parse.sharded);
62
+
63
+ assert.deepStrictEqual(parse.parameterCount, { F32: 124_697_433, I64: 514 });
64
+ assert.deepStrictEqual(sum(Object.values(parse.parameterCount)), 124_697_947);
65
+ // total params = 124m
66
+ });
67
+
68
+ it("fetch info for single-file with file path", async () => {
69
+ const parse = await parseSafetensorsMetadata({
70
+ repo: "CompVis/stable-diffusion-v1-4",
71
+ computeParametersCount: true,
72
+ path: "unet/diffusion_pytorch_model.safetensors",
73
+ revision: "133a221b8aa7292a167afc5127cb63fb5005638b",
74
+ });
75
+
76
+ assert(!parse.sharded);
77
+ assert.deepStrictEqual(parse.header.__metadata__, { format: "pt" });
78
+
79
+ // Example of one tensor (the header contains many tensors)
80
+
81
+ assert.deepStrictEqual(parse.header["up_blocks.3.resnets.0.norm2.bias"], {
82
+ dtype: "F32",
83
+ shape: [320],
84
+ data_offsets: [3_409_382_416, 3_409_383_696],
85
+ });
86
+
87
+ assert.deepStrictEqual(parse.parameterCount, { F32: 859_520_964 });
88
+ assert.deepStrictEqual(sum(Object.values(parse.parameterCount)), 859_520_964);
89
+ });
90
+
91
+ it("fetch info for sharded (with the default conventional filename) with file path", async () => {
92
+ const parse = await parseSafetensorsMetadata({
93
+ repo: "Alignment-Lab-AI/ALAI-gemma-7b",
94
+ computeParametersCount: true,
95
+ path: "7b/1/model.safetensors.index.json",
96
+ revision: "37e307261fe97bbf8b2463d61dbdd1a10daa264c",
97
+ });
98
+
99
+ assert(parse.sharded);
100
+
101
+ assert.strictEqual(Object.keys(parse.headers).length, 4);
102
+
103
+ assert.deepStrictEqual(parse.headers["model-00004-of-00004.safetensors"]["model.layers.24.mlp.up_proj.weight"], {
104
+ dtype: "BF16",
105
+ shape: [24576, 3072],
106
+ data_offsets: [301996032, 452990976],
107
+ });
108
+
109
+ assert.deepStrictEqual(parse.parameterCount, { BF16: 8_537_680_896 });
110
+ assert.deepStrictEqual(sum(Object.values(parse.parameterCount)), 8_537_680_896);
111
+ });
112
+
113
+ it("should detect sharded safetensors filename", async () => {
114
+ const safetensorsFilename = "model_00005-of-00072.safetensors"; // https://huggingface.co/bigscience/bloom/blob/4d8e28c67403974b0f17a4ac5992e4ba0b0dbb6f/model_00005-of-00072.safetensors
115
+ const safetensorsShardFileInfo = parseSafetensorsShardFilename(safetensorsFilename);
116
+
117
+ assert.strictEqual(safetensorsShardFileInfo?.prefix, "model_");
118
+ assert.strictEqual(safetensorsShardFileInfo?.basePrefix, "model");
119
+ assert.strictEqual(safetensorsShardFileInfo?.shard, "00005");
120
+ assert.strictEqual(safetensorsShardFileInfo?.total, "00072");
121
+ });
122
+ });
lib/parse-safetensors-metadata.ts ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import type { CredentialsParams, RepoDesignation } from "../types/public";
2
+ import { omit } from "../utils/omit";
3
+ import { toRepoId } from "../utils/toRepoId";
4
+ import { typedEntries } from "../utils/typedEntries";
5
+ import { downloadFile } from "./download-file";
6
+ import { fileExists } from "./file-exists";
7
+ import { promisesQueue } from "../utils/promisesQueue";
8
+ import type { SetRequired } from "../vendor/type-fest/set-required";
9
+
10
+ export const SAFETENSORS_FILE = "model.safetensors";
11
+ export const SAFETENSORS_INDEX_FILE = "model.safetensors.index.json";
12
+ /// We advise model/library authors to use the filenames above for convention inside model repos,
13
+ /// but in some situations safetensors weights have different filenames.
14
+ export const RE_SAFETENSORS_FILE = /\.safetensors$/;
15
+ export const RE_SAFETENSORS_INDEX_FILE = /\.safetensors\.index\.json$/;
16
+ export const RE_SAFETENSORS_SHARD_FILE =
17
+ /^(?<prefix>(?<basePrefix>.*?)[_-])(?<shard>\d{5})-of-(?<total>\d{5})\.safetensors$/;
18
+ export interface SafetensorsShardFileInfo {
19
+ prefix: string;
20
+ basePrefix: string;
21
+ shard: string;
22
+ total: string;
23
+ }
24
+ export function parseSafetensorsShardFilename(filename: string): SafetensorsShardFileInfo | null {
25
+ const match = RE_SAFETENSORS_SHARD_FILE.exec(filename);
26
+ if (match && match.groups) {
27
+ return {
28
+ prefix: match.groups["prefix"],
29
+ basePrefix: match.groups["basePrefix"],
30
+ shard: match.groups["shard"],
31
+ total: match.groups["total"],
32
+ };
33
+ }
34
+ return null;
35
+ }
36
+
37
+ const PARALLEL_DOWNLOADS = 20;
38
+ const MAX_HEADER_LENGTH = 25_000_000;
39
+
40
+ class SafetensorParseError extends Error {}
41
+
42
+ type FileName = string;
43
+
44
+ export type TensorName = string;
45
+ export type Dtype = "F64" | "F32" | "F16" | "BF16" | "I64" | "I32" | "I16" | "I8" | "U8" | "BOOL";
46
+
47
+ export interface TensorInfo {
48
+ dtype: Dtype;
49
+ shape: number[];
50
+ data_offsets: [number, number];
51
+ }
52
+
53
+ export type SafetensorsFileHeader = Record<TensorName, TensorInfo> & {
54
+ __metadata__: Record<string, string>;
55
+ };
56
+
57
+ export interface SafetensorsIndexJson {
58
+ dtype?: string;
59
+ /// ^there's sometimes a dtype but it looks inconsistent.
60
+ metadata?: Record<string, string>;
61
+ /// ^ why the naming inconsistency?
62
+ weight_map: Record<TensorName, FileName>;
63
+ }
64
+
65
+ export type SafetensorsShardedHeaders = Record<FileName, SafetensorsFileHeader>;
66
+
67
+ export type SafetensorsParseFromRepo =
68
+ | {
69
+ sharded: false;
70
+ header: SafetensorsFileHeader;
71
+ parameterCount?: Partial<Record<Dtype, number>>;
72
+ }
73
+ | {
74
+ sharded: true;
75
+ index: SafetensorsIndexJson;
76
+ headers: SafetensorsShardedHeaders;
77
+ parameterCount?: Partial<Record<Dtype, number>>;
78
+ };
79
+
80
+ async function parseSingleFile(
81
+ path: string,
82
+ params: {
83
+ repo: RepoDesignation;
84
+ revision?: string;
85
+ hubUrl?: string;
86
+ /**
87
+ * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
88
+ */
89
+ fetch?: typeof fetch;
90
+ } & Partial<CredentialsParams>
91
+ ): Promise<SafetensorsFileHeader> {
92
+ const blob = await downloadFile({ ...params, path });
93
+
94
+ if (!blob) {
95
+ throw new SafetensorParseError(`Failed to parse file ${path}: failed to fetch safetensors header length.`);
96
+ }
97
+
98
+ const bufLengthOfHeaderLE = await blob.slice(0, 8).arrayBuffer();
99
+ const lengthOfHeader = new DataView(bufLengthOfHeaderLE).getBigUint64(0, true);
100
+ // ^little-endian
101
+ if (lengthOfHeader <= 0) {
102
+ throw new SafetensorParseError(`Failed to parse file ${path}: safetensors header is malformed.`);
103
+ }
104
+ if (lengthOfHeader > MAX_HEADER_LENGTH) {
105
+ throw new SafetensorParseError(
106
+ `Failed to parse file ${path}: safetensor header is too big. Maximum supported size is ${MAX_HEADER_LENGTH} bytes.`
107
+ );
108
+ }
109
+
110
+ try {
111
+ // no validation for now, we assume it's a valid FileHeader.
112
+ const header: SafetensorsFileHeader = JSON.parse(await blob.slice(8, 8 + Number(lengthOfHeader)).text());
113
+ return header;
114
+ } catch (err) {
115
+ throw new SafetensorParseError(`Failed to parse file ${path}: safetensors header is not valid JSON.`);
116
+ }
117
+ }
118
+
119
+ async function parseShardedIndex(
120
+ path: string,
121
+ params: {
122
+ repo: RepoDesignation;
123
+ revision?: string;
124
+ hubUrl?: string;
125
+ /**
126
+ * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
127
+ */
128
+ fetch?: typeof fetch;
129
+ } & Partial<CredentialsParams>
130
+ ): Promise<{ index: SafetensorsIndexJson; headers: SafetensorsShardedHeaders }> {
131
+ const indexBlob = await downloadFile({
132
+ ...params,
133
+ path,
134
+ });
135
+
136
+ if (!indexBlob) {
137
+ throw new SafetensorParseError(`Failed to parse file ${path}: failed to fetch safetensors index.`);
138
+ }
139
+
140
+ // no validation for now, we assume it's a valid IndexJson.
141
+ let index: SafetensorsIndexJson;
142
+ try {
143
+ index = JSON.parse(await indexBlob.slice(0, 10_000_000).text());
144
+ } catch (error) {
145
+ throw new SafetensorParseError(`Failed to parse file ${path}: not a valid JSON.`);
146
+ }
147
+
148
+ const pathPrefix = path.slice(0, path.lastIndexOf("/") + 1);
149
+ const filenames = [...new Set(Object.values(index.weight_map))];
150
+ const shardedMap: SafetensorsShardedHeaders = Object.fromEntries(
151
+ await promisesQueue(
152
+ filenames.map(
153
+ (filename) => async () =>
154
+ [filename, await parseSingleFile(pathPrefix + filename, params)] satisfies [string, SafetensorsFileHeader]
155
+ ),
156
+ PARALLEL_DOWNLOADS
157
+ )
158
+ );
159
+ return { index, headers: shardedMap };
160
+ }
161
+
162
+ /**
163
+ * Analyze model.safetensors.index.json or model.safetensors from a model hosted
164
+ * on Hugging Face using smart range requests to extract its metadata.
165
+ */
166
+ export async function parseSafetensorsMetadata(
167
+ params: {
168
+ /** Only models are supported */
169
+ repo: RepoDesignation;
170
+ /**
171
+ * Relative file path to safetensors file inside `repo`. Defaults to `SAFETENSORS_FILE` or `SAFETENSORS_INDEX_FILE` (whichever one exists).
172
+ */
173
+ path?: string;
174
+ /**
175
+ * Will include SafetensorsParseFromRepo["parameterCount"], an object containing the number of parameters for each DType
176
+ *
177
+ * @default false
178
+ */
179
+ computeParametersCount: true;
180
+ hubUrl?: string;
181
+ revision?: string;
182
+ /**
183
+ * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
184
+ */
185
+ fetch?: typeof fetch;
186
+ } & Partial<CredentialsParams>
187
+ ): Promise<SetRequired<SafetensorsParseFromRepo, "parameterCount">>;
188
+ export async function parseSafetensorsMetadata(
189
+ params: {
190
+ /** Only models are supported */
191
+ repo: RepoDesignation;
192
+ /**
193
+ * Will include SafetensorsParseFromRepo["parameterCount"], an object containing the number of parameters for each DType
194
+ *
195
+ * @default false
196
+ */
197
+ path?: string;
198
+ computeParametersCount?: boolean;
199
+ hubUrl?: string;
200
+ revision?: string;
201
+ /**
202
+ * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
203
+ */
204
+ fetch?: typeof fetch;
205
+ } & Partial<CredentialsParams>
206
+ ): Promise<SafetensorsParseFromRepo>;
207
+ export async function parseSafetensorsMetadata(
208
+ params: {
209
+ repo: RepoDesignation;
210
+ path?: string;
211
+ computeParametersCount?: boolean;
212
+ hubUrl?: string;
213
+ revision?: string;
214
+ /**
215
+ * Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
216
+ */
217
+ fetch?: typeof fetch;
218
+ } & Partial<CredentialsParams>
219
+ ): Promise<SafetensorsParseFromRepo> {
220
+ const repoId = toRepoId(params.repo);
221
+
222
+ if (repoId.type !== "model") {
223
+ throw new TypeError("Only model repos should contain safetensors files.");
224
+ }
225
+
226
+ if (RE_SAFETENSORS_FILE.test(params.path ?? "") || (await fileExists({ ...params, path: SAFETENSORS_FILE }))) {
227
+ const header = await parseSingleFile(params.path ?? SAFETENSORS_FILE, params);
228
+ return {
229
+ sharded: false,
230
+ header,
231
+ ...(params.computeParametersCount && {
232
+ parameterCount: computeNumOfParamsByDtypeSingleFile(header),
233
+ }),
234
+ };
235
+ } else if (
236
+ RE_SAFETENSORS_INDEX_FILE.test(params.path ?? "") ||
237
+ (await fileExists({ ...params, path: SAFETENSORS_INDEX_FILE }))
238
+ ) {
239
+ const { index, headers } = await parseShardedIndex(params.path ?? SAFETENSORS_INDEX_FILE, params);
240
+ return {
241
+ sharded: true,
242
+ index,
243
+ headers,
244
+ ...(params.computeParametersCount && {
245
+ parameterCount: computeNumOfParamsByDtypeSharded(headers),
246
+ }),
247
+ };
248
+ } else {
249
+ throw new Error("model id does not seem to contain safetensors weights");
250
+ }
251
+ }
252
+
253
+ function computeNumOfParamsByDtypeSingleFile(header: SafetensorsFileHeader): Partial<Record<Dtype, number>> {
254
+ const counter: Partial<Record<Dtype, number>> = {};
255
+ const tensors = omit(header, "__metadata__");
256
+
257
+ for (const [, v] of typedEntries(tensors)) {
258
+ if (v.shape.length === 0) {
259
+ continue;
260
+ }
261
+ counter[v.dtype] = (counter[v.dtype] ?? 0) + v.shape.reduce((a, b) => a * b);
262
+ }
263
+ return counter;
264
+ }
265
+
266
+ function computeNumOfParamsByDtypeSharded(shardedMap: SafetensorsShardedHeaders): Partial<Record<Dtype, number>> {
267
+ const counter: Partial<Record<Dtype, number>> = {};
268
+ for (const header of Object.values(shardedMap)) {
269
+ for (const [k, v] of typedEntries(computeNumOfParamsByDtypeSingleFile(header))) {
270
+ counter[k] = (counter[k] ?? 0) + (v ?? 0);
271
+ }
272
+ }
273
+ return counter;
274
+ }