Add 1 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- consts.ts +1 -0
- error.ts +49 -0
- index.ts +25 -0
- lib/cache-management.spec.ts +137 -0
- lib/cache-management.ts +265 -0
- lib/check-repo-access.spec.ts +34 -0
- lib/check-repo-access.ts +32 -0
- lib/commit.spec.ts +271 -0
- lib/commit.ts +609 -0
- lib/count-commits.spec.ts +16 -0
- lib/count-commits.ts +35 -0
- lib/create-branch.spec.ts +159 -0
- lib/create-branch.ts +54 -0
- lib/create-repo.spec.ts +103 -0
- lib/create-repo.ts +78 -0
- lib/dataset-info.spec.ts +56 -0
- lib/dataset-info.ts +61 -0
- lib/delete-branch.spec.ts +43 -0
- lib/delete-branch.ts +32 -0
- lib/delete-file.spec.ts +64 -0
- lib/delete-file.ts +35 -0
- lib/delete-files.spec.ts +81 -0
- lib/delete-files.ts +33 -0
- lib/delete-repo.ts +37 -0
- lib/download-file-to-cache-dir.spec.ts +306 -0
- lib/download-file-to-cache-dir.ts +138 -0
- lib/download-file.spec.ts +82 -0
- lib/download-file.ts +77 -0
- lib/file-download-info.spec.ts +59 -0
- lib/file-download-info.ts +151 -0
- lib/file-exists.spec.ts +30 -0
- lib/file-exists.ts +41 -0
- lib/index.ts +32 -0
- lib/list-commits.spec.ts +117 -0
- lib/list-commits.ts +70 -0
- lib/list-datasets.spec.ts +47 -0
- lib/list-datasets.ts +121 -0
- lib/list-files.spec.ts +173 -0
- lib/list-files.ts +94 -0
- lib/list-models.spec.ts +118 -0
- lib/list-models.ts +139 -0
- lib/list-spaces.spec.ts +40 -0
- lib/list-spaces.ts +111 -0
- lib/model-info.spec.ts +59 -0
- lib/model-info.ts +62 -0
- lib/oauth-handle-redirect.spec.ts +60 -0
- lib/oauth-handle-redirect.ts +334 -0
- lib/oauth-login-url.ts +166 -0
- lib/parse-safetensors-metadata.spec.ts +122 -0
- lib/parse-safetensors-metadata.ts +274 -0
consts.ts
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
export const HUB_URL = "https://huggingface.co";
|
error.ts
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import type { JsonObject } from "./vendor/type-fest/basic";
|
2 |
+
|
3 |
+
export async function createApiError(
|
4 |
+
response: Response,
|
5 |
+
opts?: { requestId?: string; message?: string }
|
6 |
+
): Promise<never> {
|
7 |
+
const error = new HubApiError(response.url, response.status, response.headers.get("X-Request-Id") ?? opts?.requestId);
|
8 |
+
|
9 |
+
error.message = `Api error with status ${error.statusCode}${opts?.message ? `. ${opts.message}` : ""}`;
|
10 |
+
|
11 |
+
const trailer = [`URL: ${error.url}`, error.requestId ? `Request ID: ${error.requestId}` : undefined]
|
12 |
+
.filter(Boolean)
|
13 |
+
.join(". ");
|
14 |
+
|
15 |
+
if (response.headers.get("Content-Type")?.startsWith("application/json")) {
|
16 |
+
const json = await response.json();
|
17 |
+
error.message = json.error || json.message || error.message;
|
18 |
+
if (json.error_description) {
|
19 |
+
error.message = error.message ? error.message + `: ${json.error_description}` : json.error_description;
|
20 |
+
}
|
21 |
+
error.data = json;
|
22 |
+
} else {
|
23 |
+
error.data = { message: await response.text() };
|
24 |
+
}
|
25 |
+
|
26 |
+
error.message += `. ${trailer}`;
|
27 |
+
|
28 |
+
throw error;
|
29 |
+
}
|
30 |
+
|
31 |
+
/**
|
32 |
+
* Error thrown when an API call to the Hugging Face Hub fails.
|
33 |
+
*/
|
34 |
+
export class HubApiError extends Error {
|
35 |
+
statusCode: number;
|
36 |
+
url: string;
|
37 |
+
requestId?: string;
|
38 |
+
data?: JsonObject;
|
39 |
+
|
40 |
+
constructor(url: string, statusCode: number, requestId?: string, message?: string) {
|
41 |
+
super(message);
|
42 |
+
|
43 |
+
this.statusCode = statusCode;
|
44 |
+
this.requestId = requestId;
|
45 |
+
this.url = url;
|
46 |
+
}
|
47 |
+
}
|
48 |
+
|
49 |
+
export class InvalidApiResponseFormatError extends Error {}
|
index.ts
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
export * from "./lib";
|
2 |
+
// Typescript 5 will add 'export type *'
|
3 |
+
export type {
|
4 |
+
AccessToken,
|
5 |
+
AccessTokenRole,
|
6 |
+
AuthType,
|
7 |
+
Credentials,
|
8 |
+
PipelineType,
|
9 |
+
RepoDesignation,
|
10 |
+
RepoFullName,
|
11 |
+
RepoId,
|
12 |
+
RepoType,
|
13 |
+
SpaceHardwareFlavor,
|
14 |
+
SpaceResourceConfig,
|
15 |
+
SpaceResourceRequirement,
|
16 |
+
SpaceRuntime,
|
17 |
+
SpaceSdk,
|
18 |
+
SpaceStage,
|
19 |
+
} from "./types/public";
|
20 |
+
export { HubApiError, InvalidApiResponseFormatError } from "./error";
|
21 |
+
/**
|
22 |
+
* Only exported for E2Es convenience
|
23 |
+
*/
|
24 |
+
export { sha256 as __internal_sha256 } from "./utils/sha256";
|
25 |
+
export { XetBlob as __internal_XetBlob } from "./utils/XetBlob";
|
lib/cache-management.spec.ts
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { describe, test, expect, vi, beforeEach } from "vitest";
|
2 |
+
import {
|
3 |
+
scanCacheDir,
|
4 |
+
scanCachedRepo,
|
5 |
+
scanSnapshotDir,
|
6 |
+
parseRepoType,
|
7 |
+
getBlobStat,
|
8 |
+
type CachedFileInfo,
|
9 |
+
} from "./cache-management";
|
10 |
+
import { stat, readdir, realpath, lstat } from "node:fs/promises";
|
11 |
+
import type { Dirent, Stats } from "node:fs";
|
12 |
+
import { join } from "node:path";
|
13 |
+
|
14 |
+
// Mocks
|
15 |
+
vi.mock("node:fs/promises");
|
16 |
+
|
17 |
+
beforeEach(() => {
|
18 |
+
vi.resetAllMocks();
|
19 |
+
vi.restoreAllMocks();
|
20 |
+
});
|
21 |
+
|
22 |
+
describe("scanCacheDir", () => {
|
23 |
+
test("should throw an error if cacheDir is not a directory", async () => {
|
24 |
+
vi.mocked(stat).mockResolvedValueOnce({
|
25 |
+
isDirectory: () => false,
|
26 |
+
} as Stats);
|
27 |
+
|
28 |
+
await expect(scanCacheDir("/fake/dir")).rejects.toThrow("Scan cache expects a directory");
|
29 |
+
});
|
30 |
+
|
31 |
+
test("empty directory should return an empty set of repository and no warnings", async () => {
|
32 |
+
vi.mocked(stat).mockResolvedValueOnce({
|
33 |
+
isDirectory: () => true,
|
34 |
+
} as Stats);
|
35 |
+
|
36 |
+
// mock empty cache folder
|
37 |
+
vi.mocked(readdir).mockResolvedValue([]);
|
38 |
+
|
39 |
+
const result = await scanCacheDir("/fake/dir");
|
40 |
+
|
41 |
+
// cacheDir must have been read
|
42 |
+
expect(readdir).toHaveBeenCalledWith("/fake/dir");
|
43 |
+
|
44 |
+
expect(result.warnings.length).toBe(0);
|
45 |
+
expect(result.repos).toHaveLength(0);
|
46 |
+
expect(result.size).toBe(0);
|
47 |
+
});
|
48 |
+
});
|
49 |
+
|
50 |
+
describe("scanCachedRepo", () => {
|
51 |
+
test("should throw an error for invalid repo path", async () => {
|
52 |
+
await expect(() => {
|
53 |
+
return scanCachedRepo("/fake/repo_path");
|
54 |
+
}).rejects.toThrow("Repo path is not a valid HuggingFace cache directory");
|
55 |
+
});
|
56 |
+
|
57 |
+
test("should throw an error if the snapshot folder does not exist", async () => {
|
58 |
+
vi.mocked(readdir).mockResolvedValue([]);
|
59 |
+
vi.mocked(stat).mockResolvedValue({
|
60 |
+
isDirectory: () => false,
|
61 |
+
} as Stats);
|
62 |
+
|
63 |
+
await expect(() => {
|
64 |
+
return scanCachedRepo("/fake/cacheDir/models--hello-world--name");
|
65 |
+
}).rejects.toThrow("Snapshots dir doesn't exist in cached repo");
|
66 |
+
});
|
67 |
+
|
68 |
+
test("should properly parse the repository name", async () => {
|
69 |
+
const repoPath = "/fake/cacheDir/models--hello-world--name";
|
70 |
+
vi.mocked(readdir).mockResolvedValue([]);
|
71 |
+
vi.mocked(stat).mockResolvedValue({
|
72 |
+
isDirectory: () => true,
|
73 |
+
} as Stats);
|
74 |
+
|
75 |
+
const result = await scanCachedRepo(repoPath);
|
76 |
+
expect(readdir).toHaveBeenCalledWith(join(repoPath, "refs"), {
|
77 |
+
withFileTypes: true,
|
78 |
+
});
|
79 |
+
|
80 |
+
expect(result.id.name).toBe("hello-world/name");
|
81 |
+
expect(result.id.type).toBe("model");
|
82 |
+
});
|
83 |
+
});
|
84 |
+
|
85 |
+
describe("scanSnapshotDir", () => {
|
86 |
+
test("should scan a valid snapshot directory", async () => {
|
87 |
+
const cachedFiles: CachedFileInfo[] = [];
|
88 |
+
const blobStats = new Map<string, Stats>();
|
89 |
+
vi.mocked(readdir).mockResolvedValueOnce([{ name: "file1", isDirectory: () => false } as Dirent]);
|
90 |
+
|
91 |
+
vi.mocked(realpath).mockResolvedValueOnce("/fake/realpath");
|
92 |
+
vi.mocked(lstat).mockResolvedValueOnce({ size: 1024, atimeMs: Date.now(), mtimeMs: Date.now() } as Stats);
|
93 |
+
|
94 |
+
await scanSnapshotDir("/fake/revision", cachedFiles, blobStats);
|
95 |
+
|
96 |
+
expect(cachedFiles).toHaveLength(1);
|
97 |
+
expect(blobStats.size).toBe(1);
|
98 |
+
});
|
99 |
+
});
|
100 |
+
|
101 |
+
describe("getBlobStat", () => {
|
102 |
+
test("should retrieve blob stat if already cached", async () => {
|
103 |
+
const blobStats = new Map<string, Stats>([["/fake/blob", { size: 1024 } as Stats]]);
|
104 |
+
const result = await getBlobStat("/fake/blob", blobStats);
|
105 |
+
|
106 |
+
expect(lstat).not.toHaveBeenCalled();
|
107 |
+
expect(result.size).toBe(1024);
|
108 |
+
});
|
109 |
+
|
110 |
+
test("should fetch and cache blob stat if not cached", async () => {
|
111 |
+
const blobStats = new Map();
|
112 |
+
vi.mocked(lstat).mockResolvedValueOnce({ size: 2048 } as Stats);
|
113 |
+
|
114 |
+
const result = await getBlobStat("/fake/blob", blobStats);
|
115 |
+
|
116 |
+
expect(result.size).toBe(2048);
|
117 |
+
expect(blobStats.size).toBe(1);
|
118 |
+
});
|
119 |
+
});
|
120 |
+
|
121 |
+
describe("parseRepoType", () => {
|
122 |
+
test("should parse models repo type", () => {
|
123 |
+
expect(parseRepoType("models")).toBe("model");
|
124 |
+
});
|
125 |
+
|
126 |
+
test("should parse dataset repo type", () => {
|
127 |
+
expect(parseRepoType("datasets")).toBe("dataset");
|
128 |
+
});
|
129 |
+
|
130 |
+
test("should parse space repo type", () => {
|
131 |
+
expect(parseRepoType("spaces")).toBe("space");
|
132 |
+
});
|
133 |
+
|
134 |
+
test("should throw an error for invalid repo type", () => {
|
135 |
+
expect(() => parseRepoType("invalid")).toThrowError("Invalid repo type: invalid");
|
136 |
+
});
|
137 |
+
});
|
lib/cache-management.ts
ADDED
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { homedir } from "node:os";
|
2 |
+
import { join, basename } from "node:path";
|
3 |
+
import { stat, readdir, readFile, realpath, lstat } from "node:fs/promises";
|
4 |
+
import type { Stats } from "node:fs";
|
5 |
+
import type { RepoType, RepoId } from "../types/public";
|
6 |
+
|
7 |
+
function getDefaultHome(): string {
|
8 |
+
return join(homedir(), ".cache");
|
9 |
+
}
|
10 |
+
|
11 |
+
function getDefaultCachePath(): string {
|
12 |
+
return join(process.env["HF_HOME"] ?? join(process.env["XDG_CACHE_HOME"] ?? getDefaultHome(), "huggingface"), "hub");
|
13 |
+
}
|
14 |
+
|
15 |
+
function getHuggingFaceHubCache(): string {
|
16 |
+
return process.env["HUGGINGFACE_HUB_CACHE"] ?? getDefaultCachePath();
|
17 |
+
}
|
18 |
+
|
19 |
+
export function getHFHubCachePath(): string {
|
20 |
+
return process.env["HF_HUB_CACHE"] ?? getHuggingFaceHubCache();
|
21 |
+
}
|
22 |
+
|
23 |
+
const FILES_TO_IGNORE: string[] = [".DS_Store"];
|
24 |
+
|
25 |
+
export const REPO_ID_SEPARATOR: string = "--";
|
26 |
+
|
27 |
+
export function getRepoFolderName({ name, type }: RepoId): string {
|
28 |
+
const parts = [`${type}s`, ...name.split("/")];
|
29 |
+
return parts.join(REPO_ID_SEPARATOR);
|
30 |
+
}
|
31 |
+
|
32 |
+
export interface CachedFileInfo {
|
33 |
+
path: string;
|
34 |
+
/**
|
35 |
+
* Underlying file - which `path` is symlinked to
|
36 |
+
*/
|
37 |
+
blob: {
|
38 |
+
size: number;
|
39 |
+
path: string;
|
40 |
+
lastModifiedAt: Date;
|
41 |
+
lastAccessedAt: Date;
|
42 |
+
};
|
43 |
+
}
|
44 |
+
|
45 |
+
export interface CachedRevisionInfo {
|
46 |
+
commitOid: string;
|
47 |
+
path: string;
|
48 |
+
size: number;
|
49 |
+
files: CachedFileInfo[];
|
50 |
+
refs: string[];
|
51 |
+
|
52 |
+
lastModifiedAt: Date;
|
53 |
+
}
|
54 |
+
|
55 |
+
export interface CachedRepoInfo {
|
56 |
+
id: RepoId;
|
57 |
+
path: string;
|
58 |
+
size: number;
|
59 |
+
filesCount: number;
|
60 |
+
revisions: CachedRevisionInfo[];
|
61 |
+
|
62 |
+
lastAccessedAt: Date;
|
63 |
+
lastModifiedAt: Date;
|
64 |
+
}
|
65 |
+
|
66 |
+
export interface HFCacheInfo {
|
67 |
+
size: number;
|
68 |
+
repos: CachedRepoInfo[];
|
69 |
+
warnings: Error[];
|
70 |
+
}
|
71 |
+
|
72 |
+
export async function scanCacheDir(cacheDir: string | undefined = undefined): Promise<HFCacheInfo> {
|
73 |
+
if (!cacheDir) cacheDir = getHFHubCachePath();
|
74 |
+
|
75 |
+
const s = await stat(cacheDir);
|
76 |
+
if (!s.isDirectory()) {
|
77 |
+
throw new Error(
|
78 |
+
`Scan cache expects a directory but found a file: ${cacheDir}. Please use \`cacheDir\` argument or set \`HF_HUB_CACHE\` environment variable.`
|
79 |
+
);
|
80 |
+
}
|
81 |
+
|
82 |
+
const repos: CachedRepoInfo[] = [];
|
83 |
+
const warnings: Error[] = [];
|
84 |
+
|
85 |
+
const directories = await readdir(cacheDir);
|
86 |
+
for (const repo of directories) {
|
87 |
+
// skip .locks folder
|
88 |
+
if (repo === ".locks") continue;
|
89 |
+
|
90 |
+
// get the absolute path of the repo
|
91 |
+
const absolute = join(cacheDir, repo);
|
92 |
+
|
93 |
+
// ignore non-directory element
|
94 |
+
const s = await stat(absolute);
|
95 |
+
if (!s.isDirectory()) {
|
96 |
+
continue;
|
97 |
+
}
|
98 |
+
|
99 |
+
try {
|
100 |
+
const cached = await scanCachedRepo(absolute);
|
101 |
+
repos.push(cached);
|
102 |
+
} catch (err: unknown) {
|
103 |
+
warnings.push(err as Error);
|
104 |
+
}
|
105 |
+
}
|
106 |
+
|
107 |
+
return {
|
108 |
+
repos: repos,
|
109 |
+
size: [...repos.values()].reduce((sum, repo) => sum + repo.size, 0),
|
110 |
+
warnings: warnings,
|
111 |
+
};
|
112 |
+
}
|
113 |
+
|
114 |
+
export async function scanCachedRepo(repoPath: string): Promise<CachedRepoInfo> {
|
115 |
+
// get the directory name
|
116 |
+
const name = basename(repoPath);
|
117 |
+
if (!name.includes(REPO_ID_SEPARATOR)) {
|
118 |
+
throw new Error(`Repo path is not a valid HuggingFace cache directory: ${name}`);
|
119 |
+
}
|
120 |
+
|
121 |
+
// parse the repoId from directory name
|
122 |
+
const [type, ...remaining] = name.split(REPO_ID_SEPARATOR);
|
123 |
+
const repoType = parseRepoType(type);
|
124 |
+
const repoId = remaining.join("/");
|
125 |
+
|
126 |
+
const snapshotsPath = join(repoPath, "snapshots");
|
127 |
+
const refsPath = join(repoPath, "refs");
|
128 |
+
|
129 |
+
const snapshotStat = await stat(snapshotsPath);
|
130 |
+
if (!snapshotStat.isDirectory()) {
|
131 |
+
throw new Error(`Snapshots dir doesn't exist in cached repo ${snapshotsPath}`);
|
132 |
+
}
|
133 |
+
|
134 |
+
// Check if the refs directory exists and scan it
|
135 |
+
const refsByHash: Map<string, string[]> = new Map();
|
136 |
+
const refsStat = await stat(refsPath);
|
137 |
+
if (refsStat.isDirectory()) {
|
138 |
+
await scanRefsDir(refsPath, refsByHash);
|
139 |
+
}
|
140 |
+
|
141 |
+
// Scan snapshots directory and collect cached revision information
|
142 |
+
const cachedRevisions: CachedRevisionInfo[] = [];
|
143 |
+
const blobStats: Map<string, Stats> = new Map(); // Store blob stats
|
144 |
+
|
145 |
+
const snapshotDirs = await readdir(snapshotsPath);
|
146 |
+
for (const dir of snapshotDirs) {
|
147 |
+
if (FILES_TO_IGNORE.includes(dir)) continue; // Ignore unwanted files
|
148 |
+
|
149 |
+
const revisionPath = join(snapshotsPath, dir);
|
150 |
+
const revisionStat = await stat(revisionPath);
|
151 |
+
if (!revisionStat.isDirectory()) {
|
152 |
+
throw new Error(`Snapshots folder corrupted. Found a file: ${revisionPath}`);
|
153 |
+
}
|
154 |
+
|
155 |
+
const cachedFiles: CachedFileInfo[] = [];
|
156 |
+
await scanSnapshotDir(revisionPath, cachedFiles, blobStats);
|
157 |
+
|
158 |
+
const revisionLastModified =
|
159 |
+
cachedFiles.length > 0
|
160 |
+
? Math.max(...[...cachedFiles].map((file) => file.blob.lastModifiedAt.getTime()))
|
161 |
+
: revisionStat.mtimeMs;
|
162 |
+
|
163 |
+
cachedRevisions.push({
|
164 |
+
commitOid: dir,
|
165 |
+
files: cachedFiles,
|
166 |
+
refs: refsByHash.get(dir) || [],
|
167 |
+
size: [...cachedFiles].reduce((sum, file) => sum + file.blob.size, 0),
|
168 |
+
path: revisionPath,
|
169 |
+
lastModifiedAt: new Date(revisionLastModified),
|
170 |
+
});
|
171 |
+
|
172 |
+
refsByHash.delete(dir);
|
173 |
+
}
|
174 |
+
|
175 |
+
// Verify that all refs refer to a valid revision
|
176 |
+
if (refsByHash.size > 0) {
|
177 |
+
throw new Error(
|
178 |
+
`Reference(s) refer to missing commit hashes: ${JSON.stringify(Object.fromEntries(refsByHash))} (${repoPath})`
|
179 |
+
);
|
180 |
+
}
|
181 |
+
|
182 |
+
const repoStats = await stat(repoPath);
|
183 |
+
const repoLastAccessed =
|
184 |
+
blobStats.size > 0 ? Math.max(...[...blobStats.values()].map((stat) => stat.atimeMs)) : repoStats.atimeMs;
|
185 |
+
|
186 |
+
const repoLastModified =
|
187 |
+
blobStats.size > 0 ? Math.max(...[...blobStats.values()].map((stat) => stat.mtimeMs)) : repoStats.mtimeMs;
|
188 |
+
|
189 |
+
// Return the constructed CachedRepoInfo object
|
190 |
+
return {
|
191 |
+
id: {
|
192 |
+
name: repoId,
|
193 |
+
type: repoType,
|
194 |
+
},
|
195 |
+
path: repoPath,
|
196 |
+
filesCount: blobStats.size,
|
197 |
+
revisions: cachedRevisions,
|
198 |
+
size: [...blobStats.values()].reduce((sum, stat) => sum + stat.size, 0),
|
199 |
+
lastAccessedAt: new Date(repoLastAccessed),
|
200 |
+
lastModifiedAt: new Date(repoLastModified),
|
201 |
+
};
|
202 |
+
}
|
203 |
+
|
204 |
+
export async function scanRefsDir(refsPath: string, refsByHash: Map<string, string[]>): Promise<void> {
|
205 |
+
const refFiles = await readdir(refsPath, { withFileTypes: true });
|
206 |
+
for (const refFile of refFiles) {
|
207 |
+
const refFilePath = join(refsPath, refFile.name);
|
208 |
+
if (refFile.isDirectory()) continue; // Skip directories
|
209 |
+
|
210 |
+
const commitHash = await readFile(refFilePath, "utf-8");
|
211 |
+
const refName = refFile.name;
|
212 |
+
if (!refsByHash.has(commitHash)) {
|
213 |
+
refsByHash.set(commitHash, []);
|
214 |
+
}
|
215 |
+
refsByHash.get(commitHash)?.push(refName);
|
216 |
+
}
|
217 |
+
}
|
218 |
+
|
219 |
+
export async function scanSnapshotDir(
|
220 |
+
revisionPath: string,
|
221 |
+
cachedFiles: CachedFileInfo[],
|
222 |
+
blobStats: Map<string, Stats>
|
223 |
+
): Promise<void> {
|
224 |
+
const files = await readdir(revisionPath, { withFileTypes: true });
|
225 |
+
for (const file of files) {
|
226 |
+
if (file.isDirectory()) continue; // Skip directories
|
227 |
+
|
228 |
+
const filePath = join(revisionPath, file.name);
|
229 |
+
const blobPath = await realpath(filePath);
|
230 |
+
const blobStat = await getBlobStat(blobPath, blobStats);
|
231 |
+
|
232 |
+
cachedFiles.push({
|
233 |
+
path: filePath,
|
234 |
+
blob: {
|
235 |
+
path: blobPath,
|
236 |
+
size: blobStat.size,
|
237 |
+
lastAccessedAt: new Date(blobStat.atimeMs),
|
238 |
+
lastModifiedAt: new Date(blobStat.mtimeMs),
|
239 |
+
},
|
240 |
+
});
|
241 |
+
}
|
242 |
+
}
|
243 |
+
|
244 |
+
export async function getBlobStat(blobPath: string, blobStats: Map<string, Stats>): Promise<Stats> {
|
245 |
+
const blob = blobStats.get(blobPath);
|
246 |
+
if (!blob) {
|
247 |
+
const statResult = await lstat(blobPath);
|
248 |
+
blobStats.set(blobPath, statResult);
|
249 |
+
return statResult;
|
250 |
+
}
|
251 |
+
return blob;
|
252 |
+
}
|
253 |
+
|
254 |
+
export function parseRepoType(type: string): RepoType {
|
255 |
+
switch (type) {
|
256 |
+
case "models":
|
257 |
+
return "model";
|
258 |
+
case "datasets":
|
259 |
+
return "dataset";
|
260 |
+
case "spaces":
|
261 |
+
return "space";
|
262 |
+
default:
|
263 |
+
throw new TypeError(`Invalid repo type: ${type}`);
|
264 |
+
}
|
265 |
+
}
|
lib/check-repo-access.spec.ts
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { assert, describe, expect, it } from "vitest";
|
2 |
+
import { checkRepoAccess } from "./check-repo-access";
|
3 |
+
import { HubApiError } from "../error";
|
4 |
+
import { TEST_ACCESS_TOKEN, TEST_HUB_URL } from "../test/consts";
|
5 |
+
|
6 |
+
describe("checkRepoAccess", () => {
|
7 |
+
it("should throw 401 when accessing unexisting repo unauthenticated", async () => {
|
8 |
+
try {
|
9 |
+
await checkRepoAccess({ repo: { name: "i--d/dont", type: "model" } });
|
10 |
+
assert(false, "should have thrown");
|
11 |
+
} catch (err) {
|
12 |
+
expect(err).toBeInstanceOf(HubApiError);
|
13 |
+
expect((err as HubApiError).statusCode).toBe(401);
|
14 |
+
}
|
15 |
+
});
|
16 |
+
|
17 |
+
it("should throw 404 when accessing unexisting repo authenticated", async () => {
|
18 |
+
try {
|
19 |
+
await checkRepoAccess({
|
20 |
+
repo: { name: "i--d/dont", type: "model" },
|
21 |
+
hubUrl: TEST_HUB_URL,
|
22 |
+
accessToken: TEST_ACCESS_TOKEN,
|
23 |
+
});
|
24 |
+
assert(false, "should have thrown");
|
25 |
+
} catch (err) {
|
26 |
+
expect(err).toBeInstanceOf(HubApiError);
|
27 |
+
expect((err as HubApiError).statusCode).toBe(404);
|
28 |
+
}
|
29 |
+
});
|
30 |
+
|
31 |
+
it("should not throw when accessing public repo", async () => {
|
32 |
+
await checkRepoAccess({ repo: { name: "openai-community/gpt2", type: "model" } });
|
33 |
+
});
|
34 |
+
});
|
lib/check-repo-access.ts
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { HUB_URL } from "../consts";
|
2 |
+
// eslint-disable-next-line @typescript-eslint/no-unused-vars
|
3 |
+
import { createApiError, type HubApiError } from "../error";
|
4 |
+
import type { CredentialsParams, RepoDesignation } from "../types/public";
|
5 |
+
import { checkCredentials } from "../utils/checkCredentials";
|
6 |
+
import { toRepoId } from "../utils/toRepoId";
|
7 |
+
|
8 |
+
/**
|
9 |
+
* Check if we have read access to a repository.
|
10 |
+
*
|
11 |
+
* Throw a {@link HubApiError} error if we don't have access. HubApiError.statusCode will be 401, 403 or 404.
|
12 |
+
*/
|
13 |
+
export async function checkRepoAccess(
|
14 |
+
params: {
|
15 |
+
repo: RepoDesignation;
|
16 |
+
hubUrl?: string;
|
17 |
+
fetch?: typeof fetch;
|
18 |
+
} & Partial<CredentialsParams>
|
19 |
+
): Promise<void> {
|
20 |
+
const accessToken = params && checkCredentials(params);
|
21 |
+
const repoId = toRepoId(params.repo);
|
22 |
+
|
23 |
+
const response = await (params.fetch || fetch)(`${params?.hubUrl || HUB_URL}/api/${repoId.type}s/${repoId.name}`, {
|
24 |
+
headers: {
|
25 |
+
...(accessToken ? { Authorization: `Bearer ${accessToken}` } : {}),
|
26 |
+
},
|
27 |
+
});
|
28 |
+
|
29 |
+
if (!response.ok) {
|
30 |
+
throw await createApiError(response);
|
31 |
+
}
|
32 |
+
}
|
lib/commit.spec.ts
ADDED
@@ -0,0 +1,271 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { assert, it, describe } from "vitest";
|
2 |
+
|
3 |
+
import { TEST_HUB_URL, TEST_ACCESS_TOKEN, TEST_USER } from "../test/consts";
|
4 |
+
import type { RepoId } from "../types/public";
|
5 |
+
import type { CommitFile } from "./commit";
|
6 |
+
import { commit } from "./commit";
|
7 |
+
import { createRepo } from "./create-repo";
|
8 |
+
import { deleteRepo } from "./delete-repo";
|
9 |
+
import { downloadFile } from "./download-file";
|
10 |
+
import { fileDownloadInfo } from "./file-download-info";
|
11 |
+
import { insecureRandomString } from "../utils/insecureRandomString";
|
12 |
+
import { isFrontend } from "../utils/isFrontend";
|
13 |
+
|
14 |
+
const lfsContent = "O123456789".repeat(100_000);
|
15 |
+
|
16 |
+
describe("commit", () => {
|
17 |
+
it("should commit to a repo with blobs", async function () {
|
18 |
+
const tokenizerJsonUrl = new URL(
|
19 |
+
"https://huggingface.co/spaces/aschen/push-model-from-web/raw/main/mobilenet/model.json"
|
20 |
+
);
|
21 |
+
const repoName = `${TEST_USER}/TEST-${insecureRandomString()}`;
|
22 |
+
const repo: RepoId = {
|
23 |
+
name: repoName,
|
24 |
+
type: "model",
|
25 |
+
};
|
26 |
+
|
27 |
+
await createRepo({
|
28 |
+
accessToken: TEST_ACCESS_TOKEN,
|
29 |
+
hubUrl: TEST_HUB_URL,
|
30 |
+
repo,
|
31 |
+
license: "mit",
|
32 |
+
});
|
33 |
+
|
34 |
+
try {
|
35 |
+
const readme1 = await downloadFile({ repo, path: "README.md", hubUrl: TEST_HUB_URL });
|
36 |
+
assert(readme1, "Readme doesn't exist");
|
37 |
+
|
38 |
+
const nodeOperation: CommitFile[] = isFrontend
|
39 |
+
? []
|
40 |
+
: [
|
41 |
+
{
|
42 |
+
operation: "addOrUpdate",
|
43 |
+
path: "tsconfig.json",
|
44 |
+
content: (await import("node:url")).pathToFileURL("./tsconfig.json") as URL,
|
45 |
+
},
|
46 |
+
];
|
47 |
+
|
48 |
+
await commit({
|
49 |
+
repo,
|
50 |
+
title: "Some commit",
|
51 |
+
accessToken: TEST_ACCESS_TOKEN,
|
52 |
+
hubUrl: TEST_HUB_URL,
|
53 |
+
operations: [
|
54 |
+
{
|
55 |
+
operation: "addOrUpdate",
|
56 |
+
content: new Blob(["This is me"]),
|
57 |
+
path: "test.txt",
|
58 |
+
},
|
59 |
+
{
|
60 |
+
operation: "addOrUpdate",
|
61 |
+
content: new Blob([lfsContent]),
|
62 |
+
path: "test.lfs.txt",
|
63 |
+
},
|
64 |
+
...nodeOperation,
|
65 |
+
{
|
66 |
+
operation: "addOrUpdate",
|
67 |
+
content: tokenizerJsonUrl,
|
68 |
+
path: "lamaral.json",
|
69 |
+
},
|
70 |
+
{
|
71 |
+
operation: "delete",
|
72 |
+
path: "README.md",
|
73 |
+
},
|
74 |
+
],
|
75 |
+
// To test web workers in the front-end
|
76 |
+
useWebWorkers: { minSize: 5_000 },
|
77 |
+
});
|
78 |
+
|
79 |
+
const fileContent = await downloadFile({ repo, path: "test.txt", hubUrl: TEST_HUB_URL });
|
80 |
+
assert.strictEqual(await fileContent?.text(), "This is me");
|
81 |
+
|
82 |
+
const lfsFileContent = await downloadFile({ repo, path: "test.lfs.txt", hubUrl: TEST_HUB_URL });
|
83 |
+
assert.strictEqual(await lfsFileContent?.text(), lfsContent);
|
84 |
+
|
85 |
+
const lfsFileUrl = `${TEST_HUB_URL}/${repoName}/raw/main/test.lfs.txt`;
|
86 |
+
const lfsFilePointer = await fetch(lfsFileUrl);
|
87 |
+
assert.strictEqual(lfsFilePointer.status, 200);
|
88 |
+
assert.strictEqual(
|
89 |
+
(await lfsFilePointer.text()).trim(),
|
90 |
+
`
|
91 |
+
version https://git-lfs.github.com/spec/v1
|
92 |
+
oid sha256:a3bbce7ee1df7233d85b5f4d60faa3755f93f537804f8b540c72b0739239ddf8
|
93 |
+
size ${lfsContent.length}
|
94 |
+
`.trim()
|
95 |
+
);
|
96 |
+
|
97 |
+
if (!isFrontend) {
|
98 |
+
const fileUrlContent = await downloadFile({ repo, path: "tsconfig.json", hubUrl: TEST_HUB_URL });
|
99 |
+
assert.strictEqual(
|
100 |
+
await fileUrlContent?.text(),
|
101 |
+
(await import("node:fs")).readFileSync("./tsconfig.json", "utf-8")
|
102 |
+
);
|
103 |
+
}
|
104 |
+
|
105 |
+
const webResourceContent = await downloadFile({ repo, path: "lamaral.json", hubUrl: TEST_HUB_URL });
|
106 |
+
assert.strictEqual(await webResourceContent?.text(), await (await fetch(tokenizerJsonUrl)).text());
|
107 |
+
|
108 |
+
const readme2 = await downloadFile({ repo, path: "README.md", hubUrl: TEST_HUB_URL });
|
109 |
+
assert.strictEqual(readme2, null);
|
110 |
+
} finally {
|
111 |
+
await deleteRepo({
|
112 |
+
repo: {
|
113 |
+
name: repoName,
|
114 |
+
type: "model",
|
115 |
+
},
|
116 |
+
hubUrl: TEST_HUB_URL,
|
117 |
+
credentials: { accessToken: TEST_ACCESS_TOKEN },
|
118 |
+
});
|
119 |
+
}
|
120 |
+
}, 60_000);
|
121 |
+
|
122 |
+
it("should commit a full repo from HF with web urls", async function () {
|
123 |
+
const repoName = `${TEST_USER}/TEST-${insecureRandomString()}`;
|
124 |
+
const repo: RepoId = {
|
125 |
+
name: repoName,
|
126 |
+
type: "model",
|
127 |
+
};
|
128 |
+
|
129 |
+
await createRepo({
|
130 |
+
accessToken: TEST_ACCESS_TOKEN,
|
131 |
+
repo,
|
132 |
+
hubUrl: TEST_HUB_URL,
|
133 |
+
});
|
134 |
+
|
135 |
+
try {
|
136 |
+
const FILES_TO_UPLOAD = [
|
137 |
+
`https://huggingface.co/spaces/huggingfacejs/push-model-from-web/resolve/main/mobilenet/model.json`,
|
138 |
+
`https://huggingface.co/spaces/huggingfacejs/push-model-from-web/resolve/main/mobilenet/group1-shard1of2`,
|
139 |
+
`https://huggingface.co/spaces/huggingfacejs/push-model-from-web/resolve/main/mobilenet/group1-shard2of2`,
|
140 |
+
`https://huggingface.co/spaces/huggingfacejs/push-model-from-web/resolve/main/mobilenet/coffee.jpg`,
|
141 |
+
`https://huggingface.co/spaces/huggingfacejs/push-model-from-web/resolve/main/mobilenet/README.md`,
|
142 |
+
];
|
143 |
+
|
144 |
+
const operations: CommitFile[] = await Promise.all(
|
145 |
+
FILES_TO_UPLOAD.map(async (file) => {
|
146 |
+
return {
|
147 |
+
operation: "addOrUpdate",
|
148 |
+
path: file.slice(file.indexOf("main/") + "main/".length),
|
149 |
+
// upload remote file
|
150 |
+
content: new URL(file),
|
151 |
+
};
|
152 |
+
})
|
153 |
+
);
|
154 |
+
await commit({
|
155 |
+
repo,
|
156 |
+
accessToken: TEST_ACCESS_TOKEN,
|
157 |
+
hubUrl: TEST_HUB_URL,
|
158 |
+
title: "upload model",
|
159 |
+
operations,
|
160 |
+
});
|
161 |
+
|
162 |
+
const LFSSize = (await fileDownloadInfo({ repo, path: "mobilenet/group1-shard1of2", hubUrl: TEST_HUB_URL }))
|
163 |
+
?.size;
|
164 |
+
|
165 |
+
assert.strictEqual(LFSSize, 4_194_304);
|
166 |
+
|
167 |
+
const pointerFile = await downloadFile({
|
168 |
+
repo,
|
169 |
+
path: "mobilenet/group1-shard1of2",
|
170 |
+
raw: true,
|
171 |
+
hubUrl: TEST_HUB_URL,
|
172 |
+
});
|
173 |
+
|
174 |
+
// Make sure SHA is computed properly as well
|
175 |
+
assert.strictEqual(
|
176 |
+
(await pointerFile?.text())?.trim(),
|
177 |
+
`
|
178 |
+
version https://git-lfs.github.com/spec/v1
|
179 |
+
oid sha256:3fb621eb9b37478239504ee083042d5b18699e8b8618e569478b03b119a85a69
|
180 |
+
size 4194304
|
181 |
+
`.trim()
|
182 |
+
);
|
183 |
+
} finally {
|
184 |
+
await deleteRepo({
|
185 |
+
repo: {
|
186 |
+
name: repoName,
|
187 |
+
type: "model",
|
188 |
+
},
|
189 |
+
hubUrl: TEST_HUB_URL,
|
190 |
+
credentials: { accessToken: TEST_ACCESS_TOKEN },
|
191 |
+
});
|
192 |
+
}
|
193 |
+
// https://huggingfacejs-push-model-from-web.hf.space/
|
194 |
+
}, 60_000);
|
195 |
+
|
196 |
+
it("should be able to create a PR and then commit to it", async function () {
|
197 |
+
const repoName = `${TEST_USER}/TEST-${insecureRandomString()}`;
|
198 |
+
const repo: RepoId = {
|
199 |
+
name: repoName,
|
200 |
+
type: "model",
|
201 |
+
};
|
202 |
+
|
203 |
+
await createRepo({
|
204 |
+
credentials: {
|
205 |
+
accessToken: TEST_ACCESS_TOKEN,
|
206 |
+
},
|
207 |
+
repo,
|
208 |
+
hubUrl: TEST_HUB_URL,
|
209 |
+
});
|
210 |
+
|
211 |
+
try {
|
212 |
+
const pr = await commit({
|
213 |
+
repo,
|
214 |
+
credentials: {
|
215 |
+
accessToken: TEST_ACCESS_TOKEN,
|
216 |
+
},
|
217 |
+
hubUrl: TEST_HUB_URL,
|
218 |
+
title: "Create PR",
|
219 |
+
isPullRequest: true,
|
220 |
+
operations: [
|
221 |
+
{
|
222 |
+
operation: "addOrUpdate",
|
223 |
+
content: new Blob(["This is me"]),
|
224 |
+
path: "test.txt",
|
225 |
+
},
|
226 |
+
],
|
227 |
+
});
|
228 |
+
|
229 |
+
if (!pr) {
|
230 |
+
throw new Error("PR creation failed");
|
231 |
+
}
|
232 |
+
|
233 |
+
if (!pr.pullRequestUrl) {
|
234 |
+
throw new Error("No pull request url");
|
235 |
+
}
|
236 |
+
|
237 |
+
const prNumber = pr.pullRequestUrl.split("/").pop();
|
238 |
+
const prRef = `refs/pr/${prNumber}`;
|
239 |
+
|
240 |
+
await commit({
|
241 |
+
repo,
|
242 |
+
credentials: {
|
243 |
+
accessToken: TEST_ACCESS_TOKEN,
|
244 |
+
},
|
245 |
+
hubUrl: TEST_HUB_URL,
|
246 |
+
branch: prRef,
|
247 |
+
title: "Some commit",
|
248 |
+
operations: [
|
249 |
+
{
|
250 |
+
operation: "addOrUpdate",
|
251 |
+
content: new URL(
|
252 |
+
`https://huggingface.co/spaces/huggingfacejs/push-model-from-web/resolve/main/mobilenet/group1-shard1of2`
|
253 |
+
),
|
254 |
+
path: "mobilenet/group1-shard1of2",
|
255 |
+
},
|
256 |
+
],
|
257 |
+
});
|
258 |
+
|
259 |
+
assert(commit, "PR commit failed");
|
260 |
+
} finally {
|
261 |
+
await deleteRepo({
|
262 |
+
repo: {
|
263 |
+
name: repoName,
|
264 |
+
type: "model",
|
265 |
+
},
|
266 |
+
hubUrl: TEST_HUB_URL,
|
267 |
+
credentials: { accessToken: TEST_ACCESS_TOKEN },
|
268 |
+
});
|
269 |
+
}
|
270 |
+
}, 60_000);
|
271 |
+
});
|
lib/commit.ts
ADDED
@@ -0,0 +1,609 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { HUB_URL } from "../consts";
|
2 |
+
import { HubApiError, createApiError, InvalidApiResponseFormatError } from "../error";
|
3 |
+
import type {
|
4 |
+
ApiCommitHeader,
|
5 |
+
ApiCommitLfsFile,
|
6 |
+
ApiCommitOperation,
|
7 |
+
ApiLfsBatchRequest,
|
8 |
+
ApiLfsBatchResponse,
|
9 |
+
ApiLfsCompleteMultipartRequest,
|
10 |
+
ApiPreuploadRequest,
|
11 |
+
ApiPreuploadResponse,
|
12 |
+
} from "../types/api/api-commit";
|
13 |
+
import type { CredentialsParams, RepoDesignation } from "../types/public";
|
14 |
+
import { checkCredentials } from "../utils/checkCredentials";
|
15 |
+
import { chunk } from "../utils/chunk";
|
16 |
+
import { promisesQueue } from "../utils/promisesQueue";
|
17 |
+
import { promisesQueueStreaming } from "../utils/promisesQueueStreaming";
|
18 |
+
import { sha256 } from "../utils/sha256";
|
19 |
+
import { toRepoId } from "../utils/toRepoId";
|
20 |
+
import { WebBlob } from "../utils/WebBlob";
|
21 |
+
import { eventToGenerator } from "../utils/eventToGenerator";
|
22 |
+
import { base64FromBytes } from "../utils/base64FromBytes";
|
23 |
+
import { isFrontend } from "../utils/isFrontend";
|
24 |
+
import { createBlobs } from "../utils/createBlobs";
|
25 |
+
|
26 |
+
const CONCURRENT_SHAS = 5;
|
27 |
+
const CONCURRENT_LFS_UPLOADS = 5;
|
28 |
+
const MULTIPART_PARALLEL_UPLOAD = 5;
|
29 |
+
|
30 |
+
export interface CommitDeletedEntry {
|
31 |
+
operation: "delete";
|
32 |
+
path: string;
|
33 |
+
}
|
34 |
+
|
35 |
+
export type ContentSource = Blob | URL;
|
36 |
+
|
37 |
+
export interface CommitFile {
|
38 |
+
operation: "addOrUpdate";
|
39 |
+
path: string;
|
40 |
+
content: ContentSource;
|
41 |
+
// forceLfs?: boolean
|
42 |
+
}
|
43 |
+
|
44 |
+
type CommitBlob = Omit<CommitFile, "content"> & { content: Blob };
|
45 |
+
|
46 |
+
// TODO: find a nice way to handle LFS & non-LFS files in an uniform manner, see https://github.com/huggingface/moon-landing/issues/4370
|
47 |
+
// export type CommitRenameFile = {
|
48 |
+
// operation: "rename";
|
49 |
+
// path: string;
|
50 |
+
// oldPath: string;
|
51 |
+
// content?: ContentSource;
|
52 |
+
// };
|
53 |
+
|
54 |
+
export type CommitOperation = CommitDeletedEntry | CommitFile /* | CommitRenameFile */;
|
55 |
+
type CommitBlobOperation = Exclude<CommitOperation, CommitFile> | CommitBlob;
|
56 |
+
|
57 |
+
export type CommitParams = {
|
58 |
+
title: string;
|
59 |
+
description?: string;
|
60 |
+
repo: RepoDesignation;
|
61 |
+
operations: CommitOperation[];
|
62 |
+
/** @default "main" */
|
63 |
+
branch?: string;
|
64 |
+
/**
|
65 |
+
* Parent commit. Optional
|
66 |
+
*
|
67 |
+
* - When opening a PR: will use parentCommit as the parent commit
|
68 |
+
* - When committing on a branch: Will make sure that there were no intermediate commits
|
69 |
+
*/
|
70 |
+
parentCommit?: string;
|
71 |
+
isPullRequest?: boolean;
|
72 |
+
hubUrl?: string;
|
73 |
+
/**
|
74 |
+
* Whether to use web workers to compute SHA256 hashes.
|
75 |
+
*
|
76 |
+
* @default false
|
77 |
+
*/
|
78 |
+
useWebWorkers?: boolean | { minSize?: number; poolSize?: number };
|
79 |
+
/**
|
80 |
+
* Maximum depth of folders to upload. Files deeper than this will be ignored
|
81 |
+
*
|
82 |
+
* @default 5
|
83 |
+
*/
|
84 |
+
maxFolderDepth?: number;
|
85 |
+
/**
|
86 |
+
* Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
|
87 |
+
*/
|
88 |
+
fetch?: typeof fetch;
|
89 |
+
abortSignal?: AbortSignal;
|
90 |
+
// Credentials are optional due to custom fetch functions or cookie auth
|
91 |
+
} & Partial<CredentialsParams>;
|
92 |
+
|
93 |
+
export interface CommitOutput {
|
94 |
+
pullRequestUrl?: string;
|
95 |
+
commit: {
|
96 |
+
oid: string;
|
97 |
+
url: string;
|
98 |
+
};
|
99 |
+
hookOutput: string;
|
100 |
+
}
|
101 |
+
|
102 |
+
function isFileOperation(op: CommitOperation): op is CommitBlob {
|
103 |
+
const ret = op.operation === "addOrUpdate";
|
104 |
+
|
105 |
+
if (ret && !(op.content instanceof Blob)) {
|
106 |
+
throw new TypeError("Precondition failed: op.content should be a Blob");
|
107 |
+
}
|
108 |
+
|
109 |
+
return ret;
|
110 |
+
}
|
111 |
+
|
112 |
+
export type CommitProgressEvent =
|
113 |
+
| {
|
114 |
+
event: "phase";
|
115 |
+
phase: "preuploading" | "uploadingLargeFiles" | "committing";
|
116 |
+
}
|
117 |
+
| {
|
118 |
+
event: "fileProgress";
|
119 |
+
path: string;
|
120 |
+
progress: number;
|
121 |
+
state: "hashing" | "uploading";
|
122 |
+
};
|
123 |
+
|
124 |
+
/**
|
125 |
+
* Internal function for now, used by commit.
|
126 |
+
*
|
127 |
+
* Can be exposed later to offer fine-tuned progress info
|
128 |
+
*/
|
129 |
+
export async function* commitIter(params: CommitParams): AsyncGenerator<CommitProgressEvent, CommitOutput> {
|
130 |
+
const accessToken = checkCredentials(params);
|
131 |
+
const repoId = toRepoId(params.repo);
|
132 |
+
yield { event: "phase", phase: "preuploading" };
|
133 |
+
|
134 |
+
const lfsShas = new Map<string, string | null>();
|
135 |
+
|
136 |
+
const abortController = new AbortController();
|
137 |
+
const abortSignal = abortController.signal;
|
138 |
+
|
139 |
+
// Polyfill see https://discuss.huggingface.co/t/why-cant-i-upload-a-parquet-file-to-my-dataset-error-o-throwifaborted-is-not-a-function/62245
|
140 |
+
if (!abortSignal.throwIfAborted) {
|
141 |
+
abortSignal.throwIfAborted = () => {
|
142 |
+
if (abortSignal.aborted) {
|
143 |
+
throw new DOMException("Aborted", "AbortError");
|
144 |
+
}
|
145 |
+
};
|
146 |
+
}
|
147 |
+
|
148 |
+
if (params.abortSignal) {
|
149 |
+
params.abortSignal.addEventListener("abort", () => abortController.abort());
|
150 |
+
}
|
151 |
+
|
152 |
+
try {
|
153 |
+
const allOperations = (
|
154 |
+
await Promise.all(
|
155 |
+
params.operations.map(async (operation) => {
|
156 |
+
if (operation.operation !== "addOrUpdate") {
|
157 |
+
return operation;
|
158 |
+
}
|
159 |
+
|
160 |
+
if (!(operation.content instanceof URL)) {
|
161 |
+
/** TS trick to enforce `content` to be a `Blob` */
|
162 |
+
return { ...operation, content: operation.content };
|
163 |
+
}
|
164 |
+
|
165 |
+
const lazyBlobs = await createBlobs(operation.content, operation.path, {
|
166 |
+
fetch: params.fetch,
|
167 |
+
maxFolderDepth: params.maxFolderDepth,
|
168 |
+
});
|
169 |
+
|
170 |
+
abortSignal?.throwIfAborted();
|
171 |
+
|
172 |
+
return lazyBlobs.map((blob) => ({
|
173 |
+
...operation,
|
174 |
+
content: blob.blob,
|
175 |
+
path: blob.path,
|
176 |
+
}));
|
177 |
+
})
|
178 |
+
)
|
179 |
+
).flat(1);
|
180 |
+
|
181 |
+
const gitAttributes = allOperations.filter(isFileOperation).find((op) => op.path === ".gitattributes")?.content;
|
182 |
+
|
183 |
+
for (const operations of chunk(allOperations.filter(isFileOperation), 100)) {
|
184 |
+
const payload: ApiPreuploadRequest = {
|
185 |
+
gitAttributes: gitAttributes && (await gitAttributes.text()),
|
186 |
+
files: await Promise.all(
|
187 |
+
operations.map(async (operation) => ({
|
188 |
+
path: operation.path,
|
189 |
+
size: operation.content.size,
|
190 |
+
sample: base64FromBytes(new Uint8Array(await operation.content.slice(0, 512).arrayBuffer())),
|
191 |
+
}))
|
192 |
+
),
|
193 |
+
};
|
194 |
+
|
195 |
+
abortSignal?.throwIfAborted();
|
196 |
+
|
197 |
+
const res = await (params.fetch ?? fetch)(
|
198 |
+
`${params.hubUrl ?? HUB_URL}/api/${repoId.type}s/${repoId.name}/preupload/${encodeURIComponent(
|
199 |
+
params.branch ?? "main"
|
200 |
+
)}` + (params.isPullRequest ? "?create_pr=1" : ""),
|
201 |
+
{
|
202 |
+
method: "POST",
|
203 |
+
headers: {
|
204 |
+
...(accessToken && { Authorization: `Bearer ${accessToken}` }),
|
205 |
+
"Content-Type": "application/json",
|
206 |
+
},
|
207 |
+
body: JSON.stringify(payload),
|
208 |
+
signal: abortSignal,
|
209 |
+
}
|
210 |
+
);
|
211 |
+
|
212 |
+
if (!res.ok) {
|
213 |
+
throw await createApiError(res);
|
214 |
+
}
|
215 |
+
|
216 |
+
const json: ApiPreuploadResponse = await res.json();
|
217 |
+
|
218 |
+
for (const file of json.files) {
|
219 |
+
if (file.uploadMode === "lfs") {
|
220 |
+
lfsShas.set(file.path, null);
|
221 |
+
}
|
222 |
+
}
|
223 |
+
}
|
224 |
+
|
225 |
+
yield { event: "phase", phase: "uploadingLargeFiles" };
|
226 |
+
|
227 |
+
for (const operations of chunk(
|
228 |
+
allOperations.filter(isFileOperation).filter((op) => lfsShas.has(op.path)),
|
229 |
+
100
|
230 |
+
)) {
|
231 |
+
const shas = yield* eventToGenerator<
|
232 |
+
{ event: "fileProgress"; state: "hashing"; path: string; progress: number },
|
233 |
+
string[]
|
234 |
+
>((yieldCallback, returnCallback, rejectCallack) => {
|
235 |
+
return promisesQueue(
|
236 |
+
operations.map((op) => async () => {
|
237 |
+
const iterator = sha256(op.content, { useWebWorker: params.useWebWorkers, abortSignal: abortSignal });
|
238 |
+
let res: IteratorResult<number, string>;
|
239 |
+
do {
|
240 |
+
res = await iterator.next();
|
241 |
+
if (!res.done) {
|
242 |
+
yieldCallback({ event: "fileProgress", path: op.path, progress: res.value, state: "hashing" });
|
243 |
+
}
|
244 |
+
} while (!res.done);
|
245 |
+
const sha = res.value;
|
246 |
+
lfsShas.set(op.path, res.value);
|
247 |
+
return sha;
|
248 |
+
}),
|
249 |
+
CONCURRENT_SHAS
|
250 |
+
).then(returnCallback, rejectCallack);
|
251 |
+
});
|
252 |
+
|
253 |
+
abortSignal?.throwIfAborted();
|
254 |
+
|
255 |
+
const payload: ApiLfsBatchRequest = {
|
256 |
+
operation: "upload",
|
257 |
+
// multipart is a custom protocol for HF
|
258 |
+
transfers: ["basic", "multipart"],
|
259 |
+
hash_algo: "sha_256",
|
260 |
+
...(!params.isPullRequest && {
|
261 |
+
ref: {
|
262 |
+
name: params.branch ?? "main",
|
263 |
+
},
|
264 |
+
}),
|
265 |
+
objects: operations.map((op, i) => ({
|
266 |
+
oid: shas[i],
|
267 |
+
size: op.content.size,
|
268 |
+
})),
|
269 |
+
};
|
270 |
+
|
271 |
+
const res = await (params.fetch ?? fetch)(
|
272 |
+
`${params.hubUrl ?? HUB_URL}/${repoId.type === "model" ? "" : repoId.type + "s/"}${
|
273 |
+
repoId.name
|
274 |
+
}.git/info/lfs/objects/batch`,
|
275 |
+
{
|
276 |
+
method: "POST",
|
277 |
+
headers: {
|
278 |
+
...(accessToken && { Authorization: `Bearer ${accessToken}` }),
|
279 |
+
Accept: "application/vnd.git-lfs+json",
|
280 |
+
"Content-Type": "application/vnd.git-lfs+json",
|
281 |
+
},
|
282 |
+
body: JSON.stringify(payload),
|
283 |
+
signal: abortSignal,
|
284 |
+
}
|
285 |
+
);
|
286 |
+
|
287 |
+
if (!res.ok) {
|
288 |
+
throw await createApiError(res);
|
289 |
+
}
|
290 |
+
|
291 |
+
const json: ApiLfsBatchResponse = await res.json();
|
292 |
+
const batchRequestId = res.headers.get("X-Request-Id") || undefined;
|
293 |
+
|
294 |
+
const shaToOperation = new Map(operations.map((op, i) => [shas[i], op]));
|
295 |
+
|
296 |
+
yield* eventToGenerator<CommitProgressEvent, void>((yieldCallback, returnCallback, rejectCallback) => {
|
297 |
+
return promisesQueueStreaming(
|
298 |
+
json.objects.map((obj) => async () => {
|
299 |
+
const op = shaToOperation.get(obj.oid);
|
300 |
+
|
301 |
+
if (!op) {
|
302 |
+
throw new InvalidApiResponseFormatError("Unrequested object ID in response");
|
303 |
+
}
|
304 |
+
|
305 |
+
abortSignal?.throwIfAborted();
|
306 |
+
|
307 |
+
if (obj.error) {
|
308 |
+
const errorMessage = `Error while doing LFS batch call for ${operations[shas.indexOf(obj.oid)].path}: ${
|
309 |
+
obj.error.message
|
310 |
+
}${batchRequestId ? ` - Request ID: ${batchRequestId}` : ""}`;
|
311 |
+
throw new HubApiError(res.url, obj.error.code, batchRequestId, errorMessage);
|
312 |
+
}
|
313 |
+
if (!obj.actions?.upload) {
|
314 |
+
// Already uploaded
|
315 |
+
yieldCallback({
|
316 |
+
event: "fileProgress",
|
317 |
+
path: op.path,
|
318 |
+
progress: 1,
|
319 |
+
state: "uploading",
|
320 |
+
});
|
321 |
+
return;
|
322 |
+
}
|
323 |
+
yieldCallback({
|
324 |
+
event: "fileProgress",
|
325 |
+
path: op.path,
|
326 |
+
progress: 0,
|
327 |
+
state: "uploading",
|
328 |
+
});
|
329 |
+
const content = op.content;
|
330 |
+
const header = obj.actions.upload.header;
|
331 |
+
if (header?.chunk_size) {
|
332 |
+
const chunkSize = parseInt(header.chunk_size);
|
333 |
+
|
334 |
+
// multipart upload
|
335 |
+
// parts are in upload.header['00001'] to upload.header['99999']
|
336 |
+
|
337 |
+
const completionUrl = obj.actions.upload.href;
|
338 |
+
const parts = Object.keys(header).filter((key) => /^[0-9]+$/.test(key));
|
339 |
+
|
340 |
+
if (parts.length !== Math.ceil(content.size / chunkSize)) {
|
341 |
+
throw new Error("Invalid server response to upload large LFS file, wrong number of parts");
|
342 |
+
}
|
343 |
+
|
344 |
+
const completeReq: ApiLfsCompleteMultipartRequest = {
|
345 |
+
oid: obj.oid,
|
346 |
+
parts: parts.map((part) => ({
|
347 |
+
partNumber: +part,
|
348 |
+
etag: "",
|
349 |
+
})),
|
350 |
+
};
|
351 |
+
|
352 |
+
// Defined here so that it's not redefined at each iteration (and the caller can tell it's for the same file)
|
353 |
+
const progressCallback = (progress: number) =>
|
354 |
+
yieldCallback({ event: "fileProgress", path: op.path, progress, state: "uploading" });
|
355 |
+
|
356 |
+
await promisesQueueStreaming(
|
357 |
+
parts.map((part) => async () => {
|
358 |
+
abortSignal?.throwIfAborted();
|
359 |
+
|
360 |
+
const index = parseInt(part) - 1;
|
361 |
+
const slice = content.slice(index * chunkSize, (index + 1) * chunkSize);
|
362 |
+
|
363 |
+
const res = await (params.fetch ?? fetch)(header[part], {
|
364 |
+
method: "PUT",
|
365 |
+
/** Unfortunately, browsers don't support our inherited version of Blob in fetch calls */
|
366 |
+
body: slice instanceof WebBlob && isFrontend ? await slice.arrayBuffer() : slice,
|
367 |
+
signal: abortSignal,
|
368 |
+
...({
|
369 |
+
progressHint: {
|
370 |
+
path: op.path,
|
371 |
+
part: index,
|
372 |
+
numParts: parts.length,
|
373 |
+
progressCallback,
|
374 |
+
},
|
375 |
+
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
376 |
+
} as any),
|
377 |
+
});
|
378 |
+
|
379 |
+
if (!res.ok) {
|
380 |
+
throw await createApiError(res, {
|
381 |
+
requestId: batchRequestId,
|
382 |
+
message: `Error while uploading part ${part} of ${
|
383 |
+
operations[shas.indexOf(obj.oid)].path
|
384 |
+
} to LFS storage`,
|
385 |
+
});
|
386 |
+
}
|
387 |
+
|
388 |
+
const eTag = res.headers.get("ETag");
|
389 |
+
|
390 |
+
if (!eTag) {
|
391 |
+
throw new Error("Cannot get ETag of part during multipart upload");
|
392 |
+
}
|
393 |
+
|
394 |
+
completeReq.parts[Number(part) - 1].etag = eTag;
|
395 |
+
}),
|
396 |
+
MULTIPART_PARALLEL_UPLOAD
|
397 |
+
);
|
398 |
+
|
399 |
+
abortSignal?.throwIfAborted();
|
400 |
+
|
401 |
+
const res = await (params.fetch ?? fetch)(completionUrl, {
|
402 |
+
method: "POST",
|
403 |
+
body: JSON.stringify(completeReq),
|
404 |
+
headers: {
|
405 |
+
Accept: "application/vnd.git-lfs+json",
|
406 |
+
"Content-Type": "application/vnd.git-lfs+json",
|
407 |
+
},
|
408 |
+
signal: abortSignal,
|
409 |
+
});
|
410 |
+
|
411 |
+
if (!res.ok) {
|
412 |
+
throw await createApiError(res, {
|
413 |
+
requestId: batchRequestId,
|
414 |
+
message: `Error completing multipart upload of ${
|
415 |
+
operations[shas.indexOf(obj.oid)].path
|
416 |
+
} to LFS storage`,
|
417 |
+
});
|
418 |
+
}
|
419 |
+
|
420 |
+
yieldCallback({
|
421 |
+
event: "fileProgress",
|
422 |
+
path: op.path,
|
423 |
+
progress: 1,
|
424 |
+
state: "uploading",
|
425 |
+
});
|
426 |
+
} else {
|
427 |
+
const res = await (params.fetch ?? fetch)(obj.actions.upload.href, {
|
428 |
+
method: "PUT",
|
429 |
+
headers: {
|
430 |
+
...(batchRequestId ? { "X-Request-Id": batchRequestId } : undefined),
|
431 |
+
},
|
432 |
+
/** Unfortunately, browsers don't support our inherited version of Blob in fetch calls */
|
433 |
+
body: content instanceof WebBlob && isFrontend ? await content.arrayBuffer() : content,
|
434 |
+
signal: abortSignal,
|
435 |
+
...({
|
436 |
+
progressHint: {
|
437 |
+
path: op.path,
|
438 |
+
progressCallback: (progress: number) =>
|
439 |
+
yieldCallback({
|
440 |
+
event: "fileProgress",
|
441 |
+
path: op.path,
|
442 |
+
progress,
|
443 |
+
state: "uploading",
|
444 |
+
}),
|
445 |
+
},
|
446 |
+
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
447 |
+
} as any),
|
448 |
+
});
|
449 |
+
|
450 |
+
if (!res.ok) {
|
451 |
+
throw await createApiError(res, {
|
452 |
+
requestId: batchRequestId,
|
453 |
+
message: `Error while uploading ${operations[shas.indexOf(obj.oid)].path} to LFS storage`,
|
454 |
+
});
|
455 |
+
}
|
456 |
+
|
457 |
+
yieldCallback({
|
458 |
+
event: "fileProgress",
|
459 |
+
path: op.path,
|
460 |
+
progress: 1,
|
461 |
+
state: "uploading",
|
462 |
+
});
|
463 |
+
}
|
464 |
+
}),
|
465 |
+
CONCURRENT_LFS_UPLOADS
|
466 |
+
).then(returnCallback, rejectCallback);
|
467 |
+
});
|
468 |
+
}
|
469 |
+
|
470 |
+
abortSignal?.throwIfAborted();
|
471 |
+
|
472 |
+
yield { event: "phase", phase: "committing" };
|
473 |
+
|
474 |
+
return yield* eventToGenerator<CommitProgressEvent, CommitOutput>(
|
475 |
+
async (yieldCallback, returnCallback, rejectCallback) =>
|
476 |
+
(params.fetch ?? fetch)(
|
477 |
+
`${params.hubUrl ?? HUB_URL}/api/${repoId.type}s/${repoId.name}/commit/${encodeURIComponent(
|
478 |
+
params.branch ?? "main"
|
479 |
+
)}` + (params.isPullRequest ? "?create_pr=1" : ""),
|
480 |
+
{
|
481 |
+
method: "POST",
|
482 |
+
headers: {
|
483 |
+
...(accessToken && { Authorization: `Bearer ${accessToken}` }),
|
484 |
+
"Content-Type": "application/x-ndjson",
|
485 |
+
},
|
486 |
+
body: [
|
487 |
+
{
|
488 |
+
key: "header",
|
489 |
+
value: {
|
490 |
+
summary: params.title,
|
491 |
+
description: params.description,
|
492 |
+
parentCommit: params.parentCommit,
|
493 |
+
} satisfies ApiCommitHeader,
|
494 |
+
},
|
495 |
+
...((await Promise.all(
|
496 |
+
allOperations.map((operation) => {
|
497 |
+
if (isFileOperation(operation)) {
|
498 |
+
const sha = lfsShas.get(operation.path);
|
499 |
+
if (sha) {
|
500 |
+
return {
|
501 |
+
key: "lfsFile",
|
502 |
+
value: {
|
503 |
+
path: operation.path,
|
504 |
+
algo: "sha256",
|
505 |
+
size: operation.content.size,
|
506 |
+
oid: sha,
|
507 |
+
} satisfies ApiCommitLfsFile,
|
508 |
+
};
|
509 |
+
}
|
510 |
+
}
|
511 |
+
|
512 |
+
return convertOperationToNdJson(operation);
|
513 |
+
})
|
514 |
+
)) satisfies ApiCommitOperation[]),
|
515 |
+
]
|
516 |
+
.map((x) => JSON.stringify(x))
|
517 |
+
.join("\n"),
|
518 |
+
signal: abortSignal,
|
519 |
+
...({
|
520 |
+
progressHint: {
|
521 |
+
progressCallback: (progress: number) => {
|
522 |
+
// For now, we display equal progress for all files
|
523 |
+
// We could compute the progress based on the size of `convertOperationToNdJson` for each of the files instead
|
524 |
+
for (const op of allOperations) {
|
525 |
+
if (isFileOperation(op) && !lfsShas.has(op.path)) {
|
526 |
+
yieldCallback({
|
527 |
+
event: "fileProgress",
|
528 |
+
path: op.path,
|
529 |
+
progress,
|
530 |
+
state: "uploading",
|
531 |
+
});
|
532 |
+
}
|
533 |
+
}
|
534 |
+
},
|
535 |
+
},
|
536 |
+
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
537 |
+
} as any),
|
538 |
+
}
|
539 |
+
)
|
540 |
+
.then(async (res) => {
|
541 |
+
if (!res.ok) {
|
542 |
+
throw await createApiError(res);
|
543 |
+
}
|
544 |
+
|
545 |
+
const json = await res.json();
|
546 |
+
|
547 |
+
returnCallback({
|
548 |
+
pullRequestUrl: json.pullRequestUrl,
|
549 |
+
commit: {
|
550 |
+
oid: json.commitOid,
|
551 |
+
url: json.commitUrl,
|
552 |
+
},
|
553 |
+
hookOutput: json.hookOutput,
|
554 |
+
});
|
555 |
+
})
|
556 |
+
.catch(rejectCallback)
|
557 |
+
);
|
558 |
+
} catch (err) {
|
559 |
+
// For parallel requests, cancel them all if one fails
|
560 |
+
abortController.abort();
|
561 |
+
throw err;
|
562 |
+
}
|
563 |
+
}
|
564 |
+
|
565 |
+
export async function commit(params: CommitParams): Promise<CommitOutput> {
|
566 |
+
const iterator = commitIter(params);
|
567 |
+
let res = await iterator.next();
|
568 |
+
while (!res.done) {
|
569 |
+
res = await iterator.next();
|
570 |
+
}
|
571 |
+
return res.value;
|
572 |
+
}
|
573 |
+
|
574 |
+
async function convertOperationToNdJson(operation: CommitBlobOperation): Promise<ApiCommitOperation> {
|
575 |
+
switch (operation.operation) {
|
576 |
+
case "addOrUpdate": {
|
577 |
+
// todo: handle LFS
|
578 |
+
return {
|
579 |
+
key: "file",
|
580 |
+
value: {
|
581 |
+
content: base64FromBytes(new Uint8Array(await operation.content.arrayBuffer())),
|
582 |
+
path: operation.path,
|
583 |
+
encoding: "base64",
|
584 |
+
},
|
585 |
+
};
|
586 |
+
}
|
587 |
+
// case "rename": {
|
588 |
+
// // todo: detect when remote file is already LFS, and in that case rename as LFS
|
589 |
+
// return {
|
590 |
+
// key: "file",
|
591 |
+
// value: {
|
592 |
+
// content: operation.content,
|
593 |
+
// path: operation.path,
|
594 |
+
// oldPath: operation.oldPath
|
595 |
+
// }
|
596 |
+
// };
|
597 |
+
// }
|
598 |
+
case "delete": {
|
599 |
+
return {
|
600 |
+
key: "deletedFile",
|
601 |
+
value: {
|
602 |
+
path: operation.path,
|
603 |
+
},
|
604 |
+
};
|
605 |
+
}
|
606 |
+
default:
|
607 |
+
throw new TypeError("Unknown operation: " + (operation as { operation: string }).operation);
|
608 |
+
}
|
609 |
+
}
|
lib/count-commits.spec.ts
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { assert, it, describe } from "vitest";
|
2 |
+
import { countCommits } from "./count-commits";
|
3 |
+
|
4 |
+
describe("countCommits", () => {
|
5 |
+
it("should fetch paginated commits from the repo", async () => {
|
6 |
+
const count = await countCommits({
|
7 |
+
repo: {
|
8 |
+
name: "openai-community/gpt2",
|
9 |
+
type: "model",
|
10 |
+
},
|
11 |
+
revision: "607a30d783dfa663caf39e06633721c8d4cfcd7e",
|
12 |
+
});
|
13 |
+
|
14 |
+
assert.equal(count, 26);
|
15 |
+
});
|
16 |
+
});
|
lib/count-commits.ts
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { HUB_URL } from "../consts";
|
2 |
+
import { createApiError } from "../error";
|
3 |
+
import type { CredentialsParams, RepoDesignation } from "../types/public";
|
4 |
+
import { checkCredentials } from "../utils/checkCredentials";
|
5 |
+
import { toRepoId } from "../utils/toRepoId";
|
6 |
+
|
7 |
+
export async function countCommits(
|
8 |
+
params: {
|
9 |
+
repo: RepoDesignation;
|
10 |
+
/**
|
11 |
+
* Revision to list commits from. Defaults to the default branch.
|
12 |
+
*/
|
13 |
+
revision?: string;
|
14 |
+
hubUrl?: string;
|
15 |
+
fetch?: typeof fetch;
|
16 |
+
} & Partial<CredentialsParams>
|
17 |
+
): Promise<number> {
|
18 |
+
const accessToken = checkCredentials(params);
|
19 |
+
const repoId = toRepoId(params.repo);
|
20 |
+
|
21 |
+
// Could upgrade to 1000 commits per page
|
22 |
+
const url: string | undefined = `${params.hubUrl ?? HUB_URL}/api/${repoId.type}s/${repoId.name}/commits/${
|
23 |
+
params.revision ?? "main"
|
24 |
+
}?limit=1`;
|
25 |
+
|
26 |
+
const res: Response = await (params.fetch ?? fetch)(url, {
|
27 |
+
headers: accessToken ? { Authorization: `Bearer ${accessToken}` } : {},
|
28 |
+
});
|
29 |
+
|
30 |
+
if (!res.ok) {
|
31 |
+
throw await createApiError(res);
|
32 |
+
}
|
33 |
+
|
34 |
+
return parseInt(res.headers.get("x-total-count") ?? "0", 10);
|
35 |
+
}
|
lib/create-branch.spec.ts
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { assert, it, describe } from "vitest";
|
2 |
+
import { TEST_ACCESS_TOKEN, TEST_HUB_URL, TEST_USER } from "../test/consts";
|
3 |
+
import type { RepoId } from "../types/public";
|
4 |
+
import { insecureRandomString } from "../utils/insecureRandomString";
|
5 |
+
import { createRepo } from "./create-repo";
|
6 |
+
import { deleteRepo } from "./delete-repo";
|
7 |
+
import { createBranch } from "./create-branch";
|
8 |
+
import { uploadFile } from "./upload-file";
|
9 |
+
import { downloadFile } from "./download-file";
|
10 |
+
|
11 |
+
describe("createBranch", () => {
|
12 |
+
it("should create a new branch from the default branch", async () => {
|
13 |
+
const repoName = `${TEST_USER}/TEST-${insecureRandomString()}`;
|
14 |
+
const repo = { type: "model", name: repoName } satisfies RepoId;
|
15 |
+
|
16 |
+
try {
|
17 |
+
await createRepo({
|
18 |
+
accessToken: TEST_ACCESS_TOKEN,
|
19 |
+
hubUrl: TEST_HUB_URL,
|
20 |
+
repo,
|
21 |
+
});
|
22 |
+
|
23 |
+
await uploadFile({
|
24 |
+
repo,
|
25 |
+
accessToken: TEST_ACCESS_TOKEN,
|
26 |
+
hubUrl: TEST_HUB_URL,
|
27 |
+
file: {
|
28 |
+
path: "file.txt",
|
29 |
+
content: new Blob(["file content"]),
|
30 |
+
},
|
31 |
+
});
|
32 |
+
|
33 |
+
await createBranch({
|
34 |
+
repo,
|
35 |
+
branch: "new-branch",
|
36 |
+
accessToken: TEST_ACCESS_TOKEN,
|
37 |
+
hubUrl: TEST_HUB_URL,
|
38 |
+
});
|
39 |
+
|
40 |
+
const content = await downloadFile({
|
41 |
+
repo,
|
42 |
+
accessToken: TEST_ACCESS_TOKEN,
|
43 |
+
hubUrl: TEST_HUB_URL,
|
44 |
+
path: "file.txt",
|
45 |
+
revision: "new-branch",
|
46 |
+
});
|
47 |
+
|
48 |
+
assert.equal(await content?.text(), "file content");
|
49 |
+
} finally {
|
50 |
+
await deleteRepo({
|
51 |
+
repo,
|
52 |
+
accessToken: TEST_ACCESS_TOKEN,
|
53 |
+
hubUrl: TEST_HUB_URL,
|
54 |
+
});
|
55 |
+
}
|
56 |
+
});
|
57 |
+
|
58 |
+
it("should create an empty branch", async () => {
|
59 |
+
const repoName = `${TEST_USER}/TEST-${insecureRandomString()}`;
|
60 |
+
const repo = { type: "model", name: repoName } satisfies RepoId;
|
61 |
+
|
62 |
+
try {
|
63 |
+
await createRepo({
|
64 |
+
accessToken: TEST_ACCESS_TOKEN,
|
65 |
+
hubUrl: TEST_HUB_URL,
|
66 |
+
repo,
|
67 |
+
});
|
68 |
+
|
69 |
+
await uploadFile({
|
70 |
+
repo,
|
71 |
+
accessToken: TEST_ACCESS_TOKEN,
|
72 |
+
hubUrl: TEST_HUB_URL,
|
73 |
+
file: {
|
74 |
+
path: "file.txt",
|
75 |
+
content: new Blob(["file content"]),
|
76 |
+
},
|
77 |
+
});
|
78 |
+
|
79 |
+
await createBranch({
|
80 |
+
repo,
|
81 |
+
branch: "empty-branch",
|
82 |
+
empty: true,
|
83 |
+
accessToken: TEST_ACCESS_TOKEN,
|
84 |
+
hubUrl: TEST_HUB_URL,
|
85 |
+
});
|
86 |
+
|
87 |
+
const content = await downloadFile({
|
88 |
+
repo,
|
89 |
+
accessToken: TEST_ACCESS_TOKEN,
|
90 |
+
hubUrl: TEST_HUB_URL,
|
91 |
+
path: "file.txt",
|
92 |
+
revision: "empty-branch",
|
93 |
+
});
|
94 |
+
|
95 |
+
assert.equal(content, null);
|
96 |
+
} finally {
|
97 |
+
await deleteRepo({
|
98 |
+
repo,
|
99 |
+
accessToken: TEST_ACCESS_TOKEN,
|
100 |
+
hubUrl: TEST_HUB_URL,
|
101 |
+
});
|
102 |
+
}
|
103 |
+
});
|
104 |
+
|
105 |
+
it("should overwrite an existing branch", async () => {
|
106 |
+
const repoName = `${TEST_USER}/TEST-${insecureRandomString()}`;
|
107 |
+
const repo = { type: "model", name: repoName } satisfies RepoId;
|
108 |
+
|
109 |
+
try {
|
110 |
+
await createRepo({
|
111 |
+
accessToken: TEST_ACCESS_TOKEN,
|
112 |
+
hubUrl: TEST_HUB_URL,
|
113 |
+
repo,
|
114 |
+
});
|
115 |
+
|
116 |
+
await uploadFile({
|
117 |
+
repo,
|
118 |
+
accessToken: TEST_ACCESS_TOKEN,
|
119 |
+
hubUrl: TEST_HUB_URL,
|
120 |
+
file: {
|
121 |
+
path: "file.txt",
|
122 |
+
content: new Blob(["file content"]),
|
123 |
+
},
|
124 |
+
});
|
125 |
+
|
126 |
+
await createBranch({
|
127 |
+
repo,
|
128 |
+
branch: "overwrite-branch",
|
129 |
+
accessToken: TEST_ACCESS_TOKEN,
|
130 |
+
hubUrl: TEST_HUB_URL,
|
131 |
+
});
|
132 |
+
|
133 |
+
await createBranch({
|
134 |
+
repo,
|
135 |
+
branch: "overwrite-branch",
|
136 |
+
overwrite: true,
|
137 |
+
empty: true,
|
138 |
+
accessToken: TEST_ACCESS_TOKEN,
|
139 |
+
hubUrl: TEST_HUB_URL,
|
140 |
+
});
|
141 |
+
|
142 |
+
const content = await downloadFile({
|
143 |
+
repo,
|
144 |
+
accessToken: TEST_ACCESS_TOKEN,
|
145 |
+
hubUrl: TEST_HUB_URL,
|
146 |
+
path: "file.txt",
|
147 |
+
revision: "overwrite-branch",
|
148 |
+
});
|
149 |
+
|
150 |
+
assert.equal(content, null);
|
151 |
+
} finally {
|
152 |
+
await deleteRepo({
|
153 |
+
repo,
|
154 |
+
accessToken: TEST_ACCESS_TOKEN,
|
155 |
+
hubUrl: TEST_HUB_URL,
|
156 |
+
});
|
157 |
+
}
|
158 |
+
});
|
159 |
+
});
|
lib/create-branch.ts
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { HUB_URL } from "../consts";
|
2 |
+
import { createApiError } from "../error";
|
3 |
+
import type { AccessToken, RepoDesignation } from "../types/public";
|
4 |
+
import { toRepoId } from "../utils/toRepoId";
|
5 |
+
|
6 |
+
export async function createBranch(params: {
|
7 |
+
repo: RepoDesignation;
|
8 |
+
/**
|
9 |
+
* Revision to create the branch from. Defaults to the default branch.
|
10 |
+
*
|
11 |
+
* Use empty: true to create an empty branch.
|
12 |
+
*/
|
13 |
+
revision?: string;
|
14 |
+
hubUrl?: string;
|
15 |
+
accessToken?: AccessToken;
|
16 |
+
fetch?: typeof fetch;
|
17 |
+
/**
|
18 |
+
* The name of the branch to create
|
19 |
+
*/
|
20 |
+
branch: string;
|
21 |
+
/**
|
22 |
+
* Use this to create an empty branch, with no commits.
|
23 |
+
*/
|
24 |
+
empty?: boolean;
|
25 |
+
/**
|
26 |
+
* Use this to overwrite the branch if it already exists.
|
27 |
+
*
|
28 |
+
* If you only specify `overwrite` and no `revision`/`empty`, and the branch already exists, it will be a no-op.
|
29 |
+
*/
|
30 |
+
overwrite?: boolean;
|
31 |
+
}): Promise<void> {
|
32 |
+
const repoId = toRepoId(params.repo);
|
33 |
+
const res = await (params.fetch ?? fetch)(
|
34 |
+
`${params.hubUrl ?? HUB_URL}/api/${repoId.type}s/${repoId.name}/branch/${encodeURIComponent(params.branch)}`,
|
35 |
+
{
|
36 |
+
method: "POST",
|
37 |
+
headers: {
|
38 |
+
"Content-Type": "application/json",
|
39 |
+
...(params.accessToken && {
|
40 |
+
Authorization: `Bearer ${params.accessToken}`,
|
41 |
+
}),
|
42 |
+
},
|
43 |
+
body: JSON.stringify({
|
44 |
+
startingPoint: params.revision,
|
45 |
+
...(params.empty && { emptyBranch: true }),
|
46 |
+
overwrite: params.overwrite,
|
47 |
+
}),
|
48 |
+
}
|
49 |
+
);
|
50 |
+
|
51 |
+
if (!res.ok) {
|
52 |
+
throw await createApiError(res);
|
53 |
+
}
|
54 |
+
}
|
lib/create-repo.spec.ts
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { assert, it, describe, expect } from "vitest";
|
2 |
+
|
3 |
+
import { TEST_HUB_URL, TEST_ACCESS_TOKEN, TEST_USER } from "../test/consts";
|
4 |
+
import { insecureRandomString } from "../utils/insecureRandomString";
|
5 |
+
import { createRepo } from "./create-repo";
|
6 |
+
import { deleteRepo } from "./delete-repo";
|
7 |
+
import { downloadFile } from "./download-file";
|
8 |
+
|
9 |
+
describe("createRepo", () => {
|
10 |
+
it("should create a repo", async () => {
|
11 |
+
const repoName = `${TEST_USER}/TEST-${insecureRandomString()}`;
|
12 |
+
|
13 |
+
const result = await createRepo({
|
14 |
+
accessToken: TEST_ACCESS_TOKEN,
|
15 |
+
repo: {
|
16 |
+
name: repoName,
|
17 |
+
type: "model",
|
18 |
+
},
|
19 |
+
hubUrl: TEST_HUB_URL,
|
20 |
+
files: [{ path: ".gitattributes", content: new Blob(["*.html filter=lfs diff=lfs merge=lfs -text"]) }],
|
21 |
+
});
|
22 |
+
|
23 |
+
assert.deepStrictEqual(result, {
|
24 |
+
repoUrl: `${TEST_HUB_URL}/${repoName}`,
|
25 |
+
});
|
26 |
+
|
27 |
+
const content = await downloadFile({
|
28 |
+
repo: {
|
29 |
+
name: repoName,
|
30 |
+
type: "model",
|
31 |
+
},
|
32 |
+
path: ".gitattributes",
|
33 |
+
hubUrl: TEST_HUB_URL,
|
34 |
+
});
|
35 |
+
|
36 |
+
assert(content);
|
37 |
+
assert.strictEqual(await content.text(), "*.html filter=lfs diff=lfs merge=lfs -text");
|
38 |
+
|
39 |
+
await deleteRepo({
|
40 |
+
repo: {
|
41 |
+
name: repoName,
|
42 |
+
type: "model",
|
43 |
+
},
|
44 |
+
credentials: { accessToken: TEST_ACCESS_TOKEN },
|
45 |
+
hubUrl: TEST_HUB_URL,
|
46 |
+
});
|
47 |
+
});
|
48 |
+
|
49 |
+
it("should throw a client error when trying to create a repo without a fully-qualified name", async () => {
|
50 |
+
const tryCreate = createRepo({
|
51 |
+
repo: { name: "canonical", type: "model" },
|
52 |
+
credentials: { accessToken: TEST_ACCESS_TOKEN },
|
53 |
+
hubUrl: TEST_HUB_URL,
|
54 |
+
});
|
55 |
+
|
56 |
+
await expect(tryCreate).rejects.toBeInstanceOf(TypeError);
|
57 |
+
});
|
58 |
+
|
59 |
+
it("should create a model with a string as name", async () => {
|
60 |
+
const repoName = `${TEST_USER}/TEST-${insecureRandomString()}`;
|
61 |
+
|
62 |
+
const result = await createRepo({
|
63 |
+
accessToken: TEST_ACCESS_TOKEN,
|
64 |
+
hubUrl: TEST_HUB_URL,
|
65 |
+
repo: repoName,
|
66 |
+
files: [{ path: ".gitattributes", content: new Blob(["*.html filter=lfs diff=lfs merge=lfs -text"]) }],
|
67 |
+
});
|
68 |
+
|
69 |
+
assert.deepStrictEqual(result, {
|
70 |
+
repoUrl: `${TEST_HUB_URL}/${repoName}`,
|
71 |
+
});
|
72 |
+
|
73 |
+
await deleteRepo({
|
74 |
+
repo: {
|
75 |
+
name: repoName,
|
76 |
+
type: "model",
|
77 |
+
},
|
78 |
+
hubUrl: TEST_HUB_URL,
|
79 |
+
credentials: { accessToken: TEST_ACCESS_TOKEN },
|
80 |
+
});
|
81 |
+
});
|
82 |
+
|
83 |
+
it("should create a dataset with a string as name", async () => {
|
84 |
+
const repoName = `datasets/${TEST_USER}/TEST-${insecureRandomString()}`;
|
85 |
+
|
86 |
+
const result = await createRepo({
|
87 |
+
accessToken: TEST_ACCESS_TOKEN,
|
88 |
+
hubUrl: TEST_HUB_URL,
|
89 |
+
repo: repoName,
|
90 |
+
files: [{ path: ".gitattributes", content: new Blob(["*.html filter=lfs diff=lfs merge=lfs -text"]) }],
|
91 |
+
});
|
92 |
+
|
93 |
+
assert.deepStrictEqual(result, {
|
94 |
+
repoUrl: `${TEST_HUB_URL}/${repoName}`,
|
95 |
+
});
|
96 |
+
|
97 |
+
await deleteRepo({
|
98 |
+
repo: repoName,
|
99 |
+
hubUrl: TEST_HUB_URL,
|
100 |
+
credentials: { accessToken: TEST_ACCESS_TOKEN },
|
101 |
+
});
|
102 |
+
});
|
103 |
+
});
|
lib/create-repo.ts
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { HUB_URL } from "../consts";
|
2 |
+
import { createApiError } from "../error";
|
3 |
+
import type { ApiCreateRepoPayload } from "../types/api/api-create-repo";
|
4 |
+
import type { CredentialsParams, RepoDesignation, SpaceSdk } from "../types/public";
|
5 |
+
import { base64FromBytes } from "../utils/base64FromBytes";
|
6 |
+
import { checkCredentials } from "../utils/checkCredentials";
|
7 |
+
import { toRepoId } from "../utils/toRepoId";
|
8 |
+
|
9 |
+
export async function createRepo(
|
10 |
+
params: {
|
11 |
+
repo: RepoDesignation;
|
12 |
+
/**
|
13 |
+
* If unset, will follow the organization's default setting. (typically public, except for some Enterprise organizations)
|
14 |
+
*/
|
15 |
+
private?: boolean;
|
16 |
+
license?: string;
|
17 |
+
/**
|
18 |
+
* Only a few lightweight files are supported at repo creation
|
19 |
+
*/
|
20 |
+
files?: Array<{ content: ArrayBuffer | Blob; path: string }>;
|
21 |
+
/** @required for when {@link repo.type} === "space" */
|
22 |
+
sdk?: SpaceSdk;
|
23 |
+
hubUrl?: string;
|
24 |
+
/**
|
25 |
+
* Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
|
26 |
+
*/
|
27 |
+
fetch?: typeof fetch;
|
28 |
+
} & CredentialsParams
|
29 |
+
): Promise<{ repoUrl: string }> {
|
30 |
+
const accessToken = checkCredentials(params);
|
31 |
+
const repoId = toRepoId(params.repo);
|
32 |
+
const [namespace, repoName] = repoId.name.split("/");
|
33 |
+
|
34 |
+
if (!namespace || !repoName) {
|
35 |
+
throw new TypeError(
|
36 |
+
`"${repoId.name}" is not a fully qualified repo name. It should be of the form "{namespace}/{repoName}".`
|
37 |
+
);
|
38 |
+
}
|
39 |
+
|
40 |
+
const res = await (params.fetch ?? fetch)(`${params.hubUrl ?? HUB_URL}/api/repos/create`, {
|
41 |
+
method: "POST",
|
42 |
+
body: JSON.stringify({
|
43 |
+
name: repoName,
|
44 |
+
private: params.private,
|
45 |
+
organization: namespace,
|
46 |
+
license: params.license,
|
47 |
+
...(repoId.type === "space"
|
48 |
+
? {
|
49 |
+
type: "space",
|
50 |
+
sdk: "static",
|
51 |
+
}
|
52 |
+
: {
|
53 |
+
type: repoId.type,
|
54 |
+
}),
|
55 |
+
files: params.files
|
56 |
+
? await Promise.all(
|
57 |
+
params.files.map(async (file) => ({
|
58 |
+
encoding: "base64",
|
59 |
+
path: file.path,
|
60 |
+
content: base64FromBytes(
|
61 |
+
new Uint8Array(file.content instanceof Blob ? await file.content.arrayBuffer() : file.content)
|
62 |
+
),
|
63 |
+
}))
|
64 |
+
)
|
65 |
+
: undefined,
|
66 |
+
} satisfies ApiCreateRepoPayload),
|
67 |
+
headers: {
|
68 |
+
Authorization: `Bearer ${accessToken}`,
|
69 |
+
"Content-Type": "application/json",
|
70 |
+
},
|
71 |
+
});
|
72 |
+
|
73 |
+
if (!res.ok) {
|
74 |
+
throw await createApiError(res);
|
75 |
+
}
|
76 |
+
const output = await res.json();
|
77 |
+
return { repoUrl: output.url };
|
78 |
+
}
|
lib/dataset-info.spec.ts
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { describe, expect, it } from "vitest";
|
2 |
+
import { datasetInfo } from "./dataset-info";
|
3 |
+
import type { DatasetEntry } from "./list-datasets";
|
4 |
+
import type { ApiDatasetInfo } from "../types/api/api-dataset";
|
5 |
+
|
6 |
+
describe("datasetInfo", () => {
|
7 |
+
it("should return the dataset info", async () => {
|
8 |
+
const info = await datasetInfo({
|
9 |
+
name: "nyu-mll/glue",
|
10 |
+
});
|
11 |
+
expect(info).toEqual({
|
12 |
+
id: "621ffdd236468d709f181e3f",
|
13 |
+
downloads: expect.any(Number),
|
14 |
+
gated: false,
|
15 |
+
name: "nyu-mll/glue",
|
16 |
+
updatedAt: expect.any(Date),
|
17 |
+
likes: expect.any(Number),
|
18 |
+
private: false,
|
19 |
+
});
|
20 |
+
});
|
21 |
+
|
22 |
+
it("should return the dataset info with author", async () => {
|
23 |
+
const info: DatasetEntry & Pick<ApiDatasetInfo, "author"> = await datasetInfo({
|
24 |
+
name: "nyu-mll/glue",
|
25 |
+
additionalFields: ["author"],
|
26 |
+
});
|
27 |
+
expect(info).toEqual({
|
28 |
+
id: "621ffdd236468d709f181e3f",
|
29 |
+
downloads: expect.any(Number),
|
30 |
+
gated: false,
|
31 |
+
name: "nyu-mll/glue",
|
32 |
+
updatedAt: expect.any(Date),
|
33 |
+
likes: expect.any(Number),
|
34 |
+
private: false,
|
35 |
+
author: "nyu-mll",
|
36 |
+
});
|
37 |
+
});
|
38 |
+
|
39 |
+
it("should return the dataset info for a specific revision", async () => {
|
40 |
+
const info: DatasetEntry & Pick<ApiDatasetInfo, "sha"> = await datasetInfo({
|
41 |
+
name: "nyu-mll/glue",
|
42 |
+
revision: "cb2099c76426ff97da7aa591cbd317d91fb5fcb7",
|
43 |
+
additionalFields: ["sha"],
|
44 |
+
});
|
45 |
+
expect(info).toEqual({
|
46 |
+
id: "621ffdd236468d709f181e3f",
|
47 |
+
downloads: expect.any(Number),
|
48 |
+
gated: false,
|
49 |
+
name: "nyu-mll/glue",
|
50 |
+
updatedAt: expect.any(Date),
|
51 |
+
likes: expect.any(Number),
|
52 |
+
private: false,
|
53 |
+
sha: "cb2099c76426ff97da7aa591cbd317d91fb5fcb7",
|
54 |
+
});
|
55 |
+
});
|
56 |
+
});
|
lib/dataset-info.ts
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { HUB_URL } from "../consts";
|
2 |
+
import { createApiError } from "../error";
|
3 |
+
import type { ApiDatasetInfo } from "../types/api/api-dataset";
|
4 |
+
import type { CredentialsParams } from "../types/public";
|
5 |
+
import { checkCredentials } from "../utils/checkCredentials";
|
6 |
+
import { pick } from "../utils/pick";
|
7 |
+
import { type DATASET_EXPANDABLE_KEYS, DATASET_EXPAND_KEYS, type DatasetEntry } from "./list-datasets";
|
8 |
+
|
9 |
+
export async function datasetInfo<
|
10 |
+
const T extends Exclude<(typeof DATASET_EXPANDABLE_KEYS)[number], (typeof DATASET_EXPAND_KEYS)[number]> = never,
|
11 |
+
>(
|
12 |
+
params: {
|
13 |
+
name: string;
|
14 |
+
hubUrl?: string;
|
15 |
+
additionalFields?: T[];
|
16 |
+
/**
|
17 |
+
* An optional Git revision id which can be a branch name, a tag, or a commit hash.
|
18 |
+
*/
|
19 |
+
revision?: string;
|
20 |
+
/**
|
21 |
+
* Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
|
22 |
+
*/
|
23 |
+
fetch?: typeof fetch;
|
24 |
+
} & Partial<CredentialsParams>
|
25 |
+
): Promise<DatasetEntry & Pick<ApiDatasetInfo, T>> {
|
26 |
+
const accessToken = params && checkCredentials(params);
|
27 |
+
|
28 |
+
const search = new URLSearchParams([
|
29 |
+
...DATASET_EXPAND_KEYS.map((val) => ["expand", val] satisfies [string, string]),
|
30 |
+
...(params?.additionalFields?.map((val) => ["expand", val] satisfies [string, string]) ?? []),
|
31 |
+
]).toString();
|
32 |
+
|
33 |
+
const response = await (params.fetch || fetch)(
|
34 |
+
`${params?.hubUrl || HUB_URL}/api/datasets/${params.name}/revision/${encodeURIComponent(
|
35 |
+
params.revision ?? "HEAD"
|
36 |
+
)}?${search.toString()}`,
|
37 |
+
{
|
38 |
+
headers: {
|
39 |
+
...(accessToken ? { Authorization: `Bearer ${accessToken}` } : {}),
|
40 |
+
Accepts: "application/json",
|
41 |
+
},
|
42 |
+
}
|
43 |
+
);
|
44 |
+
|
45 |
+
if (!response.ok) {
|
46 |
+
throw await createApiError(response);
|
47 |
+
}
|
48 |
+
|
49 |
+
const data = await response.json();
|
50 |
+
|
51 |
+
return {
|
52 |
+
...(params?.additionalFields && pick(data, params.additionalFields)),
|
53 |
+
id: data._id,
|
54 |
+
name: data.id,
|
55 |
+
private: data.private,
|
56 |
+
downloads: data.downloads,
|
57 |
+
likes: data.likes,
|
58 |
+
gated: data.gated,
|
59 |
+
updatedAt: new Date(data.lastModified),
|
60 |
+
} as DatasetEntry & Pick<ApiDatasetInfo, T>;
|
61 |
+
}
|
lib/delete-branch.spec.ts
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { it, describe } from "vitest";
|
2 |
+
import { TEST_ACCESS_TOKEN, TEST_HUB_URL, TEST_USER } from "../test/consts";
|
3 |
+
import type { RepoId } from "../types/public";
|
4 |
+
import { insecureRandomString } from "../utils/insecureRandomString";
|
5 |
+
import { createRepo } from "./create-repo";
|
6 |
+
import { deleteRepo } from "./delete-repo";
|
7 |
+
import { createBranch } from "./create-branch";
|
8 |
+
import { deleteBranch } from "./delete-branch";
|
9 |
+
|
10 |
+
describe("deleteBranch", () => {
|
11 |
+
it("should delete an existing branch", async () => {
|
12 |
+
const repoName = `${TEST_USER}/TEST-${insecureRandomString()}`;
|
13 |
+
const repo = { type: "model", name: repoName } satisfies RepoId;
|
14 |
+
|
15 |
+
try {
|
16 |
+
await createRepo({
|
17 |
+
accessToken: TEST_ACCESS_TOKEN,
|
18 |
+
hubUrl: TEST_HUB_URL,
|
19 |
+
repo,
|
20 |
+
});
|
21 |
+
|
22 |
+
await createBranch({
|
23 |
+
repo,
|
24 |
+
branch: "branch-to-delete",
|
25 |
+
accessToken: TEST_ACCESS_TOKEN,
|
26 |
+
hubUrl: TEST_HUB_URL,
|
27 |
+
});
|
28 |
+
|
29 |
+
await deleteBranch({
|
30 |
+
repo,
|
31 |
+
branch: "branch-to-delete",
|
32 |
+
accessToken: TEST_ACCESS_TOKEN,
|
33 |
+
hubUrl: TEST_HUB_URL,
|
34 |
+
});
|
35 |
+
} finally {
|
36 |
+
await deleteRepo({
|
37 |
+
repo,
|
38 |
+
accessToken: TEST_ACCESS_TOKEN,
|
39 |
+
hubUrl: TEST_HUB_URL,
|
40 |
+
});
|
41 |
+
}
|
42 |
+
});
|
43 |
+
});
|
lib/delete-branch.ts
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { HUB_URL } from "../consts";
|
2 |
+
import { createApiError } from "../error";
|
3 |
+
import type { AccessToken, RepoDesignation } from "../types/public";
|
4 |
+
import { toRepoId } from "../utils/toRepoId";
|
5 |
+
|
6 |
+
export async function deleteBranch(params: {
|
7 |
+
repo: RepoDesignation;
|
8 |
+
/**
|
9 |
+
* The name of the branch to delete
|
10 |
+
*/
|
11 |
+
branch: string;
|
12 |
+
hubUrl?: string;
|
13 |
+
accessToken?: AccessToken;
|
14 |
+
fetch?: typeof fetch;
|
15 |
+
}): Promise<void> {
|
16 |
+
const repoId = toRepoId(params.repo);
|
17 |
+
const res = await (params.fetch ?? fetch)(
|
18 |
+
`${params.hubUrl ?? HUB_URL}/api/${repoId.type}s/${repoId.name}/branch/${encodeURIComponent(params.branch)}`,
|
19 |
+
{
|
20 |
+
method: "DELETE",
|
21 |
+
headers: {
|
22 |
+
...(params.accessToken && {
|
23 |
+
Authorization: `Bearer ${params.accessToken}`,
|
24 |
+
}),
|
25 |
+
},
|
26 |
+
}
|
27 |
+
);
|
28 |
+
|
29 |
+
if (!res.ok) {
|
30 |
+
throw await createApiError(res);
|
31 |
+
}
|
32 |
+
}
|
lib/delete-file.spec.ts
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { assert, it, describe } from "vitest";
|
2 |
+
|
3 |
+
import { TEST_ACCESS_TOKEN, TEST_HUB_URL, TEST_USER } from "../test/consts";
|
4 |
+
import type { RepoId } from "../types/public";
|
5 |
+
import { insecureRandomString } from "../utils/insecureRandomString";
|
6 |
+
import { createRepo } from "./create-repo";
|
7 |
+
import { deleteRepo } from "./delete-repo";
|
8 |
+
import { deleteFile } from "./delete-file";
|
9 |
+
import { downloadFile } from "./download-file";
|
10 |
+
|
11 |
+
describe("deleteFile", () => {
|
12 |
+
it("should delete a file", async () => {
|
13 |
+
const repoName = `${TEST_USER}/TEST-${insecureRandomString()}`;
|
14 |
+
const repo = { type: "model", name: repoName } satisfies RepoId;
|
15 |
+
|
16 |
+
try {
|
17 |
+
const result = await createRepo({
|
18 |
+
accessToken: TEST_ACCESS_TOKEN,
|
19 |
+
hubUrl: TEST_HUB_URL,
|
20 |
+
repo,
|
21 |
+
files: [
|
22 |
+
{ path: "file1", content: new Blob(["file1"]) },
|
23 |
+
{ path: "file2", content: new Blob(["file2"]) },
|
24 |
+
],
|
25 |
+
});
|
26 |
+
|
27 |
+
assert.deepStrictEqual(result, {
|
28 |
+
repoUrl: `${TEST_HUB_URL}/${repoName}`,
|
29 |
+
});
|
30 |
+
|
31 |
+
let content = await downloadFile({
|
32 |
+
hubUrl: TEST_HUB_URL,
|
33 |
+
repo,
|
34 |
+
path: "file1",
|
35 |
+
});
|
36 |
+
|
37 |
+
assert.strictEqual(await content?.text(), "file1");
|
38 |
+
|
39 |
+
await deleteFile({ path: "file1", repo, accessToken: TEST_ACCESS_TOKEN, hubUrl: TEST_HUB_URL });
|
40 |
+
|
41 |
+
content = await downloadFile({
|
42 |
+
repo,
|
43 |
+
path: "file1",
|
44 |
+
hubUrl: TEST_HUB_URL,
|
45 |
+
});
|
46 |
+
|
47 |
+
assert.strictEqual(content, null);
|
48 |
+
|
49 |
+
content = await downloadFile({
|
50 |
+
repo,
|
51 |
+
path: "file2",
|
52 |
+
hubUrl: TEST_HUB_URL,
|
53 |
+
});
|
54 |
+
|
55 |
+
assert.strictEqual(await content?.text(), "file2");
|
56 |
+
} finally {
|
57 |
+
await deleteRepo({
|
58 |
+
repo,
|
59 |
+
accessToken: TEST_ACCESS_TOKEN,
|
60 |
+
hubUrl: TEST_HUB_URL,
|
61 |
+
});
|
62 |
+
}
|
63 |
+
});
|
64 |
+
});
|
lib/delete-file.ts
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import type { CredentialsParams } from "../types/public";
|
2 |
+
import type { CommitOutput, CommitParams } from "./commit";
|
3 |
+
import { commit } from "./commit";
|
4 |
+
|
5 |
+
export function deleteFile(
|
6 |
+
params: {
|
7 |
+
repo: CommitParams["repo"];
|
8 |
+
path: string;
|
9 |
+
commitTitle?: CommitParams["title"];
|
10 |
+
commitDescription?: CommitParams["description"];
|
11 |
+
hubUrl?: CommitParams["hubUrl"];
|
12 |
+
fetch?: CommitParams["fetch"];
|
13 |
+
branch?: CommitParams["branch"];
|
14 |
+
isPullRequest?: CommitParams["isPullRequest"];
|
15 |
+
parentCommit?: CommitParams["parentCommit"];
|
16 |
+
} & CredentialsParams
|
17 |
+
): Promise<CommitOutput> {
|
18 |
+
return commit({
|
19 |
+
...(params.accessToken ? { accessToken: params.accessToken } : { credentials: params.credentials }),
|
20 |
+
repo: params.repo,
|
21 |
+
operations: [
|
22 |
+
{
|
23 |
+
operation: "delete",
|
24 |
+
path: params.path,
|
25 |
+
},
|
26 |
+
],
|
27 |
+
title: params.commitTitle ?? `Delete ${params.path}`,
|
28 |
+
description: params.commitDescription,
|
29 |
+
hubUrl: params.hubUrl,
|
30 |
+
branch: params.branch,
|
31 |
+
isPullRequest: params.isPullRequest,
|
32 |
+
parentCommit: params.parentCommit,
|
33 |
+
fetch: params.fetch,
|
34 |
+
});
|
35 |
+
}
|
lib/delete-files.spec.ts
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { assert, it, describe } from "vitest";
|
2 |
+
|
3 |
+
import { TEST_HUB_URL, TEST_ACCESS_TOKEN, TEST_USER } from "../test/consts";
|
4 |
+
import type { RepoId } from "../types/public";
|
5 |
+
import { insecureRandomString } from "../utils/insecureRandomString";
|
6 |
+
import { createRepo } from "./create-repo";
|
7 |
+
import { deleteRepo } from "./delete-repo";
|
8 |
+
import { deleteFiles } from "./delete-files";
|
9 |
+
import { downloadFile } from "./download-file";
|
10 |
+
|
11 |
+
describe("deleteFiles", () => {
|
12 |
+
it("should delete multiple files", async () => {
|
13 |
+
const repoName = `${TEST_USER}/TEST-${insecureRandomString()}`;
|
14 |
+
const repo = { type: "model", name: repoName } satisfies RepoId;
|
15 |
+
|
16 |
+
try {
|
17 |
+
const result = await createRepo({
|
18 |
+
accessToken: TEST_ACCESS_TOKEN,
|
19 |
+
repo,
|
20 |
+
files: [
|
21 |
+
{ path: "file1", content: new Blob(["file1"]) },
|
22 |
+
{ path: "file2", content: new Blob(["file2"]) },
|
23 |
+
{ path: "file3", content: new Blob(["file3"]) },
|
24 |
+
],
|
25 |
+
hubUrl: TEST_HUB_URL,
|
26 |
+
});
|
27 |
+
|
28 |
+
assert.deepStrictEqual(result, {
|
29 |
+
repoUrl: `${TEST_HUB_URL}/${repoName}`,
|
30 |
+
});
|
31 |
+
|
32 |
+
let content = await downloadFile({
|
33 |
+
repo,
|
34 |
+
path: "file1",
|
35 |
+
hubUrl: TEST_HUB_URL,
|
36 |
+
});
|
37 |
+
|
38 |
+
assert.strictEqual(await content?.text(), "file1");
|
39 |
+
|
40 |
+
content = await downloadFile({
|
41 |
+
repo,
|
42 |
+
path: "file2",
|
43 |
+
hubUrl: TEST_HUB_URL,
|
44 |
+
});
|
45 |
+
|
46 |
+
assert.strictEqual(await content?.text(), "file2");
|
47 |
+
|
48 |
+
await deleteFiles({ paths: ["file1", "file2"], repo, accessToken: TEST_ACCESS_TOKEN, hubUrl: TEST_HUB_URL });
|
49 |
+
|
50 |
+
content = await downloadFile({
|
51 |
+
repo,
|
52 |
+
path: "file1",
|
53 |
+
hubUrl: TEST_HUB_URL,
|
54 |
+
});
|
55 |
+
|
56 |
+
assert.strictEqual(content, null);
|
57 |
+
|
58 |
+
content = await downloadFile({
|
59 |
+
repo,
|
60 |
+
path: "file2",
|
61 |
+
hubUrl: TEST_HUB_URL,
|
62 |
+
});
|
63 |
+
|
64 |
+
assert.strictEqual(content, null);
|
65 |
+
|
66 |
+
content = await downloadFile({
|
67 |
+
repo,
|
68 |
+
path: "file3",
|
69 |
+
hubUrl: TEST_HUB_URL,
|
70 |
+
});
|
71 |
+
|
72 |
+
assert.strictEqual(await content?.text(), "file3");
|
73 |
+
} finally {
|
74 |
+
await deleteRepo({
|
75 |
+
repo,
|
76 |
+
accessToken: TEST_ACCESS_TOKEN,
|
77 |
+
hubUrl: TEST_HUB_URL,
|
78 |
+
});
|
79 |
+
}
|
80 |
+
});
|
81 |
+
});
|
lib/delete-files.ts
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import type { CredentialsParams } from "../types/public";
|
2 |
+
import type { CommitOutput, CommitParams } from "./commit";
|
3 |
+
import { commit } from "./commit";
|
4 |
+
|
5 |
+
export function deleteFiles(
|
6 |
+
params: {
|
7 |
+
repo: CommitParams["repo"];
|
8 |
+
paths: string[];
|
9 |
+
commitTitle?: CommitParams["title"];
|
10 |
+
commitDescription?: CommitParams["description"];
|
11 |
+
hubUrl?: CommitParams["hubUrl"];
|
12 |
+
branch?: CommitParams["branch"];
|
13 |
+
isPullRequest?: CommitParams["isPullRequest"];
|
14 |
+
parentCommit?: CommitParams["parentCommit"];
|
15 |
+
fetch?: CommitParams["fetch"];
|
16 |
+
} & CredentialsParams
|
17 |
+
): Promise<CommitOutput> {
|
18 |
+
return commit({
|
19 |
+
...(params.accessToken ? { accessToken: params.accessToken } : { credentials: params.credentials }),
|
20 |
+
repo: params.repo,
|
21 |
+
operations: params.paths.map((path) => ({
|
22 |
+
operation: "delete",
|
23 |
+
path,
|
24 |
+
})),
|
25 |
+
title: params.commitTitle ?? `Deletes ${params.paths.length} files`,
|
26 |
+
description: params.commitDescription,
|
27 |
+
hubUrl: params.hubUrl,
|
28 |
+
branch: params.branch,
|
29 |
+
isPullRequest: params.isPullRequest,
|
30 |
+
parentCommit: params.parentCommit,
|
31 |
+
fetch: params.fetch,
|
32 |
+
});
|
33 |
+
}
|
lib/delete-repo.ts
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { HUB_URL } from "../consts";
|
2 |
+
import { createApiError } from "../error";
|
3 |
+
import type { CredentialsParams, RepoDesignation } from "../types/public";
|
4 |
+
import { checkCredentials } from "../utils/checkCredentials";
|
5 |
+
import { toRepoId } from "../utils/toRepoId";
|
6 |
+
|
7 |
+
export async function deleteRepo(
|
8 |
+
params: {
|
9 |
+
repo: RepoDesignation;
|
10 |
+
hubUrl?: string;
|
11 |
+
/**
|
12 |
+
* Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
|
13 |
+
*/
|
14 |
+
fetch?: typeof fetch;
|
15 |
+
} & CredentialsParams
|
16 |
+
): Promise<void> {
|
17 |
+
const accessToken = checkCredentials(params);
|
18 |
+
const repoId = toRepoId(params.repo);
|
19 |
+
const [namespace, repoName] = repoId.name.split("/");
|
20 |
+
|
21 |
+
const res = await (params.fetch ?? fetch)(`${params.hubUrl ?? HUB_URL}/api/repos/delete`, {
|
22 |
+
method: "DELETE",
|
23 |
+
body: JSON.stringify({
|
24 |
+
name: repoName,
|
25 |
+
organization: namespace,
|
26 |
+
type: repoId.type,
|
27 |
+
}),
|
28 |
+
headers: {
|
29 |
+
Authorization: `Bearer ${accessToken}`,
|
30 |
+
"Content-Type": "application/json",
|
31 |
+
},
|
32 |
+
});
|
33 |
+
|
34 |
+
if (!res.ok) {
|
35 |
+
throw await createApiError(res);
|
36 |
+
}
|
37 |
+
}
|
lib/download-file-to-cache-dir.spec.ts
ADDED
@@ -0,0 +1,306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { expect, test, describe, vi, beforeEach } from "vitest";
|
2 |
+
import type { RepoDesignation, RepoId } from "../types/public";
|
3 |
+
import { dirname, join } from "node:path";
|
4 |
+
import { lstat, mkdir, stat, symlink, rename } from "node:fs/promises";
|
5 |
+
import { pathsInfo } from "./paths-info";
|
6 |
+
import { createWriteStream, type Stats } from "node:fs";
|
7 |
+
import { getHFHubCachePath, getRepoFolderName } from "./cache-management";
|
8 |
+
import { toRepoId } from "../utils/toRepoId";
|
9 |
+
import { downloadFileToCacheDir } from "./download-file-to-cache-dir";
|
10 |
+
import { createSymlink } from "../utils/symlink";
|
11 |
+
|
12 |
+
vi.mock("node:fs/promises", () => ({
|
13 |
+
rename: vi.fn(),
|
14 |
+
symlink: vi.fn(),
|
15 |
+
lstat: vi.fn(),
|
16 |
+
mkdir: vi.fn(),
|
17 |
+
stat: vi.fn(),
|
18 |
+
}));
|
19 |
+
|
20 |
+
vi.mock("node:fs", () => ({
|
21 |
+
createWriteStream: vi.fn(),
|
22 |
+
}));
|
23 |
+
|
24 |
+
vi.mock("./paths-info", () => ({
|
25 |
+
pathsInfo: vi.fn(),
|
26 |
+
}));
|
27 |
+
|
28 |
+
vi.mock("../utils/symlink", () => ({
|
29 |
+
createSymlink: vi.fn(),
|
30 |
+
}));
|
31 |
+
|
32 |
+
const DUMMY_REPO: RepoId = {
|
33 |
+
name: "hello-world",
|
34 |
+
type: "model",
|
35 |
+
};
|
36 |
+
|
37 |
+
const DUMMY_ETAG = "dummy-etag";
|
38 |
+
|
39 |
+
// utility test method to get blob file path
|
40 |
+
function _getBlobFile(params: {
|
41 |
+
repo: RepoDesignation;
|
42 |
+
etag: string;
|
43 |
+
cacheDir?: string; // default to {@link getHFHubCache}
|
44 |
+
}) {
|
45 |
+
return join(params.cacheDir ?? getHFHubCachePath(), getRepoFolderName(toRepoId(params.repo)), "blobs", params.etag);
|
46 |
+
}
|
47 |
+
|
48 |
+
// utility test method to get snapshot file path
|
49 |
+
function _getSnapshotFile(params: {
|
50 |
+
repo: RepoDesignation;
|
51 |
+
path: string;
|
52 |
+
revision: string;
|
53 |
+
cacheDir?: string; // default to {@link getHFHubCache}
|
54 |
+
}) {
|
55 |
+
return join(
|
56 |
+
params.cacheDir ?? getHFHubCachePath(),
|
57 |
+
getRepoFolderName(toRepoId(params.repo)),
|
58 |
+
"snapshots",
|
59 |
+
params.revision,
|
60 |
+
params.path
|
61 |
+
);
|
62 |
+
}
|
63 |
+
|
64 |
+
describe("downloadFileToCacheDir", () => {
|
65 |
+
const fetchMock: typeof fetch = vi.fn();
|
66 |
+
beforeEach(() => {
|
67 |
+
vi.resetAllMocks();
|
68 |
+
// mock 200 request
|
69 |
+
vi.mocked(fetchMock).mockResolvedValue(
|
70 |
+
new Response("dummy-body", {
|
71 |
+
status: 200,
|
72 |
+
headers: {
|
73 |
+
etag: DUMMY_ETAG,
|
74 |
+
"Content-Range": "bytes 0-54/55",
|
75 |
+
},
|
76 |
+
})
|
77 |
+
);
|
78 |
+
|
79 |
+
// prevent to use caching
|
80 |
+
vi.mocked(stat).mockRejectedValue(new Error("Do not exists"));
|
81 |
+
vi.mocked(lstat).mockRejectedValue(new Error("Do not exists"));
|
82 |
+
});
|
83 |
+
|
84 |
+
test("should throw an error if fileDownloadInfo return nothing", async () => {
|
85 |
+
await expect(async () => {
|
86 |
+
await downloadFileToCacheDir({
|
87 |
+
repo: DUMMY_REPO,
|
88 |
+
path: "/README.md",
|
89 |
+
fetch: fetchMock,
|
90 |
+
});
|
91 |
+
}).rejects.toThrowError("cannot get path info for /README.md");
|
92 |
+
|
93 |
+
expect(pathsInfo).toHaveBeenCalledWith(
|
94 |
+
expect.objectContaining({
|
95 |
+
repo: DUMMY_REPO,
|
96 |
+
paths: ["/README.md"],
|
97 |
+
fetch: fetchMock,
|
98 |
+
})
|
99 |
+
);
|
100 |
+
});
|
101 |
+
|
102 |
+
test("existing symlinked and blob should not re-download it", async () => {
|
103 |
+
// <cache>/<repo>/<revision>/snapshots/README.md
|
104 |
+
const expectPointer = _getSnapshotFile({
|
105 |
+
repo: DUMMY_REPO,
|
106 |
+
path: "/README.md",
|
107 |
+
revision: "dd4bc8b21efa05ec961e3efc4ee5e3832a3679c7",
|
108 |
+
});
|
109 |
+
// stat ensure a symlink and the pointed file exists
|
110 |
+
vi.mocked(stat).mockResolvedValue({} as Stats); // prevent default mocked reject
|
111 |
+
|
112 |
+
const output = await downloadFileToCacheDir({
|
113 |
+
repo: DUMMY_REPO,
|
114 |
+
path: "/README.md",
|
115 |
+
fetch: fetchMock,
|
116 |
+
revision: "dd4bc8b21efa05ec961e3efc4ee5e3832a3679c7",
|
117 |
+
});
|
118 |
+
|
119 |
+
expect(stat).toHaveBeenCalledOnce();
|
120 |
+
// Get call argument for stat
|
121 |
+
const starArg = vi.mocked(stat).mock.calls[0][0];
|
122 |
+
|
123 |
+
expect(starArg).toBe(expectPointer);
|
124 |
+
expect(fetchMock).not.toHaveBeenCalledWith();
|
125 |
+
|
126 |
+
expect(output).toBe(expectPointer);
|
127 |
+
});
|
128 |
+
|
129 |
+
test("existing symlinked and blob with default revision should not re-download it", async () => {
|
130 |
+
// <cache>/<repo>/<revision>/snapshots/README.md
|
131 |
+
const expectPointer = _getSnapshotFile({
|
132 |
+
repo: DUMMY_REPO,
|
133 |
+
path: "/README.md",
|
134 |
+
revision: "main",
|
135 |
+
});
|
136 |
+
// stat ensure a symlink and the pointed file exists
|
137 |
+
vi.mocked(stat).mockResolvedValue({} as Stats); // prevent default mocked reject
|
138 |
+
vi.mocked(lstat).mockResolvedValue({} as Stats);
|
139 |
+
vi.mocked(pathsInfo).mockResolvedValue([
|
140 |
+
{
|
141 |
+
oid: DUMMY_ETAG,
|
142 |
+
size: 55,
|
143 |
+
path: "README.md",
|
144 |
+
type: "file",
|
145 |
+
lastCommit: {
|
146 |
+
date: new Date(),
|
147 |
+
id: "main",
|
148 |
+
title: "Commit msg",
|
149 |
+
},
|
150 |
+
},
|
151 |
+
]);
|
152 |
+
|
153 |
+
const output = await downloadFileToCacheDir({
|
154 |
+
repo: DUMMY_REPO,
|
155 |
+
path: "/README.md",
|
156 |
+
fetch: fetchMock,
|
157 |
+
});
|
158 |
+
|
159 |
+
expect(stat).toHaveBeenCalledOnce();
|
160 |
+
expect(symlink).not.toHaveBeenCalledOnce();
|
161 |
+
// Get call argument for stat
|
162 |
+
const starArg = vi.mocked(stat).mock.calls[0][0];
|
163 |
+
|
164 |
+
expect(starArg).toBe(expectPointer);
|
165 |
+
expect(fetchMock).not.toHaveBeenCalledWith();
|
166 |
+
|
167 |
+
expect(output).toBe(expectPointer);
|
168 |
+
});
|
169 |
+
|
170 |
+
test("existing blob should only create the symlink", async () => {
|
171 |
+
// <cache>/<repo>/<revision>/snapshots/README.md
|
172 |
+
const expectPointer = _getSnapshotFile({
|
173 |
+
repo: DUMMY_REPO,
|
174 |
+
path: "/README.md",
|
175 |
+
revision: "dummy-commit-hash",
|
176 |
+
});
|
177 |
+
// <cache>/<repo>/blobs/<etag>
|
178 |
+
const expectedBlob = _getBlobFile({
|
179 |
+
repo: DUMMY_REPO,
|
180 |
+
etag: DUMMY_ETAG,
|
181 |
+
});
|
182 |
+
|
183 |
+
// mock existing blob only no symlink
|
184 |
+
vi.mocked(lstat).mockResolvedValue({} as Stats);
|
185 |
+
// mock pathsInfo resolve content
|
186 |
+
vi.mocked(pathsInfo).mockResolvedValue([
|
187 |
+
{
|
188 |
+
oid: DUMMY_ETAG,
|
189 |
+
size: 55,
|
190 |
+
path: "README.md",
|
191 |
+
type: "file",
|
192 |
+
lastCommit: {
|
193 |
+
date: new Date(),
|
194 |
+
id: "dummy-commit-hash",
|
195 |
+
title: "Commit msg",
|
196 |
+
},
|
197 |
+
},
|
198 |
+
]);
|
199 |
+
|
200 |
+
const output = await downloadFileToCacheDir({
|
201 |
+
repo: DUMMY_REPO,
|
202 |
+
path: "/README.md",
|
203 |
+
fetch: fetchMock,
|
204 |
+
});
|
205 |
+
|
206 |
+
// should have check for the blob
|
207 |
+
expect(lstat).toHaveBeenCalled();
|
208 |
+
expect(vi.mocked(lstat).mock.calls[0][0]).toBe(expectedBlob);
|
209 |
+
|
210 |
+
// symlink should have been created
|
211 |
+
expect(createSymlink).toHaveBeenCalledOnce();
|
212 |
+
// no download done
|
213 |
+
expect(fetchMock).not.toHaveBeenCalled();
|
214 |
+
|
215 |
+
expect(output).toBe(expectPointer);
|
216 |
+
});
|
217 |
+
|
218 |
+
test("expect resolve value to be the pointer path of downloaded file", async () => {
|
219 |
+
// <cache>/<repo>/<revision>/snapshots/README.md
|
220 |
+
const expectPointer = _getSnapshotFile({
|
221 |
+
repo: DUMMY_REPO,
|
222 |
+
path: "/README.md",
|
223 |
+
revision: "dummy-commit-hash",
|
224 |
+
});
|
225 |
+
// <cache>/<repo>/blobs/<etag>
|
226 |
+
const expectedBlob = _getBlobFile({
|
227 |
+
repo: DUMMY_REPO,
|
228 |
+
etag: DUMMY_ETAG,
|
229 |
+
});
|
230 |
+
|
231 |
+
vi.mocked(pathsInfo).mockResolvedValue([
|
232 |
+
{
|
233 |
+
oid: DUMMY_ETAG,
|
234 |
+
size: 55,
|
235 |
+
path: "README.md",
|
236 |
+
type: "file",
|
237 |
+
lastCommit: {
|
238 |
+
date: new Date(),
|
239 |
+
id: "dummy-commit-hash",
|
240 |
+
title: "Commit msg",
|
241 |
+
},
|
242 |
+
},
|
243 |
+
]);
|
244 |
+
|
245 |
+
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
246 |
+
vi.mocked(createWriteStream).mockReturnValue(async function* () {} as any);
|
247 |
+
|
248 |
+
const output = await downloadFileToCacheDir({
|
249 |
+
repo: DUMMY_REPO,
|
250 |
+
path: "/README.md",
|
251 |
+
fetch: fetchMock,
|
252 |
+
});
|
253 |
+
|
254 |
+
// expect blobs and snapshots folder to have been mkdir
|
255 |
+
expect(vi.mocked(mkdir).mock.calls[0][0]).toBe(dirname(expectedBlob));
|
256 |
+
expect(vi.mocked(mkdir).mock.calls[1][0]).toBe(dirname(expectPointer));
|
257 |
+
|
258 |
+
expect(output).toBe(expectPointer);
|
259 |
+
});
|
260 |
+
|
261 |
+
test("should write fetch response to blob", async () => {
|
262 |
+
// <cache>/<repo>/<revision>/snapshots/README.md
|
263 |
+
const expectPointer = _getSnapshotFile({
|
264 |
+
repo: DUMMY_REPO,
|
265 |
+
path: "/README.md",
|
266 |
+
revision: "dummy-commit-hash",
|
267 |
+
});
|
268 |
+
// <cache>/<repo>/blobs/<etag>
|
269 |
+
const expectedBlob = _getBlobFile({
|
270 |
+
repo: DUMMY_REPO,
|
271 |
+
etag: DUMMY_ETAG,
|
272 |
+
});
|
273 |
+
|
274 |
+
// mock pathsInfo resolve content
|
275 |
+
vi.mocked(pathsInfo).mockResolvedValue([
|
276 |
+
{
|
277 |
+
oid: DUMMY_ETAG,
|
278 |
+
size: 55,
|
279 |
+
path: "README.md",
|
280 |
+
type: "file",
|
281 |
+
lastCommit: {
|
282 |
+
date: new Date(),
|
283 |
+
id: "dummy-commit-hash",
|
284 |
+
title: "Commit msg",
|
285 |
+
},
|
286 |
+
},
|
287 |
+
]);
|
288 |
+
|
289 |
+
// eslint-disable-next-line @typescript-eslint/no-explicit-any
|
290 |
+
vi.mocked(createWriteStream).mockReturnValue(async function* () {} as any);
|
291 |
+
|
292 |
+
await downloadFileToCacheDir({
|
293 |
+
repo: DUMMY_REPO,
|
294 |
+
path: "/README.md",
|
295 |
+
fetch: fetchMock,
|
296 |
+
});
|
297 |
+
|
298 |
+
const incomplete = `${expectedBlob}.incomplete`;
|
299 |
+
// 1. should write fetch#response#body to incomplete file
|
300 |
+
expect(createWriteStream).toHaveBeenCalledWith(incomplete);
|
301 |
+
// 2. should rename the incomplete to the blob expected name
|
302 |
+
expect(rename).toHaveBeenCalledWith(incomplete, expectedBlob);
|
303 |
+
// 3. should create symlink pointing to blob
|
304 |
+
expect(createSymlink).toHaveBeenCalledWith({ sourcePath: expectedBlob, finalPath: expectPointer });
|
305 |
+
});
|
306 |
+
});
|
lib/download-file-to-cache-dir.ts
ADDED
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { getHFHubCachePath, getRepoFolderName } from "./cache-management";
|
2 |
+
import { dirname, join } from "node:path";
|
3 |
+
import { rename, lstat, mkdir, stat } from "node:fs/promises";
|
4 |
+
import type { CommitInfo, PathInfo } from "./paths-info";
|
5 |
+
import { pathsInfo } from "./paths-info";
|
6 |
+
import type { CredentialsParams, RepoDesignation } from "../types/public";
|
7 |
+
import { toRepoId } from "../utils/toRepoId";
|
8 |
+
import { downloadFile } from "./download-file";
|
9 |
+
import { createSymlink } from "../utils/symlink";
|
10 |
+
import { Readable } from "node:stream";
|
11 |
+
import type { ReadableStream } from "node:stream/web";
|
12 |
+
import { pipeline } from "node:stream/promises";
|
13 |
+
import { createWriteStream } from "node:fs";
|
14 |
+
|
15 |
+
export const REGEX_COMMIT_HASH: RegExp = new RegExp("^[0-9a-f]{40}$");
|
16 |
+
|
17 |
+
function getFilePointer(storageFolder: string, revision: string, relativeFilename: string): string {
|
18 |
+
const snapshotPath = join(storageFolder, "snapshots");
|
19 |
+
return join(snapshotPath, revision, relativeFilename);
|
20 |
+
}
|
21 |
+
|
22 |
+
/**
|
23 |
+
* handy method to check if a file exists, or the pointer of a symlinks exists
|
24 |
+
* @param path
|
25 |
+
* @param followSymlinks
|
26 |
+
*/
|
27 |
+
async function exists(path: string, followSymlinks?: boolean): Promise<boolean> {
|
28 |
+
try {
|
29 |
+
if (followSymlinks) {
|
30 |
+
await stat(path);
|
31 |
+
} else {
|
32 |
+
await lstat(path);
|
33 |
+
}
|
34 |
+
return true;
|
35 |
+
} catch (err: unknown) {
|
36 |
+
return false;
|
37 |
+
}
|
38 |
+
}
|
39 |
+
|
40 |
+
/**
|
41 |
+
* Download a given file if it's not already present in the local cache.
|
42 |
+
* @param params
|
43 |
+
* @return the symlink to the blob object
|
44 |
+
*/
|
45 |
+
export async function downloadFileToCacheDir(
|
46 |
+
params: {
|
47 |
+
repo: RepoDesignation;
|
48 |
+
path: string;
|
49 |
+
/**
|
50 |
+
* If true, will download the raw git file.
|
51 |
+
*
|
52 |
+
* For example, when calling on a file stored with Git LFS, the pointer file will be downloaded instead.
|
53 |
+
*/
|
54 |
+
raw?: boolean;
|
55 |
+
/**
|
56 |
+
* An optional Git revision id which can be a branch name, a tag, or a commit hash.
|
57 |
+
*
|
58 |
+
* @default "main"
|
59 |
+
*/
|
60 |
+
revision?: string;
|
61 |
+
hubUrl?: string;
|
62 |
+
cacheDir?: string;
|
63 |
+
/**
|
64 |
+
* Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
|
65 |
+
*/
|
66 |
+
fetch?: typeof fetch;
|
67 |
+
} & Partial<CredentialsParams>
|
68 |
+
): Promise<string> {
|
69 |
+
// get revision provided or default to main
|
70 |
+
const revision = params.revision ?? "main";
|
71 |
+
const cacheDir = params.cacheDir ?? getHFHubCachePath();
|
72 |
+
// get repo id
|
73 |
+
const repoId = toRepoId(params.repo);
|
74 |
+
// get storage folder
|
75 |
+
const storageFolder = join(cacheDir, getRepoFolderName(repoId));
|
76 |
+
|
77 |
+
let commitHash: string | undefined;
|
78 |
+
|
79 |
+
// if user provides a commitHash as revision, and they already have the file on disk, shortcut everything.
|
80 |
+
if (REGEX_COMMIT_HASH.test(revision)) {
|
81 |
+
commitHash = revision;
|
82 |
+
const pointerPath = getFilePointer(storageFolder, revision, params.path);
|
83 |
+
if (await exists(pointerPath, true)) return pointerPath;
|
84 |
+
}
|
85 |
+
|
86 |
+
const pathsInformation: (PathInfo & { lastCommit: CommitInfo })[] = await pathsInfo({
|
87 |
+
...params,
|
88 |
+
paths: [params.path],
|
89 |
+
revision: revision,
|
90 |
+
expand: true,
|
91 |
+
});
|
92 |
+
if (!pathsInformation || pathsInformation.length !== 1) throw new Error(`cannot get path info for ${params.path}`);
|
93 |
+
|
94 |
+
let etag: string;
|
95 |
+
if (pathsInformation[0].lfs) {
|
96 |
+
etag = pathsInformation[0].lfs.oid; // get the LFS pointed file oid
|
97 |
+
} else {
|
98 |
+
etag = pathsInformation[0].oid; // get the repo file if not a LFS pointer
|
99 |
+
}
|
100 |
+
|
101 |
+
const pointerPath = getFilePointer(storageFolder, commitHash ?? pathsInformation[0].lastCommit.id, params.path);
|
102 |
+
const blobPath = join(storageFolder, "blobs", etag);
|
103 |
+
|
104 |
+
// if we have the pointer file, we can shortcut the download
|
105 |
+
if (await exists(pointerPath, true)) return pointerPath;
|
106 |
+
|
107 |
+
// mkdir blob and pointer path parent directory
|
108 |
+
await mkdir(dirname(blobPath), { recursive: true });
|
109 |
+
await mkdir(dirname(pointerPath), { recursive: true });
|
110 |
+
|
111 |
+
// We might already have the blob but not the pointer
|
112 |
+
// shortcut the download if needed
|
113 |
+
if (await exists(blobPath)) {
|
114 |
+
// create symlinks in snapshot folder to blob object
|
115 |
+
await createSymlink({ sourcePath: blobPath, finalPath: pointerPath });
|
116 |
+
return pointerPath;
|
117 |
+
}
|
118 |
+
|
119 |
+
const incomplete = `${blobPath}.incomplete`;
|
120 |
+
console.debug(`Downloading ${params.path} to ${incomplete}`);
|
121 |
+
|
122 |
+
const blob: Blob | null = await downloadFile({
|
123 |
+
...params,
|
124 |
+
revision: commitHash,
|
125 |
+
});
|
126 |
+
|
127 |
+
if (!blob) {
|
128 |
+
throw new Error(`invalid response for file ${params.path}`);
|
129 |
+
}
|
130 |
+
|
131 |
+
await pipeline(Readable.fromWeb(blob.stream() as ReadableStream), createWriteStream(incomplete));
|
132 |
+
|
133 |
+
// rename .incomplete file to expect blob
|
134 |
+
await rename(incomplete, blobPath);
|
135 |
+
// create symlinks in snapshot folder to blob object
|
136 |
+
await createSymlink({ sourcePath: blobPath, finalPath: pointerPath });
|
137 |
+
return pointerPath;
|
138 |
+
}
|
lib/download-file.spec.ts
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { expect, test, describe, assert } from "vitest";
|
2 |
+
import { downloadFile } from "./download-file";
|
3 |
+
import { deleteRepo } from "./delete-repo";
|
4 |
+
import { createRepo } from "./create-repo";
|
5 |
+
import { TEST_ACCESS_TOKEN, TEST_HUB_URL, TEST_USER } from "../test/consts";
|
6 |
+
import { insecureRandomString } from "../utils/insecureRandomString";
|
7 |
+
|
8 |
+
describe("downloadFile", () => {
|
9 |
+
test("should download regular file", async () => {
|
10 |
+
const blob = await downloadFile({
|
11 |
+
repo: {
|
12 |
+
type: "model",
|
13 |
+
name: "openai-community/gpt2",
|
14 |
+
},
|
15 |
+
path: "README.md",
|
16 |
+
});
|
17 |
+
|
18 |
+
const text = await blob?.slice(0, 1000).text();
|
19 |
+
assert(
|
20 |
+
text?.includes(`---
|
21 |
+
language: en
|
22 |
+
tags:
|
23 |
+
- exbert
|
24 |
+
|
25 |
+
license: mit
|
26 |
+
---
|
27 |
+
|
28 |
+
|
29 |
+
# GPT-2
|
30 |
+
|
31 |
+
Test the whole generation capabilities here: https://transformer.huggingface.co/doc/gpt2-large`)
|
32 |
+
);
|
33 |
+
});
|
34 |
+
test("should downoad xet file", async () => {
|
35 |
+
const blob = await downloadFile({
|
36 |
+
repo: {
|
37 |
+
type: "model",
|
38 |
+
name: "celinah/xet-experiments",
|
39 |
+
},
|
40 |
+
path: "large_text.txt",
|
41 |
+
});
|
42 |
+
|
43 |
+
const text = await blob?.slice(0, 100).text();
|
44 |
+
expect(text).toMatch("this is a text file.".repeat(10).slice(0, 100));
|
45 |
+
});
|
46 |
+
|
47 |
+
test("should download private file", async () => {
|
48 |
+
const repoName = `datasets/${TEST_USER}/TEST-${insecureRandomString()}`;
|
49 |
+
|
50 |
+
const result = await createRepo({
|
51 |
+
accessToken: TEST_ACCESS_TOKEN,
|
52 |
+
hubUrl: TEST_HUB_URL,
|
53 |
+
private: true,
|
54 |
+
repo: repoName,
|
55 |
+
files: [{ path: ".gitattributes", content: new Blob(["*.html filter=lfs diff=lfs merge=lfs -text"]) }],
|
56 |
+
});
|
57 |
+
|
58 |
+
assert.deepStrictEqual(result, {
|
59 |
+
repoUrl: `${TEST_HUB_URL}/${repoName}`,
|
60 |
+
});
|
61 |
+
|
62 |
+
try {
|
63 |
+
const blob = await downloadFile({
|
64 |
+
repo: repoName,
|
65 |
+
path: ".gitattributes",
|
66 |
+
hubUrl: TEST_HUB_URL,
|
67 |
+
accessToken: TEST_ACCESS_TOKEN,
|
68 |
+
});
|
69 |
+
|
70 |
+
assert(blob, "File should be found");
|
71 |
+
|
72 |
+
const text = await blob?.text();
|
73 |
+
assert.strictEqual(text, "*.html filter=lfs diff=lfs merge=lfs -text");
|
74 |
+
} finally {
|
75 |
+
await deleteRepo({
|
76 |
+
repo: repoName,
|
77 |
+
hubUrl: TEST_HUB_URL,
|
78 |
+
accessToken: TEST_ACCESS_TOKEN,
|
79 |
+
});
|
80 |
+
}
|
81 |
+
});
|
82 |
+
});
|
lib/download-file.ts
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import type { CredentialsParams, RepoDesignation } from "../types/public";
|
2 |
+
import { checkCredentials } from "../utils/checkCredentials";
|
3 |
+
import { WebBlob } from "../utils/WebBlob";
|
4 |
+
import { XetBlob } from "../utils/XetBlob";
|
5 |
+
import type { FileDownloadInfoOutput } from "./file-download-info";
|
6 |
+
import { fileDownloadInfo } from "./file-download-info";
|
7 |
+
|
8 |
+
/**
|
9 |
+
* @returns null when the file doesn't exist
|
10 |
+
*/
|
11 |
+
export async function downloadFile(
|
12 |
+
params: {
|
13 |
+
repo: RepoDesignation;
|
14 |
+
path: string;
|
15 |
+
/**
|
16 |
+
* If true, will download the raw git file.
|
17 |
+
*
|
18 |
+
* For example, when calling on a file stored with Git LFS, the pointer file will be downloaded instead.
|
19 |
+
*/
|
20 |
+
raw?: boolean;
|
21 |
+
/**
|
22 |
+
* An optional Git revision id which can be a branch name, a tag, or a commit hash.
|
23 |
+
*
|
24 |
+
* @default "main"
|
25 |
+
*/
|
26 |
+
revision?: string;
|
27 |
+
hubUrl?: string;
|
28 |
+
/**
|
29 |
+
* Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
|
30 |
+
*/
|
31 |
+
fetch?: typeof fetch;
|
32 |
+
/**
|
33 |
+
* Whether to use the xet protocol to download the file (if applicable).
|
34 |
+
*
|
35 |
+
* Currently there's experimental support for it, so it's not enabled by default.
|
36 |
+
*
|
37 |
+
* It will be enabled automatically in a future minor version.
|
38 |
+
*
|
39 |
+
* @default false
|
40 |
+
*/
|
41 |
+
xet?: boolean;
|
42 |
+
/**
|
43 |
+
* Can save an http request if provided
|
44 |
+
*/
|
45 |
+
downloadInfo?: FileDownloadInfoOutput;
|
46 |
+
} & Partial<CredentialsParams>
|
47 |
+
): Promise<Blob | null> {
|
48 |
+
const accessToken = checkCredentials(params);
|
49 |
+
|
50 |
+
const info =
|
51 |
+
params.downloadInfo ??
|
52 |
+
(await fileDownloadInfo({
|
53 |
+
accessToken,
|
54 |
+
repo: params.repo,
|
55 |
+
path: params.path,
|
56 |
+
revision: params.revision,
|
57 |
+
hubUrl: params.hubUrl,
|
58 |
+
fetch: params.fetch,
|
59 |
+
raw: params.raw,
|
60 |
+
}));
|
61 |
+
|
62 |
+
if (!info) {
|
63 |
+
return null;
|
64 |
+
}
|
65 |
+
|
66 |
+
if (info.xet && params.xet) {
|
67 |
+
return new XetBlob({
|
68 |
+
refreshUrl: info.xet.refreshUrl.href,
|
69 |
+
reconstructionUrl: info.xet.reconstructionUrl.href,
|
70 |
+
fetch: params.fetch,
|
71 |
+
accessToken,
|
72 |
+
size: info.size,
|
73 |
+
});
|
74 |
+
}
|
75 |
+
|
76 |
+
return new WebBlob(new URL(info.url), 0, info.size, "", true, params.fetch ?? fetch, accessToken);
|
77 |
+
}
|
lib/file-download-info.spec.ts
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { assert, it, describe } from "vitest";
|
2 |
+
import { fileDownloadInfo } from "./file-download-info";
|
3 |
+
|
4 |
+
describe("fileDownloadInfo", () => {
|
5 |
+
it("should fetch LFS file info", async () => {
|
6 |
+
const info = await fileDownloadInfo({
|
7 |
+
repo: {
|
8 |
+
name: "bert-base-uncased",
|
9 |
+
type: "model",
|
10 |
+
},
|
11 |
+
path: "tf_model.h5",
|
12 |
+
revision: "dd4bc8b21efa05ec961e3efc4ee5e3832a3679c7",
|
13 |
+
});
|
14 |
+
|
15 |
+
assert.strictEqual(info?.size, 536063208);
|
16 |
+
assert.strictEqual(info?.etag, '"a7a17d6d844b5de815ccab5f42cad6d24496db3850a2a43d8258221018ce87d2"');
|
17 |
+
});
|
18 |
+
|
19 |
+
it("should fetch raw LFS pointer info", async () => {
|
20 |
+
const info = await fileDownloadInfo({
|
21 |
+
repo: {
|
22 |
+
name: "bert-base-uncased",
|
23 |
+
type: "model",
|
24 |
+
},
|
25 |
+
path: "tf_model.h5",
|
26 |
+
revision: "dd4bc8b21efa05ec961e3efc4ee5e3832a3679c7",
|
27 |
+
raw: true,
|
28 |
+
});
|
29 |
+
|
30 |
+
assert.strictEqual(info?.size, 134);
|
31 |
+
assert.strictEqual(info?.etag, '"9eb98c817f04b051b3bcca591bcd4e03cec88018"');
|
32 |
+
});
|
33 |
+
|
34 |
+
it("should fetch non-LFS file info", async () => {
|
35 |
+
const info = await fileDownloadInfo({
|
36 |
+
repo: {
|
37 |
+
name: "bert-base-uncased",
|
38 |
+
type: "model",
|
39 |
+
},
|
40 |
+
path: "tokenizer_config.json",
|
41 |
+
revision: "1a7dd4986e3dab699c24ca19b2afd0f5e1a80f37",
|
42 |
+
});
|
43 |
+
|
44 |
+
assert.strictEqual(info?.size, 28);
|
45 |
+
assert.strictEqual(info?.etag, '"a661b1a138dac6dc5590367402d100765010ffd6"');
|
46 |
+
});
|
47 |
+
|
48 |
+
it("should fetch xet file info", async () => {
|
49 |
+
const info = await fileDownloadInfo({
|
50 |
+
repo: {
|
51 |
+
type: "model",
|
52 |
+
name: "celinah/xet-experiments",
|
53 |
+
},
|
54 |
+
path: "large_text.txt",
|
55 |
+
});
|
56 |
+
assert.strictEqual(info?.size, 62914580);
|
57 |
+
assert.strictEqual(info?.etag, '"c27f98578d9363b27db0bc1cbd9c692f8e6e90ae98c38cee7bc0a88829debd17"');
|
58 |
+
});
|
59 |
+
});
|
lib/file-download-info.ts
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { HUB_URL } from "../consts";
|
2 |
+
import { createApiError, InvalidApiResponseFormatError } from "../error";
|
3 |
+
import type { CredentialsParams, RepoDesignation } from "../types/public";
|
4 |
+
import { checkCredentials } from "../utils/checkCredentials";
|
5 |
+
import { parseLinkHeader } from "../utils/parseLinkHeader";
|
6 |
+
import { toRepoId } from "../utils/toRepoId";
|
7 |
+
|
8 |
+
export interface XetFileInfo {
|
9 |
+
hash: string;
|
10 |
+
refreshUrl: URL;
|
11 |
+
/**
|
12 |
+
* Can be directly used instead of the hash.
|
13 |
+
*/
|
14 |
+
reconstructionUrl: URL;
|
15 |
+
}
|
16 |
+
|
17 |
+
export interface FileDownloadInfoOutput {
|
18 |
+
size: number;
|
19 |
+
etag: string;
|
20 |
+
xet?: XetFileInfo;
|
21 |
+
// URL to fetch (with the access token if private file)
|
22 |
+
url: string;
|
23 |
+
}
|
24 |
+
/**
|
25 |
+
* @returns null when the file doesn't exist
|
26 |
+
*/
|
27 |
+
export async function fileDownloadInfo(
|
28 |
+
params: {
|
29 |
+
repo: RepoDesignation;
|
30 |
+
path: string;
|
31 |
+
revision?: string;
|
32 |
+
hubUrl?: string;
|
33 |
+
/**
|
34 |
+
* Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
|
35 |
+
*/
|
36 |
+
fetch?: typeof fetch;
|
37 |
+
/**
|
38 |
+
* To get the raw pointer file behind a LFS file
|
39 |
+
*/
|
40 |
+
raw?: boolean;
|
41 |
+
/**
|
42 |
+
* To avoid the content-disposition header in the `downloadLink` for LFS files
|
43 |
+
*
|
44 |
+
* So that on browsers you can use the URL in an iframe for example
|
45 |
+
*/
|
46 |
+
noContentDisposition?: boolean;
|
47 |
+
} & Partial<CredentialsParams>
|
48 |
+
): Promise<FileDownloadInfoOutput | null> {
|
49 |
+
const accessToken = checkCredentials(params);
|
50 |
+
const repoId = toRepoId(params.repo);
|
51 |
+
|
52 |
+
const hubUrl = params.hubUrl ?? HUB_URL;
|
53 |
+
const url =
|
54 |
+
`${hubUrl}/${repoId.type === "model" ? "" : `${repoId.type}s/`}${repoId.name}/${
|
55 |
+
params.raw ? "raw" : "resolve"
|
56 |
+
}/${encodeURIComponent(params.revision ?? "main")}/${params.path}` +
|
57 |
+
(params.noContentDisposition ? "?noContentDisposition=1" : "");
|
58 |
+
|
59 |
+
const resp = await (params.fetch ?? fetch)(url, {
|
60 |
+
method: "GET",
|
61 |
+
headers: {
|
62 |
+
...(accessToken && {
|
63 |
+
Authorization: `Bearer ${accessToken}`,
|
64 |
+
}),
|
65 |
+
Range: "bytes=0-0",
|
66 |
+
Accept: "application/vnd.xet-fileinfo+json, */*",
|
67 |
+
},
|
68 |
+
});
|
69 |
+
|
70 |
+
if (resp.status === 404 && resp.headers.get("X-Error-Code") === "EntryNotFound") {
|
71 |
+
return null;
|
72 |
+
}
|
73 |
+
|
74 |
+
if (!resp.ok) {
|
75 |
+
throw await createApiError(resp);
|
76 |
+
}
|
77 |
+
|
78 |
+
let size: number | undefined;
|
79 |
+
let xetInfo: XetFileInfo | undefined;
|
80 |
+
|
81 |
+
if (resp.headers.get("Content-Type")?.includes("application/vnd.xet-fileinfo+json")) {
|
82 |
+
size = parseInt(resp.headers.get("X-Linked-Size") ?? "invalid");
|
83 |
+
if (isNaN(size)) {
|
84 |
+
throw new InvalidApiResponseFormatError("Invalid file size received in X-Linked-Size header");
|
85 |
+
}
|
86 |
+
|
87 |
+
const hash = resp.headers.get("X-Xet-Hash");
|
88 |
+
const links = parseLinkHeader(resp.headers.get("Link") ?? "");
|
89 |
+
|
90 |
+
const reconstructionUrl = (() => {
|
91 |
+
try {
|
92 |
+
return new URL(links["xet-reconstruction-info"]);
|
93 |
+
} catch {
|
94 |
+
return null;
|
95 |
+
}
|
96 |
+
})();
|
97 |
+
const refreshUrl = (() => {
|
98 |
+
try {
|
99 |
+
return new URL(links["xet-auth"]);
|
100 |
+
} catch {
|
101 |
+
return null;
|
102 |
+
}
|
103 |
+
})();
|
104 |
+
|
105 |
+
if (!hash) {
|
106 |
+
throw new InvalidApiResponseFormatError("No hash received in X-Xet-Hash header");
|
107 |
+
}
|
108 |
+
|
109 |
+
if (!reconstructionUrl || !refreshUrl) {
|
110 |
+
throw new InvalidApiResponseFormatError("No xet-reconstruction-info or xet-auth link header");
|
111 |
+
}
|
112 |
+
xetInfo = {
|
113 |
+
hash,
|
114 |
+
refreshUrl,
|
115 |
+
reconstructionUrl,
|
116 |
+
};
|
117 |
+
}
|
118 |
+
|
119 |
+
if (size === undefined || isNaN(size)) {
|
120 |
+
const contentRangeHeader = resp.headers.get("content-range");
|
121 |
+
|
122 |
+
if (!contentRangeHeader) {
|
123 |
+
throw new InvalidApiResponseFormatError("Expected size information");
|
124 |
+
}
|
125 |
+
|
126 |
+
const [, parsedSize] = contentRangeHeader.split("/");
|
127 |
+
size = parseInt(parsedSize);
|
128 |
+
|
129 |
+
if (isNaN(size)) {
|
130 |
+
throw new InvalidApiResponseFormatError("Invalid file size received");
|
131 |
+
}
|
132 |
+
}
|
133 |
+
|
134 |
+
const etag = resp.headers.get("X-Linked-ETag") ?? resp.headers.get("ETag") ?? undefined;
|
135 |
+
|
136 |
+
if (!etag) {
|
137 |
+
throw new InvalidApiResponseFormatError("Expected ETag");
|
138 |
+
}
|
139 |
+
|
140 |
+
return {
|
141 |
+
etag,
|
142 |
+
size,
|
143 |
+
xet: xetInfo,
|
144 |
+
// Cannot use resp.url in case it's a S3 url and the user adds an Authorization header to it.
|
145 |
+
url:
|
146 |
+
resp.url &&
|
147 |
+
(new URL(resp.url).origin === new URL(hubUrl).origin || resp.headers.get("X-Cache")?.endsWith(" cloudfront"))
|
148 |
+
? resp.url
|
149 |
+
: url,
|
150 |
+
};
|
151 |
+
}
|
lib/file-exists.spec.ts
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { assert, it, describe } from "vitest";
|
2 |
+
import { fileExists } from "./file-exists";
|
3 |
+
|
4 |
+
describe("fileExists", () => {
|
5 |
+
it("should return true for file that exists", async () => {
|
6 |
+
const info = await fileExists({
|
7 |
+
repo: {
|
8 |
+
name: "bert-base-uncased",
|
9 |
+
type: "model",
|
10 |
+
},
|
11 |
+
path: "tf_model.h5",
|
12 |
+
revision: "dd4bc8b21efa05ec961e3efc4ee5e3832a3679c7",
|
13 |
+
});
|
14 |
+
|
15 |
+
assert(info, "file should exist");
|
16 |
+
});
|
17 |
+
|
18 |
+
it("should return false for file that does not exist", async () => {
|
19 |
+
const info = await fileExists({
|
20 |
+
repo: {
|
21 |
+
name: "bert-base-uncased",
|
22 |
+
type: "model",
|
23 |
+
},
|
24 |
+
path: "tf_model.h5dadazdzazd",
|
25 |
+
revision: "dd4bc8b21efa05ec961e3efc4ee5e3832a3679c7",
|
26 |
+
});
|
27 |
+
|
28 |
+
assert(!info, "file should not exist");
|
29 |
+
});
|
30 |
+
});
|
lib/file-exists.ts
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { HUB_URL } from "../consts";
|
2 |
+
import { createApiError } from "../error";
|
3 |
+
import type { CredentialsParams, RepoDesignation } from "../types/public";
|
4 |
+
import { checkCredentials } from "../utils/checkCredentials";
|
5 |
+
import { toRepoId } from "../utils/toRepoId";
|
6 |
+
|
7 |
+
export async function fileExists(
|
8 |
+
params: {
|
9 |
+
repo: RepoDesignation;
|
10 |
+
path: string;
|
11 |
+
revision?: string;
|
12 |
+
hubUrl?: string;
|
13 |
+
/**
|
14 |
+
* Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
|
15 |
+
*/
|
16 |
+
fetch?: typeof fetch;
|
17 |
+
} & Partial<CredentialsParams>
|
18 |
+
): Promise<boolean> {
|
19 |
+
const accessToken = checkCredentials(params);
|
20 |
+
const repoId = toRepoId(params.repo);
|
21 |
+
|
22 |
+
const hubUrl = params.hubUrl ?? HUB_URL;
|
23 |
+
const url = `${hubUrl}/${repoId.type === "model" ? "" : `${repoId.type}s/`}${repoId.name}/raw/${encodeURIComponent(
|
24 |
+
params.revision ?? "main"
|
25 |
+
)}/${params.path}`;
|
26 |
+
|
27 |
+
const resp = await (params.fetch ?? fetch)(url, {
|
28 |
+
method: "HEAD",
|
29 |
+
headers: accessToken ? { Authorization: `Bearer ${accessToken}` } : {},
|
30 |
+
});
|
31 |
+
|
32 |
+
if (resp.status === 404) {
|
33 |
+
return false;
|
34 |
+
}
|
35 |
+
|
36 |
+
if (!resp.ok) {
|
37 |
+
throw await createApiError(resp);
|
38 |
+
}
|
39 |
+
|
40 |
+
return true;
|
41 |
+
}
|
lib/index.ts
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
export * from "./cache-management";
|
2 |
+
export * from "./check-repo-access";
|
3 |
+
export * from "./commit";
|
4 |
+
export * from "./count-commits";
|
5 |
+
export * from "./create-repo";
|
6 |
+
export * from "./create-branch";
|
7 |
+
export * from "./dataset-info";
|
8 |
+
export * from "./delete-branch";
|
9 |
+
export * from "./delete-file";
|
10 |
+
export * from "./delete-files";
|
11 |
+
export * from "./delete-repo";
|
12 |
+
export * from "./download-file";
|
13 |
+
export * from "./download-file-to-cache-dir";
|
14 |
+
export * from "./file-download-info";
|
15 |
+
export * from "./file-exists";
|
16 |
+
export * from "./list-commits";
|
17 |
+
export * from "./list-datasets";
|
18 |
+
export * from "./list-files";
|
19 |
+
export * from "./list-models";
|
20 |
+
export * from "./list-spaces";
|
21 |
+
export * from "./model-info";
|
22 |
+
export * from "./oauth-handle-redirect";
|
23 |
+
export * from "./oauth-login-url";
|
24 |
+
export * from "./parse-safetensors-metadata";
|
25 |
+
export * from "./paths-info";
|
26 |
+
export * from "./repo-exists";
|
27 |
+
export * from "./snapshot-download";
|
28 |
+
export * from "./space-info";
|
29 |
+
export * from "./upload-file";
|
30 |
+
export * from "./upload-files";
|
31 |
+
export * from "./upload-files-with-progress";
|
32 |
+
export * from "./who-am-i";
|
lib/list-commits.spec.ts
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { assert, it, describe } from "vitest";
|
2 |
+
import type { CommitData } from "./list-commits";
|
3 |
+
import { listCommits } from "./list-commits";
|
4 |
+
|
5 |
+
describe("listCommits", () => {
|
6 |
+
it("should fetch paginated commits from the repo", async () => {
|
7 |
+
const commits: CommitData[] = [];
|
8 |
+
for await (const commit of listCommits({
|
9 |
+
repo: {
|
10 |
+
name: "openai-community/gpt2",
|
11 |
+
type: "model",
|
12 |
+
},
|
13 |
+
revision: "607a30d783dfa663caf39e06633721c8d4cfcd7e",
|
14 |
+
batchSize: 5,
|
15 |
+
})) {
|
16 |
+
commits.push(commit);
|
17 |
+
}
|
18 |
+
|
19 |
+
assert.equal(commits.length, 26);
|
20 |
+
assert.deepEqual(commits.slice(0, 6), [
|
21 |
+
{
|
22 |
+
oid: "607a30d783dfa663caf39e06633721c8d4cfcd7e",
|
23 |
+
title: "Adds the tokenizer configuration file (#80)",
|
24 |
+
message: "\n\n\n- Adds tokenizer_config.json file (db6d57930088fb63e52c010bd9ac77c955ac55e7)\n\n",
|
25 |
+
authors: [
|
26 |
+
{
|
27 |
+
username: "lysandre",
|
28 |
+
avatarUrl:
|
29 |
+
"https://cdn-avatars.huggingface.co/v1/production/uploads/5e3aec01f55e2b62848a5217/PMKS0NNB4MJQlTSFzh918.jpeg",
|
30 |
+
},
|
31 |
+
],
|
32 |
+
date: new Date("2024-02-19T10:57:45.000Z"),
|
33 |
+
},
|
34 |
+
{
|
35 |
+
oid: "11c5a3d5811f50298f278a704980280950aedb10",
|
36 |
+
title: "Adding ONNX file of this model (#60)",
|
37 |
+
message: "\n\n\n- Adding ONNX file of this model (9411f419c589519e1a46c94ac7789ea20fd7c322)\n\n",
|
38 |
+
authors: [
|
39 |
+
{
|
40 |
+
username: "fxmarty",
|
41 |
+
avatarUrl:
|
42 |
+
"https://cdn-avatars.huggingface.co/v1/production/uploads/1651743336129-624c60cba8ec93a7ac188b56.png",
|
43 |
+
},
|
44 |
+
],
|
45 |
+
date: new Date("2023-06-30T02:19:43.000Z"),
|
46 |
+
},
|
47 |
+
{
|
48 |
+
oid: "e7da7f221d5bf496a48136c0cd264e630fe9fcc8",
|
49 |
+
title: "Update generation_config.json",
|
50 |
+
message: "",
|
51 |
+
authors: [
|
52 |
+
{
|
53 |
+
username: "joaogante",
|
54 |
+
avatarUrl: "https://cdn-avatars.huggingface.co/v1/production/uploads/1641203017724-noauth.png",
|
55 |
+
},
|
56 |
+
],
|
57 |
+
date: new Date("2022-12-16T15:44:21.000Z"),
|
58 |
+
},
|
59 |
+
{
|
60 |
+
oid: "f27b190eeac4c2302d24068eabf5e9d6044389ae",
|
61 |
+
title: "Add note that this is the smallest version of the model (#18)",
|
62 |
+
message:
|
63 |
+
"\n\n\n- Add note that this is the smallest version of the model (611838ef095a5bb35bf2027d05e1194b7c9d37ac)\n\n\nCo-authored-by: helen <mathemakitten@users.noreply.huggingface.co>\n",
|
64 |
+
authors: [
|
65 |
+
{
|
66 |
+
username: "sgugger",
|
67 |
+
avatarUrl:
|
68 |
+
"https://cdn-avatars.huggingface.co/v1/production/uploads/1593126474392-5ef50182b71947201082a4e5.jpeg",
|
69 |
+
},
|
70 |
+
{
|
71 |
+
username: "mathemakitten",
|
72 |
+
avatarUrl:
|
73 |
+
"https://cdn-avatars.huggingface.co/v1/production/uploads/1658248499901-6079afe2d2cd8c150e6ae05e.jpeg",
|
74 |
+
},
|
75 |
+
],
|
76 |
+
date: new Date("2022-11-23T12:55:26.000Z"),
|
77 |
+
},
|
78 |
+
{
|
79 |
+
oid: "0dd7bcc7a64e4350d8859c9a2813132fbf6ae591",
|
80 |
+
title: "Our very first generation_config.json (#17)",
|
81 |
+
message:
|
82 |
+
"\n\n\n- Our very first generation_config.json (671851b7e9d56ef062890732065d7bd5f4628bd6)\n\n\nCo-authored-by: Joao Gante <joaogante@users.noreply.huggingface.co>\n",
|
83 |
+
authors: [
|
84 |
+
{
|
85 |
+
username: "sgugger",
|
86 |
+
avatarUrl:
|
87 |
+
"https://cdn-avatars.huggingface.co/v1/production/uploads/1593126474392-5ef50182b71947201082a4e5.jpeg",
|
88 |
+
},
|
89 |
+
{
|
90 |
+
username: "joaogante",
|
91 |
+
avatarUrl: "https://cdn-avatars.huggingface.co/v1/production/uploads/1641203017724-noauth.png",
|
92 |
+
},
|
93 |
+
],
|
94 |
+
date: new Date("2022-11-18T18:19:30.000Z"),
|
95 |
+
},
|
96 |
+
{
|
97 |
+
oid: "75e09b43581151bd1d9ef6700faa605df408979f",
|
98 |
+
title: "Upload model.safetensors with huggingface_hub (#12)",
|
99 |
+
message:
|
100 |
+
"\n\n\n- Upload model.safetensors with huggingface_hub (ba2f794b2e4ea09ef932a6628fa0815dfaf09661)\n\n\nCo-authored-by: Nicolas Patry <Narsil@users.noreply.huggingface.co>\n",
|
101 |
+
authors: [
|
102 |
+
{
|
103 |
+
username: "julien-c",
|
104 |
+
avatarUrl:
|
105 |
+
"https://cdn-avatars.huggingface.co/v1/production/uploads/5dd96eb166059660ed1ee413/NQtzmrDdbG0H8qkZvRyGk.jpeg",
|
106 |
+
},
|
107 |
+
{
|
108 |
+
username: "Narsil",
|
109 |
+
avatarUrl:
|
110 |
+
"https://cdn-avatars.huggingface.co/v1/production/uploads/1608285816082-5e2967b819407e3277369b95.png",
|
111 |
+
},
|
112 |
+
],
|
113 |
+
date: new Date("2022-10-20T09:34:54.000Z"),
|
114 |
+
},
|
115 |
+
]);
|
116 |
+
});
|
117 |
+
});
|
lib/list-commits.ts
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { HUB_URL } from "../consts";
|
2 |
+
import { createApiError } from "../error";
|
3 |
+
import type { ApiCommitData } from "../types/api/api-commit";
|
4 |
+
import type { CredentialsParams, RepoDesignation } from "../types/public";
|
5 |
+
import { checkCredentials } from "../utils/checkCredentials";
|
6 |
+
import { parseLinkHeader } from "../utils/parseLinkHeader";
|
7 |
+
import { toRepoId } from "../utils/toRepoId";
|
8 |
+
|
9 |
+
export interface CommitData {
|
10 |
+
oid: string;
|
11 |
+
title: string;
|
12 |
+
message: string;
|
13 |
+
authors: Array<{ username: string; avatarUrl: string }>;
|
14 |
+
date: Date;
|
15 |
+
}
|
16 |
+
|
17 |
+
export async function* listCommits(
|
18 |
+
params: {
|
19 |
+
repo: RepoDesignation;
|
20 |
+
/**
|
21 |
+
* Revision to list commits from. Defaults to the default branch.
|
22 |
+
*/
|
23 |
+
revision?: string;
|
24 |
+
hubUrl?: string;
|
25 |
+
/**
|
26 |
+
* Number of commits to fetch from the hub each http call. Defaults to 100. Can be set to 1000.
|
27 |
+
*/
|
28 |
+
batchSize?: number;
|
29 |
+
/**
|
30 |
+
* Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
|
31 |
+
*/
|
32 |
+
fetch?: typeof fetch;
|
33 |
+
} & Partial<CredentialsParams>
|
34 |
+
): AsyncGenerator<CommitData> {
|
35 |
+
const accessToken = checkCredentials(params);
|
36 |
+
const repoId = toRepoId(params.repo);
|
37 |
+
|
38 |
+
// Could upgrade to 1000 commits per page
|
39 |
+
let url: string | undefined = `${params.hubUrl ?? HUB_URL}/api/${repoId.type}s/${repoId.name}/commits/${
|
40 |
+
params.revision ?? "main"
|
41 |
+
}?limit=${params.batchSize ?? 100}`;
|
42 |
+
|
43 |
+
while (url) {
|
44 |
+
const res: Response = await (params.fetch ?? fetch)(url, {
|
45 |
+
headers: accessToken ? { Authorization: `Bearer ${accessToken}` } : {},
|
46 |
+
});
|
47 |
+
|
48 |
+
if (!res.ok) {
|
49 |
+
throw await createApiError(res);
|
50 |
+
}
|
51 |
+
|
52 |
+
const resJson: ApiCommitData[] = await res.json();
|
53 |
+
for (const commit of resJson) {
|
54 |
+
yield {
|
55 |
+
oid: commit.id,
|
56 |
+
title: commit.title,
|
57 |
+
message: commit.message,
|
58 |
+
authors: commit.authors.map((author) => ({
|
59 |
+
username: author.user,
|
60 |
+
avatarUrl: author.avatar,
|
61 |
+
})),
|
62 |
+
date: new Date(commit.date),
|
63 |
+
};
|
64 |
+
}
|
65 |
+
|
66 |
+
const linkHeader = res.headers.get("Link");
|
67 |
+
|
68 |
+
url = linkHeader ? parseLinkHeader(linkHeader).next : undefined;
|
69 |
+
}
|
70 |
+
}
|
lib/list-datasets.spec.ts
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { describe, expect, it } from "vitest";
|
2 |
+
import type { DatasetEntry } from "./list-datasets";
|
3 |
+
import { listDatasets } from "./list-datasets";
|
4 |
+
|
5 |
+
describe("listDatasets", () => {
|
6 |
+
it("should list datasets from hf-doc-builder", async () => {
|
7 |
+
const results: DatasetEntry[] = [];
|
8 |
+
|
9 |
+
for await (const entry of listDatasets({ search: { owner: "hf-doc-build" } })) {
|
10 |
+
if (entry.name === "hf-doc-build/doc-build-dev-test") {
|
11 |
+
continue;
|
12 |
+
}
|
13 |
+
if (typeof entry.downloads === "number") {
|
14 |
+
entry.downloads = 0;
|
15 |
+
}
|
16 |
+
if (typeof entry.likes === "number") {
|
17 |
+
entry.likes = 0;
|
18 |
+
}
|
19 |
+
if (entry.updatedAt instanceof Date && !isNaN(entry.updatedAt.getTime())) {
|
20 |
+
entry.updatedAt = new Date(0);
|
21 |
+
}
|
22 |
+
|
23 |
+
results.push(entry);
|
24 |
+
}
|
25 |
+
|
26 |
+
expect(results).deep.equal([
|
27 |
+
{
|
28 |
+
id: "6356b19985da6f13863228bd",
|
29 |
+
name: "hf-doc-build/doc-build",
|
30 |
+
private: false,
|
31 |
+
gated: false,
|
32 |
+
downloads: 0,
|
33 |
+
likes: 0,
|
34 |
+
updatedAt: new Date(0),
|
35 |
+
},
|
36 |
+
{
|
37 |
+
id: "636a1b69f2f9ec4289c4c19e",
|
38 |
+
name: "hf-doc-build/doc-build-dev",
|
39 |
+
gated: false,
|
40 |
+
private: false,
|
41 |
+
downloads: 0,
|
42 |
+
likes: 0,
|
43 |
+
updatedAt: new Date(0),
|
44 |
+
},
|
45 |
+
]);
|
46 |
+
});
|
47 |
+
});
|
lib/list-datasets.ts
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { HUB_URL } from "../consts";
|
2 |
+
import { createApiError } from "../error";
|
3 |
+
import type { ApiDatasetInfo } from "../types/api/api-dataset";
|
4 |
+
import type { CredentialsParams } from "../types/public";
|
5 |
+
import { checkCredentials } from "../utils/checkCredentials";
|
6 |
+
import { parseLinkHeader } from "../utils/parseLinkHeader";
|
7 |
+
import { pick } from "../utils/pick";
|
8 |
+
|
9 |
+
export const DATASET_EXPAND_KEYS = [
|
10 |
+
"private",
|
11 |
+
"downloads",
|
12 |
+
"gated",
|
13 |
+
"likes",
|
14 |
+
"lastModified",
|
15 |
+
] as const satisfies readonly (keyof ApiDatasetInfo)[];
|
16 |
+
|
17 |
+
export const DATASET_EXPANDABLE_KEYS = [
|
18 |
+
"author",
|
19 |
+
"cardData",
|
20 |
+
"citation",
|
21 |
+
"createdAt",
|
22 |
+
"disabled",
|
23 |
+
"description",
|
24 |
+
"downloads",
|
25 |
+
"downloadsAllTime",
|
26 |
+
"gated",
|
27 |
+
"gitalyUid",
|
28 |
+
"lastModified",
|
29 |
+
"likes",
|
30 |
+
"paperswithcode_id",
|
31 |
+
"private",
|
32 |
+
// "siblings",
|
33 |
+
"sha",
|
34 |
+
"tags",
|
35 |
+
] as const satisfies readonly (keyof ApiDatasetInfo)[];
|
36 |
+
|
37 |
+
export interface DatasetEntry {
|
38 |
+
id: string;
|
39 |
+
name: string;
|
40 |
+
private: boolean;
|
41 |
+
downloads: number;
|
42 |
+
gated: false | "auto" | "manual";
|
43 |
+
likes: number;
|
44 |
+
updatedAt: Date;
|
45 |
+
}
|
46 |
+
|
47 |
+
export async function* listDatasets<
|
48 |
+
const T extends Exclude<(typeof DATASET_EXPANDABLE_KEYS)[number], (typeof DATASET_EXPAND_KEYS)[number]> = never,
|
49 |
+
>(
|
50 |
+
params?: {
|
51 |
+
search?: {
|
52 |
+
/**
|
53 |
+
* Will search in the dataset name for matches
|
54 |
+
*/
|
55 |
+
query?: string;
|
56 |
+
owner?: string;
|
57 |
+
tags?: string[];
|
58 |
+
};
|
59 |
+
hubUrl?: string;
|
60 |
+
additionalFields?: T[];
|
61 |
+
/**
|
62 |
+
* Set to limit the number of models returned.
|
63 |
+
*/
|
64 |
+
limit?: number;
|
65 |
+
/**
|
66 |
+
* Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
|
67 |
+
*/
|
68 |
+
fetch?: typeof fetch;
|
69 |
+
} & Partial<CredentialsParams>
|
70 |
+
): AsyncGenerator<DatasetEntry & Pick<ApiDatasetInfo, T>> {
|
71 |
+
const accessToken = params && checkCredentials(params);
|
72 |
+
let totalToFetch = params?.limit ?? Infinity;
|
73 |
+
const search = new URLSearchParams([
|
74 |
+
...Object.entries({
|
75 |
+
limit: String(Math.min(totalToFetch, 500)),
|
76 |
+
...(params?.search?.owner ? { author: params.search.owner } : undefined),
|
77 |
+
...(params?.search?.query ? { search: params.search.query } : undefined),
|
78 |
+
}),
|
79 |
+
...(params?.search?.tags?.map((tag) => ["filter", tag]) ?? []),
|
80 |
+
...DATASET_EXPAND_KEYS.map((val) => ["expand", val] satisfies [string, string]),
|
81 |
+
...(params?.additionalFields?.map((val) => ["expand", val] satisfies [string, string]) ?? []),
|
82 |
+
]).toString();
|
83 |
+
let url: string | undefined = `${params?.hubUrl || HUB_URL}/api/datasets` + (search ? "?" + search : "");
|
84 |
+
|
85 |
+
while (url) {
|
86 |
+
const res: Response = await (params?.fetch ?? fetch)(url, {
|
87 |
+
headers: {
|
88 |
+
accept: "application/json",
|
89 |
+
...(accessToken ? { Authorization: `Bearer ${accessToken}` } : undefined),
|
90 |
+
},
|
91 |
+
});
|
92 |
+
|
93 |
+
if (!res.ok) {
|
94 |
+
throw await createApiError(res);
|
95 |
+
}
|
96 |
+
|
97 |
+
const items: ApiDatasetInfo[] = await res.json();
|
98 |
+
|
99 |
+
for (const item of items) {
|
100 |
+
yield {
|
101 |
+
...(params?.additionalFields && pick(item, params.additionalFields)),
|
102 |
+
id: item._id,
|
103 |
+
name: item.id,
|
104 |
+
private: item.private,
|
105 |
+
downloads: item.downloads,
|
106 |
+
likes: item.likes,
|
107 |
+
gated: item.gated,
|
108 |
+
updatedAt: new Date(item.lastModified),
|
109 |
+
} as DatasetEntry & Pick<ApiDatasetInfo, T>;
|
110 |
+
totalToFetch--;
|
111 |
+
if (totalToFetch <= 0) {
|
112 |
+
return;
|
113 |
+
}
|
114 |
+
}
|
115 |
+
|
116 |
+
const linkHeader = res.headers.get("Link");
|
117 |
+
|
118 |
+
url = linkHeader ? parseLinkHeader(linkHeader).next : undefined;
|
119 |
+
// Could update limit in url to fetch less items if not all items of next page are needed.
|
120 |
+
}
|
121 |
+
}
|
lib/list-files.spec.ts
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { assert, it, describe } from "vitest";
|
2 |
+
import type { ListFileEntry } from "./list-files";
|
3 |
+
import { listFiles } from "./list-files";
|
4 |
+
|
5 |
+
describe("listFiles", () => {
|
6 |
+
it("should fetch the list of files from the repo", async () => {
|
7 |
+
const cursor = listFiles({
|
8 |
+
repo: {
|
9 |
+
name: "bert-base-uncased",
|
10 |
+
type: "model",
|
11 |
+
},
|
12 |
+
revision: "dd4bc8b21efa05ec961e3efc4ee5e3832a3679c7",
|
13 |
+
});
|
14 |
+
|
15 |
+
const files: ListFileEntry[] = [];
|
16 |
+
|
17 |
+
for await (const entry of cursor) {
|
18 |
+
files.push(entry);
|
19 |
+
}
|
20 |
+
|
21 |
+
assert.deepStrictEqual(files, [
|
22 |
+
{
|
23 |
+
oid: "dc08351d4dc0732d9c8af04070ced089b201ce2f",
|
24 |
+
path: ".gitattributes",
|
25 |
+
size: 345,
|
26 |
+
type: "file",
|
27 |
+
},
|
28 |
+
{
|
29 |
+
oid: "fca794a5f07ff8f963fe8b61e3694b0fb7f955df",
|
30 |
+
path: "config.json",
|
31 |
+
size: 313,
|
32 |
+
type: "file",
|
33 |
+
},
|
34 |
+
{
|
35 |
+
lfs: {
|
36 |
+
oid: "097417381d6c7230bd9e3557456d726de6e83245ec8b24f529f60198a67b203a",
|
37 |
+
size: 440473133,
|
38 |
+
pointerSize: 134,
|
39 |
+
},
|
40 |
+
xetHash: "2d8408d3a894d02517d04956e2f7546ff08362594072f3527ce144b5212a3296",
|
41 |
+
oid: "ba5d19791be1dd7992e33bd61f20207b0f7f50a5",
|
42 |
+
path: "pytorch_model.bin",
|
43 |
+
size: 440473133,
|
44 |
+
type: "file",
|
45 |
+
},
|
46 |
+
{
|
47 |
+
lfs: {
|
48 |
+
oid: "a7a17d6d844b5de815ccab5f42cad6d24496db3850a2a43d8258221018ce87d2",
|
49 |
+
size: 536063208,
|
50 |
+
pointerSize: 134,
|
51 |
+
},
|
52 |
+
xetHash: "879c5715c18a0b7f051dd33f70f0a5c8dd1522e0a43f6f75520f16167f29279b",
|
53 |
+
oid: "9eb98c817f04b051b3bcca591bcd4e03cec88018",
|
54 |
+
path: "tf_model.h5",
|
55 |
+
size: 536063208,
|
56 |
+
type: "file",
|
57 |
+
},
|
58 |
+
{
|
59 |
+
oid: "fb140275c155a9c7c5a3b3e0e77a9e839594a938",
|
60 |
+
path: "vocab.txt",
|
61 |
+
size: 231508,
|
62 |
+
type: "file",
|
63 |
+
},
|
64 |
+
]);
|
65 |
+
});
|
66 |
+
|
67 |
+
it("should fetch the list of files from the repo, including last commit", async () => {
|
68 |
+
const cursor = listFiles({
|
69 |
+
repo: {
|
70 |
+
name: "bert-base-uncased",
|
71 |
+
type: "model",
|
72 |
+
},
|
73 |
+
revision: "dd4bc8b21efa05ec961e3efc4ee5e3832a3679c7",
|
74 |
+
expand: true,
|
75 |
+
});
|
76 |
+
|
77 |
+
const files: ListFileEntry[] = [];
|
78 |
+
|
79 |
+
for await (const entry of cursor) {
|
80 |
+
delete entry.securityFileStatus; // flaky
|
81 |
+
files.push(entry);
|
82 |
+
}
|
83 |
+
|
84 |
+
assert.deepStrictEqual(files, [
|
85 |
+
{
|
86 |
+
lastCommit: {
|
87 |
+
date: "2018-11-14T23:35:08.000Z",
|
88 |
+
id: "504939aa53e8ce310dba3dd2296dbe266c575de4",
|
89 |
+
title: "initial commit",
|
90 |
+
},
|
91 |
+
oid: "dc08351d4dc0732d9c8af04070ced089b201ce2f",
|
92 |
+
path: ".gitattributes",
|
93 |
+
size: 345,
|
94 |
+
type: "file",
|
95 |
+
},
|
96 |
+
{
|
97 |
+
lastCommit: {
|
98 |
+
date: "2019-06-18T09:06:51.000Z",
|
99 |
+
id: "bb3c1c3256d2598217df9889a14a2e811587891d",
|
100 |
+
title: "Update config.json",
|
101 |
+
},
|
102 |
+
oid: "fca794a5f07ff8f963fe8b61e3694b0fb7f955df",
|
103 |
+
path: "config.json",
|
104 |
+
size: 313,
|
105 |
+
type: "file",
|
106 |
+
},
|
107 |
+
{
|
108 |
+
lastCommit: {
|
109 |
+
date: "2019-06-18T09:06:34.000Z",
|
110 |
+
id: "3d2477d72b675a999d1b13ca822aaaf4908634ad",
|
111 |
+
title: "Update pytorch_model.bin",
|
112 |
+
},
|
113 |
+
lfs: {
|
114 |
+
oid: "097417381d6c7230bd9e3557456d726de6e83245ec8b24f529f60198a67b203a",
|
115 |
+
size: 440473133,
|
116 |
+
pointerSize: 134,
|
117 |
+
},
|
118 |
+
xetHash: "2d8408d3a894d02517d04956e2f7546ff08362594072f3527ce144b5212a3296",
|
119 |
+
oid: "ba5d19791be1dd7992e33bd61f20207b0f7f50a5",
|
120 |
+
path: "pytorch_model.bin",
|
121 |
+
size: 440473133,
|
122 |
+
type: "file",
|
123 |
+
},
|
124 |
+
{
|
125 |
+
lastCommit: {
|
126 |
+
date: "2019-09-23T19:48:44.000Z",
|
127 |
+
id: "dd4bc8b21efa05ec961e3efc4ee5e3832a3679c7",
|
128 |
+
title: "Update tf_model.h5",
|
129 |
+
},
|
130 |
+
lfs: {
|
131 |
+
oid: "a7a17d6d844b5de815ccab5f42cad6d24496db3850a2a43d8258221018ce87d2",
|
132 |
+
size: 536063208,
|
133 |
+
pointerSize: 134,
|
134 |
+
},
|
135 |
+
xetHash: "879c5715c18a0b7f051dd33f70f0a5c8dd1522e0a43f6f75520f16167f29279b",
|
136 |
+
oid: "9eb98c817f04b051b3bcca591bcd4e03cec88018",
|
137 |
+
path: "tf_model.h5",
|
138 |
+
size: 536063208,
|
139 |
+
type: "file",
|
140 |
+
},
|
141 |
+
{
|
142 |
+
lastCommit: {
|
143 |
+
date: "2018-11-14T23:35:08.000Z",
|
144 |
+
id: "2f07d813ca87c8c709147704c87210359ccf2309",
|
145 |
+
title: "Update vocab.txt",
|
146 |
+
},
|
147 |
+
oid: "fb140275c155a9c7c5a3b3e0e77a9e839594a938",
|
148 |
+
path: "vocab.txt",
|
149 |
+
size: 231508,
|
150 |
+
type: "file",
|
151 |
+
},
|
152 |
+
]);
|
153 |
+
});
|
154 |
+
|
155 |
+
it("should fetch the list of files from the repo, including subfolders", async () => {
|
156 |
+
const cursor = listFiles({
|
157 |
+
repo: {
|
158 |
+
name: "xsum",
|
159 |
+
type: "dataset",
|
160 |
+
},
|
161 |
+
revision: "0f3ea2f2b55fcb11e71fb1e3aec6822e44ddcb0f",
|
162 |
+
recursive: true,
|
163 |
+
});
|
164 |
+
|
165 |
+
const files: ListFileEntry[] = [];
|
166 |
+
|
167 |
+
for await (const entry of cursor) {
|
168 |
+
files.push(entry);
|
169 |
+
}
|
170 |
+
|
171 |
+
assert(files.some((file) => file.path === "data/XSUM-EMNLP18-Summary-Data-Original.tar.gz"));
|
172 |
+
});
|
173 |
+
});
|
lib/list-files.ts
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { HUB_URL } from "../consts";
|
2 |
+
import { createApiError } from "../error";
|
3 |
+
import type { ApiIndexTreeEntry } from "../types/api/api-index-tree";
|
4 |
+
import type { CredentialsParams, RepoDesignation } from "../types/public";
|
5 |
+
import { checkCredentials } from "../utils/checkCredentials";
|
6 |
+
import { parseLinkHeader } from "../utils/parseLinkHeader";
|
7 |
+
import { toRepoId } from "../utils/toRepoId";
|
8 |
+
|
9 |
+
export interface ListFileEntry {
|
10 |
+
type: "file" | "directory" | "unknown";
|
11 |
+
size: number;
|
12 |
+
path: string;
|
13 |
+
oid: string;
|
14 |
+
lfs?: {
|
15 |
+
oid: string;
|
16 |
+
size: number;
|
17 |
+
/** Size of the raw pointer file, 100~200 bytes */
|
18 |
+
pointerSize: number;
|
19 |
+
};
|
20 |
+
/**
|
21 |
+
* Xet-backed hash, a new protocol replacing LFS for big files.
|
22 |
+
*/
|
23 |
+
xetHash?: string;
|
24 |
+
/**
|
25 |
+
* Only fetched if `expand` is set to `true` in the `listFiles` call.
|
26 |
+
*/
|
27 |
+
lastCommit?: {
|
28 |
+
date: string;
|
29 |
+
id: string;
|
30 |
+
title: string;
|
31 |
+
};
|
32 |
+
/**
|
33 |
+
* Only fetched if `expand` is set to `true` in the `listFiles` call.
|
34 |
+
*/
|
35 |
+
securityFileStatus?: unknown;
|
36 |
+
}
|
37 |
+
|
38 |
+
/**
|
39 |
+
* List files in a folder. To list ALL files in the directory, call it
|
40 |
+
* with {@link params.recursive} set to `true`.
|
41 |
+
*/
|
42 |
+
export async function* listFiles(
|
43 |
+
params: {
|
44 |
+
repo: RepoDesignation;
|
45 |
+
/**
|
46 |
+
* Do we want to list files in subdirectories?
|
47 |
+
*/
|
48 |
+
recursive?: boolean;
|
49 |
+
/**
|
50 |
+
* Eg 'data' for listing all files in the 'data' folder. Leave it empty to list all
|
51 |
+
* files in the repo.
|
52 |
+
*/
|
53 |
+
path?: string;
|
54 |
+
/**
|
55 |
+
* Fetch `lastCommit` and `securityFileStatus` for each file.
|
56 |
+
*/
|
57 |
+
expand?: boolean;
|
58 |
+
revision?: string;
|
59 |
+
hubUrl?: string;
|
60 |
+
/**
|
61 |
+
* Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
|
62 |
+
*/
|
63 |
+
fetch?: typeof fetch;
|
64 |
+
} & Partial<CredentialsParams>
|
65 |
+
): AsyncGenerator<ListFileEntry> {
|
66 |
+
const accessToken = checkCredentials(params);
|
67 |
+
const repoId = toRepoId(params.repo);
|
68 |
+
let url: string | undefined = `${params.hubUrl || HUB_URL}/api/${repoId.type}s/${repoId.name}/tree/${
|
69 |
+
params.revision || "main"
|
70 |
+
}${params.path ? "/" + params.path : ""}?recursive=${!!params.recursive}&expand=${!!params.expand}`;
|
71 |
+
|
72 |
+
while (url) {
|
73 |
+
const res: Response = await (params.fetch ?? fetch)(url, {
|
74 |
+
headers: {
|
75 |
+
accept: "application/json",
|
76 |
+
...(accessToken ? { Authorization: `Bearer ${accessToken}` } : undefined),
|
77 |
+
},
|
78 |
+
});
|
79 |
+
|
80 |
+
if (!res.ok) {
|
81 |
+
throw await createApiError(res);
|
82 |
+
}
|
83 |
+
|
84 |
+
const items: ApiIndexTreeEntry[] = await res.json();
|
85 |
+
|
86 |
+
for (const item of items) {
|
87 |
+
yield item;
|
88 |
+
}
|
89 |
+
|
90 |
+
const linkHeader = res.headers.get("Link");
|
91 |
+
|
92 |
+
url = linkHeader ? parseLinkHeader(linkHeader).next : undefined;
|
93 |
+
}
|
94 |
+
}
|
lib/list-models.spec.ts
ADDED
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { describe, expect, it } from "vitest";
|
2 |
+
import type { ModelEntry } from "./list-models";
|
3 |
+
import { listModels } from "./list-models";
|
4 |
+
|
5 |
+
describe("listModels", () => {
|
6 |
+
it("should list models for depth estimation", async () => {
|
7 |
+
const results: ModelEntry[] = [];
|
8 |
+
|
9 |
+
for await (const entry of listModels({
|
10 |
+
search: { owner: "Intel", task: "depth-estimation" },
|
11 |
+
})) {
|
12 |
+
if (typeof entry.downloads === "number") {
|
13 |
+
entry.downloads = 0;
|
14 |
+
}
|
15 |
+
if (typeof entry.likes === "number") {
|
16 |
+
entry.likes = 0;
|
17 |
+
}
|
18 |
+
if (entry.updatedAt instanceof Date && !isNaN(entry.updatedAt.getTime())) {
|
19 |
+
entry.updatedAt = new Date(0);
|
20 |
+
}
|
21 |
+
|
22 |
+
if (!["Intel/dpt-large", "Intel/dpt-hybrid-midas"].includes(entry.name)) {
|
23 |
+
expect(entry.task).to.equal("depth-estimation");
|
24 |
+
continue;
|
25 |
+
}
|
26 |
+
|
27 |
+
results.push(entry);
|
28 |
+
}
|
29 |
+
|
30 |
+
results.sort((a, b) => a.id.localeCompare(b.id));
|
31 |
+
|
32 |
+
expect(results).deep.equal([
|
33 |
+
{
|
34 |
+
id: "621ffdc136468d709f17e709",
|
35 |
+
name: "Intel/dpt-large",
|
36 |
+
private: false,
|
37 |
+
gated: false,
|
38 |
+
downloads: 0,
|
39 |
+
likes: 0,
|
40 |
+
task: "depth-estimation",
|
41 |
+
updatedAt: new Date(0),
|
42 |
+
},
|
43 |
+
{
|
44 |
+
id: "638f07977559bf9a2b2b04ac",
|
45 |
+
name: "Intel/dpt-hybrid-midas",
|
46 |
+
gated: false,
|
47 |
+
private: false,
|
48 |
+
downloads: 0,
|
49 |
+
likes: 0,
|
50 |
+
task: "depth-estimation",
|
51 |
+
updatedAt: new Date(0),
|
52 |
+
},
|
53 |
+
]);
|
54 |
+
});
|
55 |
+
|
56 |
+
it("should list indonesian models with gguf format", async () => {
|
57 |
+
let count = 0;
|
58 |
+
for await (const entry of listModels({
|
59 |
+
search: { tags: ["gguf", "id"] },
|
60 |
+
additionalFields: ["tags"],
|
61 |
+
limit: 2,
|
62 |
+
})) {
|
63 |
+
count++;
|
64 |
+
expect(entry.tags).to.include("gguf");
|
65 |
+
expect(entry.tags).to.include("id");
|
66 |
+
}
|
67 |
+
|
68 |
+
expect(count).to.equal(2);
|
69 |
+
});
|
70 |
+
|
71 |
+
it("should search model by name", async () => {
|
72 |
+
let count = 0;
|
73 |
+
for await (const entry of listModels({
|
74 |
+
search: { query: "t5" },
|
75 |
+
limit: 10,
|
76 |
+
})) {
|
77 |
+
count++;
|
78 |
+
expect(entry.name.toLocaleLowerCase()).to.include("t5");
|
79 |
+
}
|
80 |
+
|
81 |
+
expect(count).to.equal(10);
|
82 |
+
});
|
83 |
+
|
84 |
+
it("should search model by inference provider", async () => {
|
85 |
+
let count = 0;
|
86 |
+
for await (const entry of listModels({
|
87 |
+
search: { inferenceProviders: ["together"] },
|
88 |
+
additionalFields: ["inferenceProviderMapping"],
|
89 |
+
limit: 10,
|
90 |
+
})) {
|
91 |
+
count++;
|
92 |
+
if (Array.isArray(entry.inferenceProviderMapping)) {
|
93 |
+
expect(entry.inferenceProviderMapping.map(({ provider }) => provider)).to.include("together");
|
94 |
+
}
|
95 |
+
}
|
96 |
+
|
97 |
+
expect(count).to.equal(10);
|
98 |
+
});
|
99 |
+
|
100 |
+
it("should search model by several inference providers", async () => {
|
101 |
+
let count = 0;
|
102 |
+
const inferenceProviders = ["together", "replicate"];
|
103 |
+
for await (const entry of listModels({
|
104 |
+
search: { inferenceProviders },
|
105 |
+
additionalFields: ["inferenceProviderMapping"],
|
106 |
+
limit: 10,
|
107 |
+
})) {
|
108 |
+
count++;
|
109 |
+
if (Array.isArray(entry.inferenceProviderMapping)) {
|
110 |
+
expect(
|
111 |
+
entry.inferenceProviderMapping.filter(({ provider }) => inferenceProviders.includes(provider)).length
|
112 |
+
).toBeGreaterThan(0);
|
113 |
+
}
|
114 |
+
}
|
115 |
+
|
116 |
+
expect(count).to.equal(10);
|
117 |
+
});
|
118 |
+
});
|
lib/list-models.ts
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { HUB_URL } from "../consts";
|
2 |
+
import { createApiError } from "../error";
|
3 |
+
import type { ApiModelInfo } from "../types/api/api-model";
|
4 |
+
import type { CredentialsParams, PipelineType } from "../types/public";
|
5 |
+
import { checkCredentials } from "../utils/checkCredentials";
|
6 |
+
import { parseLinkHeader } from "../utils/parseLinkHeader";
|
7 |
+
import { pick } from "../utils/pick";
|
8 |
+
|
9 |
+
export const MODEL_EXPAND_KEYS = [
|
10 |
+
"pipeline_tag",
|
11 |
+
"private",
|
12 |
+
"gated",
|
13 |
+
"downloads",
|
14 |
+
"likes",
|
15 |
+
"lastModified",
|
16 |
+
] as const satisfies readonly (keyof ApiModelInfo)[];
|
17 |
+
|
18 |
+
export const MODEL_EXPANDABLE_KEYS = [
|
19 |
+
"author",
|
20 |
+
"cardData",
|
21 |
+
"config",
|
22 |
+
"createdAt",
|
23 |
+
"disabled",
|
24 |
+
"downloads",
|
25 |
+
"downloadsAllTime",
|
26 |
+
"gated",
|
27 |
+
"gitalyUid",
|
28 |
+
"inferenceProviderMapping",
|
29 |
+
"lastModified",
|
30 |
+
"library_name",
|
31 |
+
"likes",
|
32 |
+
"model-index",
|
33 |
+
"pipeline_tag",
|
34 |
+
"private",
|
35 |
+
"safetensors",
|
36 |
+
"sha",
|
37 |
+
// "siblings",
|
38 |
+
"spaces",
|
39 |
+
"tags",
|
40 |
+
"transformersInfo",
|
41 |
+
] as const satisfies readonly (keyof ApiModelInfo)[];
|
42 |
+
|
43 |
+
export interface ModelEntry {
|
44 |
+
id: string;
|
45 |
+
name: string;
|
46 |
+
private: boolean;
|
47 |
+
gated: false | "auto" | "manual";
|
48 |
+
task?: PipelineType;
|
49 |
+
likes: number;
|
50 |
+
downloads: number;
|
51 |
+
updatedAt: Date;
|
52 |
+
}
|
53 |
+
|
54 |
+
export async function* listModels<
|
55 |
+
const T extends Exclude<(typeof MODEL_EXPANDABLE_KEYS)[number], (typeof MODEL_EXPAND_KEYS)[number]> = never,
|
56 |
+
>(
|
57 |
+
params?: {
|
58 |
+
search?: {
|
59 |
+
/**
|
60 |
+
* Will search in the model name for matches
|
61 |
+
*/
|
62 |
+
query?: string;
|
63 |
+
owner?: string;
|
64 |
+
task?: PipelineType;
|
65 |
+
tags?: string[];
|
66 |
+
/**
|
67 |
+
* Will search for models that have one of the inference providers in the list.
|
68 |
+
*/
|
69 |
+
inferenceProviders?: string[];
|
70 |
+
};
|
71 |
+
hubUrl?: string;
|
72 |
+
additionalFields?: T[];
|
73 |
+
/**
|
74 |
+
* Set to limit the number of models returned.
|
75 |
+
*/
|
76 |
+
limit?: number;
|
77 |
+
/**
|
78 |
+
* Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
|
79 |
+
*/
|
80 |
+
fetch?: typeof fetch;
|
81 |
+
} & Partial<CredentialsParams>
|
82 |
+
): AsyncGenerator<ModelEntry & Pick<ApiModelInfo, T>> {
|
83 |
+
const accessToken = params && checkCredentials(params);
|
84 |
+
let totalToFetch = params?.limit ?? Infinity;
|
85 |
+
const search = new URLSearchParams([
|
86 |
+
...Object.entries({
|
87 |
+
limit: String(Math.min(totalToFetch, 500)),
|
88 |
+
...(params?.search?.owner ? { author: params.search.owner } : undefined),
|
89 |
+
...(params?.search?.task ? { pipeline_tag: params.search.task } : undefined),
|
90 |
+
...(params?.search?.query ? { search: params.search.query } : undefined),
|
91 |
+
...(params?.search?.inferenceProviders
|
92 |
+
? { inference_provider: params.search.inferenceProviders.join(",") }
|
93 |
+
: undefined),
|
94 |
+
}),
|
95 |
+
...(params?.search?.tags?.map((tag) => ["filter", tag]) ?? []),
|
96 |
+
...MODEL_EXPAND_KEYS.map((val) => ["expand", val] satisfies [string, string]),
|
97 |
+
...(params?.additionalFields?.map((val) => ["expand", val] satisfies [string, string]) ?? []),
|
98 |
+
]).toString();
|
99 |
+
let url: string | undefined = `${params?.hubUrl || HUB_URL}/api/models?${search}`;
|
100 |
+
|
101 |
+
while (url) {
|
102 |
+
const res: Response = await (params?.fetch ?? fetch)(url, {
|
103 |
+
headers: {
|
104 |
+
accept: "application/json",
|
105 |
+
...(accessToken ? { Authorization: `Bearer ${accessToken}` } : undefined),
|
106 |
+
},
|
107 |
+
});
|
108 |
+
|
109 |
+
if (!res.ok) {
|
110 |
+
throw await createApiError(res);
|
111 |
+
}
|
112 |
+
|
113 |
+
const items: ApiModelInfo[] = await res.json();
|
114 |
+
|
115 |
+
for (const item of items) {
|
116 |
+
yield {
|
117 |
+
...(params?.additionalFields && pick(item, params.additionalFields)),
|
118 |
+
id: item._id,
|
119 |
+
name: item.id,
|
120 |
+
private: item.private,
|
121 |
+
task: item.pipeline_tag,
|
122 |
+
downloads: item.downloads,
|
123 |
+
gated: item.gated,
|
124 |
+
likes: item.likes,
|
125 |
+
updatedAt: new Date(item.lastModified),
|
126 |
+
} as ModelEntry & Pick<ApiModelInfo, T>;
|
127 |
+
totalToFetch--;
|
128 |
+
|
129 |
+
if (totalToFetch <= 0) {
|
130 |
+
return;
|
131 |
+
}
|
132 |
+
}
|
133 |
+
|
134 |
+
const linkHeader = res.headers.get("Link");
|
135 |
+
|
136 |
+
url = linkHeader ? parseLinkHeader(linkHeader).next : undefined;
|
137 |
+
// Could update url to reduce the limit if we don't need the whole 500 of the next batch.
|
138 |
+
}
|
139 |
+
}
|
lib/list-spaces.spec.ts
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { describe, expect, it } from "vitest";
|
2 |
+
import type { SpaceEntry } from "./list-spaces";
|
3 |
+
import { listSpaces } from "./list-spaces";
|
4 |
+
|
5 |
+
describe("listSpaces", () => {
|
6 |
+
it("should list spaces for Microsoft", async () => {
|
7 |
+
const results: SpaceEntry[] = [];
|
8 |
+
|
9 |
+
for await (const entry of listSpaces({
|
10 |
+
search: { owner: "microsoft" },
|
11 |
+
additionalFields: ["subdomain"],
|
12 |
+
})) {
|
13 |
+
if (entry.name !== "microsoft/visual_chatgpt") {
|
14 |
+
continue;
|
15 |
+
}
|
16 |
+
if (typeof entry.likes === "number") {
|
17 |
+
entry.likes = 0;
|
18 |
+
}
|
19 |
+
if (entry.updatedAt instanceof Date && !isNaN(entry.updatedAt.getTime())) {
|
20 |
+
entry.updatedAt = new Date(0);
|
21 |
+
}
|
22 |
+
|
23 |
+
results.push(entry);
|
24 |
+
}
|
25 |
+
|
26 |
+
results.sort((a, b) => a.id.localeCompare(b.id));
|
27 |
+
|
28 |
+
expect(results).deep.equal([
|
29 |
+
{
|
30 |
+
id: "6409a392bbc73d022c58c980",
|
31 |
+
name: "microsoft/visual_chatgpt",
|
32 |
+
private: false,
|
33 |
+
likes: 0,
|
34 |
+
sdk: "gradio",
|
35 |
+
subdomain: "microsoft-visual-chatgpt",
|
36 |
+
updatedAt: new Date(0),
|
37 |
+
},
|
38 |
+
]);
|
39 |
+
});
|
40 |
+
});
|
lib/list-spaces.ts
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { HUB_URL } from "../consts";
|
2 |
+
import { createApiError } from "../error";
|
3 |
+
import type { ApiSpaceInfo } from "../types/api/api-space";
|
4 |
+
import type { CredentialsParams, SpaceSdk } from "../types/public";
|
5 |
+
import { checkCredentials } from "../utils/checkCredentials";
|
6 |
+
import { parseLinkHeader } from "../utils/parseLinkHeader";
|
7 |
+
import { pick } from "../utils/pick";
|
8 |
+
|
9 |
+
export const SPACE_EXPAND_KEYS = [
|
10 |
+
"sdk",
|
11 |
+
"likes",
|
12 |
+
"private",
|
13 |
+
"lastModified",
|
14 |
+
] as const satisfies readonly (keyof ApiSpaceInfo)[];
|
15 |
+
export const SPACE_EXPANDABLE_KEYS = [
|
16 |
+
"author",
|
17 |
+
"cardData",
|
18 |
+
"datasets",
|
19 |
+
"disabled",
|
20 |
+
"gitalyUid",
|
21 |
+
"lastModified",
|
22 |
+
"createdAt",
|
23 |
+
"likes",
|
24 |
+
"private",
|
25 |
+
"runtime",
|
26 |
+
"sdk",
|
27 |
+
// "siblings",
|
28 |
+
"sha",
|
29 |
+
"subdomain",
|
30 |
+
"tags",
|
31 |
+
"models",
|
32 |
+
] as const satisfies readonly (keyof ApiSpaceInfo)[];
|
33 |
+
|
34 |
+
export interface SpaceEntry {
|
35 |
+
id: string;
|
36 |
+
name: string;
|
37 |
+
sdk?: SpaceSdk;
|
38 |
+
likes: number;
|
39 |
+
private: boolean;
|
40 |
+
updatedAt: Date;
|
41 |
+
// Use additionalFields to fetch the fields from ApiSpaceInfo
|
42 |
+
}
|
43 |
+
|
44 |
+
export async function* listSpaces<
|
45 |
+
const T extends Exclude<(typeof SPACE_EXPANDABLE_KEYS)[number], (typeof SPACE_EXPAND_KEYS)[number]> = never,
|
46 |
+
>(
|
47 |
+
params?: {
|
48 |
+
search?: {
|
49 |
+
/**
|
50 |
+
* Will search in the space name for matches
|
51 |
+
*/
|
52 |
+
query?: string;
|
53 |
+
owner?: string;
|
54 |
+
tags?: string[];
|
55 |
+
};
|
56 |
+
hubUrl?: string;
|
57 |
+
/**
|
58 |
+
* Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
|
59 |
+
*/
|
60 |
+
fetch?: typeof fetch;
|
61 |
+
/**
|
62 |
+
* Additional fields to fetch from huggingface.co.
|
63 |
+
*/
|
64 |
+
additionalFields?: T[];
|
65 |
+
} & Partial<CredentialsParams>
|
66 |
+
): AsyncGenerator<SpaceEntry & Pick<ApiSpaceInfo, T>> {
|
67 |
+
const accessToken = params && checkCredentials(params);
|
68 |
+
const search = new URLSearchParams([
|
69 |
+
...Object.entries({
|
70 |
+
limit: "500",
|
71 |
+
...(params?.search?.owner ? { author: params.search.owner } : undefined),
|
72 |
+
...(params?.search?.query ? { search: params.search.query } : undefined),
|
73 |
+
}),
|
74 |
+
...(params?.search?.tags?.map((tag) => ["filter", tag]) ?? []),
|
75 |
+
...[...SPACE_EXPAND_KEYS, ...(params?.additionalFields ?? [])].map(
|
76 |
+
(val) => ["expand", val] satisfies [string, string]
|
77 |
+
),
|
78 |
+
]).toString();
|
79 |
+
let url: string | undefined = `${params?.hubUrl || HUB_URL}/api/spaces?${search}`;
|
80 |
+
|
81 |
+
while (url) {
|
82 |
+
const res: Response = await (params?.fetch ?? fetch)(url, {
|
83 |
+
headers: {
|
84 |
+
accept: "application/json",
|
85 |
+
...(accessToken ? { Authorization: `Bearer ${accessToken}` } : undefined),
|
86 |
+
},
|
87 |
+
});
|
88 |
+
|
89 |
+
if (!res.ok) {
|
90 |
+
throw await createApiError(res);
|
91 |
+
}
|
92 |
+
|
93 |
+
const items: ApiSpaceInfo[] = await res.json();
|
94 |
+
|
95 |
+
for (const item of items) {
|
96 |
+
yield {
|
97 |
+
...(params?.additionalFields && pick(item, params.additionalFields)),
|
98 |
+
id: item._id,
|
99 |
+
name: item.id,
|
100 |
+
sdk: item.sdk,
|
101 |
+
likes: item.likes,
|
102 |
+
private: item.private,
|
103 |
+
updatedAt: new Date(item.lastModified),
|
104 |
+
} as SpaceEntry & Pick<ApiSpaceInfo, T>;
|
105 |
+
}
|
106 |
+
|
107 |
+
const linkHeader = res.headers.get("Link");
|
108 |
+
|
109 |
+
url = linkHeader ? parseLinkHeader(linkHeader).next : undefined;
|
110 |
+
}
|
111 |
+
}
|
lib/model-info.spec.ts
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { describe, expect, it } from "vitest";
|
2 |
+
import { modelInfo } from "./model-info";
|
3 |
+
import type { ModelEntry } from "./list-models";
|
4 |
+
import type { ApiModelInfo } from "../types/api/api-model";
|
5 |
+
|
6 |
+
describe("modelInfo", () => {
|
7 |
+
it("should return the model info", async () => {
|
8 |
+
const info = await modelInfo({
|
9 |
+
name: "openai-community/gpt2",
|
10 |
+
});
|
11 |
+
expect(info).toEqual({
|
12 |
+
id: "621ffdc036468d709f17434d",
|
13 |
+
downloads: expect.any(Number),
|
14 |
+
gated: false,
|
15 |
+
name: "openai-community/gpt2",
|
16 |
+
updatedAt: expect.any(Date),
|
17 |
+
likes: expect.any(Number),
|
18 |
+
task: "text-generation",
|
19 |
+
private: false,
|
20 |
+
});
|
21 |
+
});
|
22 |
+
|
23 |
+
it("should return the model info with author", async () => {
|
24 |
+
const info: ModelEntry & Pick<ApiModelInfo, "author"> = await modelInfo({
|
25 |
+
name: "openai-community/gpt2",
|
26 |
+
additionalFields: ["author"],
|
27 |
+
});
|
28 |
+
expect(info).toEqual({
|
29 |
+
id: "621ffdc036468d709f17434d",
|
30 |
+
downloads: expect.any(Number),
|
31 |
+
author: "openai-community",
|
32 |
+
gated: false,
|
33 |
+
name: "openai-community/gpt2",
|
34 |
+
updatedAt: expect.any(Date),
|
35 |
+
likes: expect.any(Number),
|
36 |
+
task: "text-generation",
|
37 |
+
private: false,
|
38 |
+
});
|
39 |
+
});
|
40 |
+
|
41 |
+
it("should return the model info for a specific revision", async () => {
|
42 |
+
const info: ModelEntry & Pick<ApiModelInfo, "sha"> = await modelInfo({
|
43 |
+
name: "openai-community/gpt2",
|
44 |
+
additionalFields: ["sha"],
|
45 |
+
revision: "f27b190eeac4c2302d24068eabf5e9d6044389ae",
|
46 |
+
});
|
47 |
+
expect(info).toEqual({
|
48 |
+
id: "621ffdc036468d709f17434d",
|
49 |
+
downloads: expect.any(Number),
|
50 |
+
gated: false,
|
51 |
+
name: "openai-community/gpt2",
|
52 |
+
updatedAt: expect.any(Date),
|
53 |
+
likes: expect.any(Number),
|
54 |
+
task: "text-generation",
|
55 |
+
private: false,
|
56 |
+
sha: "f27b190eeac4c2302d24068eabf5e9d6044389ae",
|
57 |
+
});
|
58 |
+
});
|
59 |
+
});
|
lib/model-info.ts
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { HUB_URL } from "../consts";
|
2 |
+
import { createApiError } from "../error";
|
3 |
+
import type { ApiModelInfo } from "../types/api/api-model";
|
4 |
+
import type { CredentialsParams } from "../types/public";
|
5 |
+
import { checkCredentials } from "../utils/checkCredentials";
|
6 |
+
import { pick } from "../utils/pick";
|
7 |
+
import { MODEL_EXPAND_KEYS, type MODEL_EXPANDABLE_KEYS, type ModelEntry } from "./list-models";
|
8 |
+
|
9 |
+
export async function modelInfo<
|
10 |
+
const T extends Exclude<(typeof MODEL_EXPANDABLE_KEYS)[number], (typeof MODEL_EXPAND_KEYS)[number]> = never,
|
11 |
+
>(
|
12 |
+
params: {
|
13 |
+
name: string;
|
14 |
+
hubUrl?: string;
|
15 |
+
additionalFields?: T[];
|
16 |
+
/**
|
17 |
+
* An optional Git revision id which can be a branch name, a tag, or a commit hash.
|
18 |
+
*/
|
19 |
+
revision?: string;
|
20 |
+
/**
|
21 |
+
* Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
|
22 |
+
*/
|
23 |
+
fetch?: typeof fetch;
|
24 |
+
} & Partial<CredentialsParams>
|
25 |
+
): Promise<ModelEntry & Pick<ApiModelInfo, T>> {
|
26 |
+
const accessToken = params && checkCredentials(params);
|
27 |
+
|
28 |
+
const search = new URLSearchParams([
|
29 |
+
...MODEL_EXPAND_KEYS.map((val) => ["expand", val] satisfies [string, string]),
|
30 |
+
...(params?.additionalFields?.map((val) => ["expand", val] satisfies [string, string]) ?? []),
|
31 |
+
]).toString();
|
32 |
+
|
33 |
+
const response = await (params.fetch || fetch)(
|
34 |
+
`${params?.hubUrl || HUB_URL}/api/models/${params.name}/revision/${encodeURIComponent(
|
35 |
+
params.revision ?? "HEAD"
|
36 |
+
)}?${search.toString()}`,
|
37 |
+
{
|
38 |
+
headers: {
|
39 |
+
...(accessToken ? { Authorization: `Bearer ${accessToken}` } : {}),
|
40 |
+
Accepts: "application/json",
|
41 |
+
},
|
42 |
+
}
|
43 |
+
);
|
44 |
+
|
45 |
+
if (!response.ok) {
|
46 |
+
throw await createApiError(response);
|
47 |
+
}
|
48 |
+
|
49 |
+
const data = await response.json();
|
50 |
+
|
51 |
+
return {
|
52 |
+
...(params?.additionalFields && pick(data, params.additionalFields)),
|
53 |
+
id: data._id,
|
54 |
+
name: data.id,
|
55 |
+
private: data.private,
|
56 |
+
task: data.pipeline_tag,
|
57 |
+
downloads: data.downloads,
|
58 |
+
gated: data.gated,
|
59 |
+
likes: data.likes,
|
60 |
+
updatedAt: new Date(data.lastModified),
|
61 |
+
} as ModelEntry & Pick<ApiModelInfo, T>;
|
62 |
+
}
|
lib/oauth-handle-redirect.spec.ts
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { describe, expect, it } from "vitest";
|
2 |
+
import { TEST_COOKIE, TEST_HUB_URL } from "../test/consts";
|
3 |
+
import { oauthLoginUrl } from "./oauth-login-url";
|
4 |
+
import { oauthHandleRedirect } from "./oauth-handle-redirect";
|
5 |
+
|
6 |
+
describe("oauthHandleRedirect", () => {
|
7 |
+
it("should work", async () => {
|
8 |
+
const localStorage = {
|
9 |
+
nonce: undefined,
|
10 |
+
codeVerifier: undefined,
|
11 |
+
};
|
12 |
+
const url = await oauthLoginUrl({
|
13 |
+
clientId: "dummy-app",
|
14 |
+
redirectUrl: "http://localhost:3000",
|
15 |
+
localStorage,
|
16 |
+
scopes: "openid profile email",
|
17 |
+
hubUrl: TEST_HUB_URL,
|
18 |
+
});
|
19 |
+
const resp = await fetch(url, {
|
20 |
+
method: "POST",
|
21 |
+
headers: {
|
22 |
+
Cookie: `token=${TEST_COOKIE}`,
|
23 |
+
},
|
24 |
+
redirect: "manual",
|
25 |
+
});
|
26 |
+
if (resp.status !== 303) {
|
27 |
+
throw new Error(`Failed to fetch url ${url}: ${resp.status} ${resp.statusText}`);
|
28 |
+
}
|
29 |
+
const location = resp.headers.get("Location");
|
30 |
+
if (!location) {
|
31 |
+
throw new Error(`No location header in response`);
|
32 |
+
}
|
33 |
+
const result = await oauthHandleRedirect({
|
34 |
+
redirectedUrl: location,
|
35 |
+
codeVerifier: localStorage.codeVerifier,
|
36 |
+
nonce: localStorage.nonce,
|
37 |
+
hubUrl: TEST_HUB_URL,
|
38 |
+
});
|
39 |
+
|
40 |
+
if (!result) {
|
41 |
+
throw new Error("Expected result to be defined");
|
42 |
+
}
|
43 |
+
expect(result.accessToken).toEqual(expect.any(String));
|
44 |
+
expect(result.accessTokenExpiresAt).toBeInstanceOf(Date);
|
45 |
+
expect(result.accessTokenExpiresAt.getTime()).toBeGreaterThan(Date.now());
|
46 |
+
expect(result.scope).toEqual(expect.any(String));
|
47 |
+
expect(result.userInfo).toEqual({
|
48 |
+
sub: "62f264b9f3c90f4b6514a269",
|
49 |
+
name: "@huggingface/hub CI bot",
|
50 |
+
preferred_username: "hub.js",
|
51 |
+
email_verified: true,
|
52 |
+
email: "eliott@huggingface.co",
|
53 |
+
isPro: false,
|
54 |
+
picture: "https://hub-ci.huggingface.co/avatars/934b830e9fdaa879487852f79eef7165.svg",
|
55 |
+
profile: "https://hub-ci.huggingface.co/hub.js",
|
56 |
+
website: "https://github.com/huggingface/hub.js",
|
57 |
+
orgs: [],
|
58 |
+
});
|
59 |
+
});
|
60 |
+
});
|
lib/oauth-handle-redirect.ts
ADDED
@@ -0,0 +1,334 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { HUB_URL } from "../consts";
|
2 |
+
import { createApiError } from "../error";
|
3 |
+
|
4 |
+
export interface UserInfo {
|
5 |
+
/**
|
6 |
+
* OpenID Connect field. Unique identifier for the user, even in case of rename.
|
7 |
+
*/
|
8 |
+
sub: string;
|
9 |
+
/**
|
10 |
+
* OpenID Connect field. The user's full name.
|
11 |
+
*/
|
12 |
+
name: string;
|
13 |
+
/**
|
14 |
+
* OpenID Connect field. The user's username.
|
15 |
+
*/
|
16 |
+
preferred_username: string;
|
17 |
+
/**
|
18 |
+
* OpenID Connect field, available if scope "email" was granted.
|
19 |
+
*/
|
20 |
+
email_verified?: boolean;
|
21 |
+
/**
|
22 |
+
* OpenID Connect field, available if scope "email" was granted.
|
23 |
+
*/
|
24 |
+
email?: string;
|
25 |
+
/**
|
26 |
+
* OpenID Connect field. The user's profile picture URL.
|
27 |
+
*/
|
28 |
+
picture: string;
|
29 |
+
/**
|
30 |
+
* OpenID Connect field. The user's profile URL.
|
31 |
+
*/
|
32 |
+
profile: string;
|
33 |
+
/**
|
34 |
+
* OpenID Connect field. The user's website URL.
|
35 |
+
*/
|
36 |
+
website?: string;
|
37 |
+
|
38 |
+
/**
|
39 |
+
* Hugging Face field. Whether the user is a pro user.
|
40 |
+
*/
|
41 |
+
isPro: boolean;
|
42 |
+
/**
|
43 |
+
* Hugging Face field. Whether the user has a payment method set up. Needs "read-billing" scope.
|
44 |
+
*/
|
45 |
+
canPay?: boolean;
|
46 |
+
/**
|
47 |
+
* Hugging Face field. The user's orgs
|
48 |
+
*/
|
49 |
+
orgs?: Array<{
|
50 |
+
/**
|
51 |
+
* OpenID Connect field. Unique identifier for the org.
|
52 |
+
*/
|
53 |
+
sub: string;
|
54 |
+
/**
|
55 |
+
* OpenID Connect field. The org's full name.
|
56 |
+
*/
|
57 |
+
name: string;
|
58 |
+
/**
|
59 |
+
* OpenID Connect field. The org's username.
|
60 |
+
*/
|
61 |
+
preferred_username: string;
|
62 |
+
/**
|
63 |
+
* OpenID Connect field. The org's profile picture URL.
|
64 |
+
*/
|
65 |
+
picture: string;
|
66 |
+
|
67 |
+
/**
|
68 |
+
* Hugging Face field. Whether the org is an enterprise org.
|
69 |
+
*/
|
70 |
+
isEnterprise: boolean;
|
71 |
+
/**
|
72 |
+
* Hugging Face field. Whether the org has a payment method set up. Needs "read-billing" scope, and the user needs to approve access to the org in the OAuth page.
|
73 |
+
*/
|
74 |
+
canPay?: boolean;
|
75 |
+
/**
|
76 |
+
* Hugging Face field. The user's role in the org. The user needs to approve access to the org in the OAuth page.
|
77 |
+
*/
|
78 |
+
roleInOrg?: string;
|
79 |
+
/**
|
80 |
+
* HuggingFace field. When the user granted the oauth app access to the org, but didn't complete SSO.
|
81 |
+
*
|
82 |
+
* Should never happen directly after the oauth flow.
|
83 |
+
*/
|
84 |
+
pendingSSO?: boolean;
|
85 |
+
/**
|
86 |
+
* HuggingFace field. When the user granted the oauth app access to the org, but didn't complete MFA.
|
87 |
+
*
|
88 |
+
* Should never happen directly after the oauth flow.
|
89 |
+
*/
|
90 |
+
missingMFA?: boolean;
|
91 |
+
}>;
|
92 |
+
}
|
93 |
+
|
94 |
+
export interface OAuthResult {
|
95 |
+
accessToken: string;
|
96 |
+
accessTokenExpiresAt: Date;
|
97 |
+
userInfo: UserInfo;
|
98 |
+
/**
|
99 |
+
* State passed to the OAuth provider in the original request to the OAuth provider.
|
100 |
+
*/
|
101 |
+
state?: string;
|
102 |
+
/**
|
103 |
+
* Granted scope
|
104 |
+
*/
|
105 |
+
scope: string;
|
106 |
+
}
|
107 |
+
|
108 |
+
/**
|
109 |
+
* To call after the OAuth provider redirects back to the app.
|
110 |
+
*
|
111 |
+
* There is also a helper function {@link oauthHandleRedirectIfPresent}, which will call `oauthHandleRedirect` if the URL contains an oauth code
|
112 |
+
* in the query parameters and return `false` otherwise.
|
113 |
+
*/
|
114 |
+
export async function oauthHandleRedirect(opts?: {
|
115 |
+
/**
|
116 |
+
* The URL of the hub. Defaults to {@link HUB_URL}.
|
117 |
+
*/
|
118 |
+
hubUrl?: string;
|
119 |
+
/**
|
120 |
+
* The URL to analyze.
|
121 |
+
*
|
122 |
+
* @default window.location.href
|
123 |
+
*/
|
124 |
+
redirectedUrl?: string;
|
125 |
+
/**
|
126 |
+
* nonce generated by oauthLoginUrl
|
127 |
+
*
|
128 |
+
* @default localStorage.getItem("huggingface.co:oauth:nonce")
|
129 |
+
*/
|
130 |
+
nonce?: string;
|
131 |
+
/**
|
132 |
+
* codeVerifier generated by oauthLoginUrl
|
133 |
+
*
|
134 |
+
* @default localStorage.getItem("huggingface.co:oauth:code_verifier")
|
135 |
+
*/
|
136 |
+
codeVerifier?: string;
|
137 |
+
}): Promise<OAuthResult> {
|
138 |
+
if (typeof window === "undefined" && !opts?.redirectedUrl) {
|
139 |
+
throw new Error("oauthHandleRedirect is only available in the browser, unless you provide redirectedUrl");
|
140 |
+
}
|
141 |
+
if (typeof localStorage === "undefined" && (!opts?.nonce || !opts?.codeVerifier)) {
|
142 |
+
throw new Error(
|
143 |
+
"oauthHandleRedirect requires localStorage to be available, unless you provide nonce and codeVerifier"
|
144 |
+
);
|
145 |
+
}
|
146 |
+
|
147 |
+
const redirectedUrl = opts?.redirectedUrl ?? window.location.href;
|
148 |
+
const searchParams = (() => {
|
149 |
+
try {
|
150 |
+
return new URL(redirectedUrl).searchParams;
|
151 |
+
} catch (err) {
|
152 |
+
throw new Error("Failed to parse redirected URL: " + redirectedUrl);
|
153 |
+
}
|
154 |
+
})();
|
155 |
+
|
156 |
+
const [error, errorDescription] = [searchParams.get("error"), searchParams.get("error_description")];
|
157 |
+
|
158 |
+
if (error) {
|
159 |
+
throw new Error(`${error}: ${errorDescription}`);
|
160 |
+
}
|
161 |
+
|
162 |
+
const code = searchParams.get("code");
|
163 |
+
const nonce = opts?.nonce ?? localStorage.getItem("huggingface.co:oauth:nonce");
|
164 |
+
|
165 |
+
if (!code) {
|
166 |
+
throw new Error("Missing oauth code from query parameters in redirected URL: " + redirectedUrl);
|
167 |
+
}
|
168 |
+
|
169 |
+
if (!nonce) {
|
170 |
+
throw new Error("Missing oauth nonce from localStorage");
|
171 |
+
}
|
172 |
+
|
173 |
+
const codeVerifier = opts?.codeVerifier ?? localStorage.getItem("huggingface.co:oauth:code_verifier");
|
174 |
+
|
175 |
+
if (!codeVerifier) {
|
176 |
+
throw new Error("Missing oauth code_verifier from localStorage");
|
177 |
+
}
|
178 |
+
|
179 |
+
const state = searchParams.get("state");
|
180 |
+
|
181 |
+
if (!state) {
|
182 |
+
throw new Error("Missing oauth state from query parameters in redirected URL");
|
183 |
+
}
|
184 |
+
|
185 |
+
let parsedState: { nonce: string; redirectUri: string; state?: string };
|
186 |
+
|
187 |
+
try {
|
188 |
+
parsedState = JSON.parse(state);
|
189 |
+
} catch {
|
190 |
+
throw new Error("Invalid oauth state in redirected URL, unable to parse JSON: " + state);
|
191 |
+
}
|
192 |
+
|
193 |
+
if (parsedState.nonce !== nonce) {
|
194 |
+
throw new Error("Invalid oauth state in redirected URL");
|
195 |
+
}
|
196 |
+
|
197 |
+
const hubUrl = opts?.hubUrl || HUB_URL;
|
198 |
+
|
199 |
+
const openidConfigUrl = `${new URL(hubUrl).origin}/.well-known/openid-configuration`;
|
200 |
+
const openidConfigRes = await fetch(openidConfigUrl, {
|
201 |
+
headers: {
|
202 |
+
Accept: "application/json",
|
203 |
+
},
|
204 |
+
});
|
205 |
+
|
206 |
+
if (!openidConfigRes.ok) {
|
207 |
+
throw await createApiError(openidConfigRes);
|
208 |
+
}
|
209 |
+
|
210 |
+
const openidConfig: {
|
211 |
+
authorization_endpoint: string;
|
212 |
+
token_endpoint: string;
|
213 |
+
userinfo_endpoint: string;
|
214 |
+
} = await openidConfigRes.json();
|
215 |
+
|
216 |
+
const tokenRes = await fetch(openidConfig.token_endpoint, {
|
217 |
+
method: "POST",
|
218 |
+
headers: {
|
219 |
+
"Content-Type": "application/x-www-form-urlencoded",
|
220 |
+
},
|
221 |
+
body: new URLSearchParams({
|
222 |
+
grant_type: "authorization_code",
|
223 |
+
code,
|
224 |
+
redirect_uri: parsedState.redirectUri,
|
225 |
+
code_verifier: codeVerifier,
|
226 |
+
}).toString(),
|
227 |
+
});
|
228 |
+
|
229 |
+
if (!opts?.codeVerifier) {
|
230 |
+
localStorage.removeItem("huggingface.co:oauth:code_verifier");
|
231 |
+
}
|
232 |
+
if (!opts?.nonce) {
|
233 |
+
localStorage.removeItem("huggingface.co:oauth:nonce");
|
234 |
+
}
|
235 |
+
|
236 |
+
if (!tokenRes.ok) {
|
237 |
+
throw await createApiError(tokenRes);
|
238 |
+
}
|
239 |
+
|
240 |
+
const token: {
|
241 |
+
access_token: string;
|
242 |
+
expires_in: number;
|
243 |
+
id_token: string;
|
244 |
+
// refresh_token: string;
|
245 |
+
scope: string;
|
246 |
+
token_type: string;
|
247 |
+
} = await tokenRes.json();
|
248 |
+
|
249 |
+
const accessTokenExpiresAt = new Date(Date.now() + token.expires_in * 1000);
|
250 |
+
|
251 |
+
const userInfoRes = await fetch(openidConfig.userinfo_endpoint, {
|
252 |
+
headers: {
|
253 |
+
Authorization: `Bearer ${token.access_token}`,
|
254 |
+
},
|
255 |
+
});
|
256 |
+
|
257 |
+
if (!userInfoRes.ok) {
|
258 |
+
throw await createApiError(userInfoRes);
|
259 |
+
}
|
260 |
+
|
261 |
+
const userInfo: UserInfo = await userInfoRes.json();
|
262 |
+
|
263 |
+
return {
|
264 |
+
accessToken: token.access_token,
|
265 |
+
accessTokenExpiresAt,
|
266 |
+
userInfo: userInfo,
|
267 |
+
state: parsedState.state,
|
268 |
+
scope: token.scope,
|
269 |
+
};
|
270 |
+
}
|
271 |
+
|
272 |
+
// if (code && !nonce) {
|
273 |
+
// console.warn("Missing oauth nonce from localStorage");
|
274 |
+
// }
|
275 |
+
|
276 |
+
/**
|
277 |
+
* To call after the OAuth provider redirects back to the app.
|
278 |
+
*
|
279 |
+
* It returns false if the URL does not contain an oauth code in the query parameters, otherwise
|
280 |
+
* it calls {@link oauthHandleRedirect}.
|
281 |
+
*
|
282 |
+
* Depending on your app, you may want to call {@link oauthHandleRedirect} directly instead.
|
283 |
+
*/
|
284 |
+
export async function oauthHandleRedirectIfPresent(opts?: {
|
285 |
+
/**
|
286 |
+
* The URL of the hub. Defaults to {@link HUB_URL}.
|
287 |
+
*/
|
288 |
+
hubUrl?: string;
|
289 |
+
/**
|
290 |
+
* The URL to analyze.
|
291 |
+
*
|
292 |
+
* @default window.location.href
|
293 |
+
*/
|
294 |
+
redirectedUrl?: string;
|
295 |
+
/**
|
296 |
+
* nonce generated by oauthLoginUrl
|
297 |
+
*
|
298 |
+
* @default localStorage.getItem("huggingface.co:oauth:nonce")
|
299 |
+
*/
|
300 |
+
nonce?: string;
|
301 |
+
/**
|
302 |
+
* codeVerifier generated by oauthLoginUrl
|
303 |
+
*
|
304 |
+
* @default localStorage.getItem("huggingface.co:oauth:code_verifier")
|
305 |
+
*/
|
306 |
+
codeVerifier?: string;
|
307 |
+
}): Promise<OAuthResult | false> {
|
308 |
+
if (typeof window === "undefined" && !opts?.redirectedUrl) {
|
309 |
+
throw new Error("oauthHandleRedirect is only available in the browser, unless you provide redirectedUrl");
|
310 |
+
}
|
311 |
+
if (typeof localStorage === "undefined" && (!opts?.nonce || !opts?.codeVerifier)) {
|
312 |
+
throw new Error(
|
313 |
+
"oauthHandleRedirect requires localStorage to be available, unless you provide nonce and codeVerifier"
|
314 |
+
);
|
315 |
+
}
|
316 |
+
const searchParams = new URLSearchParams(opts?.redirectedUrl ?? window.location.search);
|
317 |
+
|
318 |
+
if (searchParams.has("error")) {
|
319 |
+
return oauthHandleRedirect(opts);
|
320 |
+
}
|
321 |
+
|
322 |
+
if (searchParams.has("code")) {
|
323 |
+
if (!localStorage.getItem("huggingface.co:oauth:nonce")) {
|
324 |
+
console.warn(
|
325 |
+
"Missing oauth nonce from localStorage. This can happen when the user refreshes the page after logging in, without changing the URL."
|
326 |
+
);
|
327 |
+
return false;
|
328 |
+
}
|
329 |
+
|
330 |
+
return oauthHandleRedirect(opts);
|
331 |
+
}
|
332 |
+
|
333 |
+
return false;
|
334 |
+
}
|
lib/oauth-login-url.ts
ADDED
@@ -0,0 +1,166 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { HUB_URL } from "../consts";
|
2 |
+
import { createApiError } from "../error";
|
3 |
+
import { base64FromBytes } from "../utils/base64FromBytes";
|
4 |
+
|
5 |
+
/**
|
6 |
+
* Use "Sign in with Hub" to authenticate a user, and get oauth user info / access token.
|
7 |
+
*
|
8 |
+
* Returns an url to redirect to. After the user is redirected back to your app, call `oauthHandleRedirect` to get the oauth user info / access token.
|
9 |
+
*
|
10 |
+
* When called from inside a static Space with OAuth enabled, it will load the config from the space, otherwise you need to at least specify
|
11 |
+
* the client ID of your OAuth App.
|
12 |
+
*
|
13 |
+
* @example
|
14 |
+
* ```ts
|
15 |
+
* import { oauthLoginUrl, oauthHandleRedirectIfPresent } from "@huggingface/hub";
|
16 |
+
*
|
17 |
+
* const oauthResult = await oauthHandleRedirectIfPresent();
|
18 |
+
*
|
19 |
+
* if (!oauthResult) {
|
20 |
+
* // If the user is not logged in, redirect to the login page
|
21 |
+
* window.location.href = await oauthLoginUrl();
|
22 |
+
* }
|
23 |
+
*
|
24 |
+
* // You can use oauthResult.accessToken, oauthResult.accessTokenExpiresAt and oauthResult.userInfo
|
25 |
+
* console.log(oauthResult);
|
26 |
+
* ```
|
27 |
+
*
|
28 |
+
* (Theoretically, this function could be used to authenticate a user for any OAuth provider supporting PKCE and OpenID Connect by changing `hubUrl`,
|
29 |
+
* but it is currently only tested with the Hugging Face Hub.)
|
30 |
+
*/
|
31 |
+
export async function oauthLoginUrl(opts?: {
|
32 |
+
/**
|
33 |
+
* OAuth client ID.
|
34 |
+
*
|
35 |
+
* For static Spaces, you can omit this and it will be loaded from the Space config, as long as `hf_oauth: true` is present in the README.md's metadata.
|
36 |
+
* For other Spaces, it is available to the backend in the OAUTH_CLIENT_ID environment variable, as long as `hf_oauth: true` is present in the README.md's metadata.
|
37 |
+
*
|
38 |
+
* You can also create a Developer Application at https://huggingface.co/settings/connected-applications and use its client ID.
|
39 |
+
*/
|
40 |
+
clientId?: string;
|
41 |
+
hubUrl?: string;
|
42 |
+
/**
|
43 |
+
* OAuth scope, a list of space-separated scopes.
|
44 |
+
*
|
45 |
+
* For static Spaces, you can omit this and it will be loaded from the Space config, as long as `hf_oauth: true` is present in the README.md's metadata.
|
46 |
+
* For other Spaces, it is available to the backend in the OAUTH_SCOPES environment variable, as long as `hf_oauth: true` is present in the README.md's metadata.
|
47 |
+
*
|
48 |
+
* Defaults to "openid profile".
|
49 |
+
*
|
50 |
+
* You can also create a Developer Application at https://huggingface.co/settings/connected-applications and use its scopes.
|
51 |
+
*
|
52 |
+
* See https://huggingface.co/docs/hub/oauth for a list of available scopes.
|
53 |
+
*/
|
54 |
+
scopes?: string;
|
55 |
+
/**
|
56 |
+
* Redirect URI, defaults to the current URL.
|
57 |
+
*
|
58 |
+
* For Spaces, any URL within the Space is allowed.
|
59 |
+
*
|
60 |
+
* For Developer Applications, you can add any URL you want to the list of allowed redirect URIs at https://huggingface.co/settings/connected-applications.
|
61 |
+
*/
|
62 |
+
redirectUrl?: string;
|
63 |
+
/**
|
64 |
+
* State to pass to the OAuth provider, which will be returned in the call to `oauthLogin` after the redirect.
|
65 |
+
*/
|
66 |
+
state?: string;
|
67 |
+
/**
|
68 |
+
* If provided, will be filled with the code verifier and nonce used for the OAuth flow,
|
69 |
+
* instead of using localStorage.
|
70 |
+
*
|
71 |
+
* When calling {@link `oauthHandleRedirectIfPresent`} or {@link `oauthHandleRedirect`} you will need to provide the same values.
|
72 |
+
*/
|
73 |
+
localStorage?: {
|
74 |
+
codeVerifier?: string;
|
75 |
+
nonce?: string;
|
76 |
+
};
|
77 |
+
}): Promise<string> {
|
78 |
+
if (typeof window === "undefined" && (!opts?.redirectUrl || !opts?.clientId)) {
|
79 |
+
throw new Error("oauthLogin is only available in the browser, unless you provide clientId and redirectUrl");
|
80 |
+
}
|
81 |
+
if (typeof localStorage === "undefined" && !opts?.localStorage) {
|
82 |
+
throw new Error(
|
83 |
+
"oauthLogin requires localStorage to be available in the context, unless you provide a localStorage empty object as argument"
|
84 |
+
);
|
85 |
+
}
|
86 |
+
|
87 |
+
const hubUrl = opts?.hubUrl || HUB_URL;
|
88 |
+
const openidConfigUrl = `${new URL(hubUrl).origin}/.well-known/openid-configuration`;
|
89 |
+
const openidConfigRes = await fetch(openidConfigUrl, {
|
90 |
+
headers: {
|
91 |
+
Accept: "application/json",
|
92 |
+
},
|
93 |
+
});
|
94 |
+
|
95 |
+
if (!openidConfigRes.ok) {
|
96 |
+
throw await createApiError(openidConfigRes);
|
97 |
+
}
|
98 |
+
|
99 |
+
const opendidConfig: {
|
100 |
+
authorization_endpoint: string;
|
101 |
+
token_endpoint: string;
|
102 |
+
userinfo_endpoint: string;
|
103 |
+
} = await openidConfigRes.json();
|
104 |
+
|
105 |
+
const newNonce = globalThis.crypto.randomUUID();
|
106 |
+
// Two random UUIDs concatenated together, because min length is 43 and max length is 128
|
107 |
+
const newCodeVerifier = globalThis.crypto.randomUUID() + globalThis.crypto.randomUUID();
|
108 |
+
|
109 |
+
if (opts?.localStorage) {
|
110 |
+
if (opts.localStorage.codeVerifier !== undefined && opts.localStorage.codeVerifier !== null) {
|
111 |
+
throw new Error(
|
112 |
+
"localStorage.codeVerifier must be initially set to null or undefined, and will be filled by oauthLoginUrl"
|
113 |
+
);
|
114 |
+
}
|
115 |
+
if (opts.localStorage.nonce !== undefined && opts.localStorage.nonce !== null) {
|
116 |
+
throw new Error(
|
117 |
+
"localStorage.nonce must be initially set to null or undefined, and will be filled by oauthLoginUrl"
|
118 |
+
);
|
119 |
+
}
|
120 |
+
opts.localStorage.codeVerifier = newCodeVerifier;
|
121 |
+
opts.localStorage.nonce = newNonce;
|
122 |
+
} else {
|
123 |
+
localStorage.setItem("huggingface.co:oauth:nonce", newNonce);
|
124 |
+
localStorage.setItem("huggingface.co:oauth:code_verifier", newCodeVerifier);
|
125 |
+
}
|
126 |
+
|
127 |
+
const redirectUri = opts?.redirectUrl || (typeof window !== "undefined" ? window.location.href : undefined);
|
128 |
+
if (!redirectUri) {
|
129 |
+
throw new Error("Missing redirectUrl");
|
130 |
+
}
|
131 |
+
const state = JSON.stringify({
|
132 |
+
nonce: newNonce,
|
133 |
+
redirectUri,
|
134 |
+
state: opts?.state,
|
135 |
+
});
|
136 |
+
|
137 |
+
const variables: Record<string, string> | null =
|
138 |
+
// @ts-expect-error window.huggingface is defined inside static Spaces.
|
139 |
+
typeof window !== "undefined" ? window.huggingface?.variables ?? null : null;
|
140 |
+
|
141 |
+
const clientId = opts?.clientId || variables?.OAUTH_CLIENT_ID;
|
142 |
+
|
143 |
+
if (!clientId) {
|
144 |
+
if (variables) {
|
145 |
+
throw new Error("Missing clientId, please add hf_oauth: true to the README.md's metadata in your static Space");
|
146 |
+
}
|
147 |
+
throw new Error("Missing clientId");
|
148 |
+
}
|
149 |
+
|
150 |
+
const challenge = base64FromBytes(
|
151 |
+
new Uint8Array(await globalThis.crypto.subtle.digest("SHA-256", new TextEncoder().encode(newCodeVerifier)))
|
152 |
+
)
|
153 |
+
.replace(/[+]/g, "-")
|
154 |
+
.replace(/[/]/g, "_")
|
155 |
+
.replace(/=/g, "");
|
156 |
+
|
157 |
+
return `${opendidConfig.authorization_endpoint}?${new URLSearchParams({
|
158 |
+
client_id: clientId,
|
159 |
+
scope: opts?.scopes || variables?.OAUTH_SCOPES || "openid profile",
|
160 |
+
response_type: "code",
|
161 |
+
redirect_uri: redirectUri,
|
162 |
+
state,
|
163 |
+
code_challenge: challenge,
|
164 |
+
code_challenge_method: "S256",
|
165 |
+
}).toString()}`;
|
166 |
+
}
|
lib/parse-safetensors-metadata.spec.ts
ADDED
@@ -0,0 +1,122 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import { assert, it, describe } from "vitest";
|
2 |
+
import { parseSafetensorsMetadata, parseSafetensorsShardFilename } from "./parse-safetensors-metadata";
|
3 |
+
import { sum } from "../utils/sum";
|
4 |
+
|
5 |
+
describe("parseSafetensorsMetadata", () => {
|
6 |
+
it("fetch info for single-file (with the default conventional filename)", async () => {
|
7 |
+
const parse = await parseSafetensorsMetadata({
|
8 |
+
repo: "bert-base-uncased",
|
9 |
+
computeParametersCount: true,
|
10 |
+
revision: "86b5e0934494bd15c9632b12f734a8a67f723594",
|
11 |
+
});
|
12 |
+
|
13 |
+
assert(!parse.sharded);
|
14 |
+
assert.deepStrictEqual(parse.header.__metadata__, { format: "pt" });
|
15 |
+
|
16 |
+
// Example of one tensor (the header contains many tensors)
|
17 |
+
|
18 |
+
assert.deepStrictEqual(parse.header["bert.embeddings.LayerNorm.beta"], {
|
19 |
+
dtype: "F32",
|
20 |
+
shape: [768],
|
21 |
+
data_offsets: [0, 3072],
|
22 |
+
});
|
23 |
+
|
24 |
+
assert.deepStrictEqual(parse.parameterCount, { F32: 110_106_428 });
|
25 |
+
assert.deepStrictEqual(sum(Object.values(parse.parameterCount)), 110_106_428);
|
26 |
+
// total params = 110m
|
27 |
+
});
|
28 |
+
|
29 |
+
it("fetch info for sharded (with the default conventional filename)", async () => {
|
30 |
+
const parse = await parseSafetensorsMetadata({
|
31 |
+
repo: "bigscience/bloom",
|
32 |
+
computeParametersCount: true,
|
33 |
+
revision: "053d9cd9fbe814e091294f67fcfedb3397b954bb",
|
34 |
+
});
|
35 |
+
|
36 |
+
assert(parse.sharded);
|
37 |
+
|
38 |
+
assert.strictEqual(Object.keys(parse.headers).length, 72);
|
39 |
+
// This model has 72 shards!
|
40 |
+
|
41 |
+
// Example of one tensor inside one file
|
42 |
+
|
43 |
+
assert.deepStrictEqual(parse.headers["model_00012-of-00072.safetensors"]["h.10.input_layernorm.weight"], {
|
44 |
+
dtype: "BF16",
|
45 |
+
shape: [14336],
|
46 |
+
data_offsets: [3288649728, 3288678400],
|
47 |
+
});
|
48 |
+
|
49 |
+
assert.deepStrictEqual(parse.parameterCount, { BF16: 176_247_271_424 });
|
50 |
+
assert.deepStrictEqual(sum(Object.values(parse.parameterCount)), 176_247_271_424);
|
51 |
+
// total params = 176B
|
52 |
+
});
|
53 |
+
|
54 |
+
it("fetch info for single-file with multiple dtypes", async () => {
|
55 |
+
const parse = await parseSafetensorsMetadata({
|
56 |
+
repo: "roberta-base",
|
57 |
+
computeParametersCount: true,
|
58 |
+
revision: "e2da8e2f811d1448a5b465c236feacd80ffbac7b",
|
59 |
+
});
|
60 |
+
|
61 |
+
assert(!parse.sharded);
|
62 |
+
|
63 |
+
assert.deepStrictEqual(parse.parameterCount, { F32: 124_697_433, I64: 514 });
|
64 |
+
assert.deepStrictEqual(sum(Object.values(parse.parameterCount)), 124_697_947);
|
65 |
+
// total params = 124m
|
66 |
+
});
|
67 |
+
|
68 |
+
it("fetch info for single-file with file path", async () => {
|
69 |
+
const parse = await parseSafetensorsMetadata({
|
70 |
+
repo: "CompVis/stable-diffusion-v1-4",
|
71 |
+
computeParametersCount: true,
|
72 |
+
path: "unet/diffusion_pytorch_model.safetensors",
|
73 |
+
revision: "133a221b8aa7292a167afc5127cb63fb5005638b",
|
74 |
+
});
|
75 |
+
|
76 |
+
assert(!parse.sharded);
|
77 |
+
assert.deepStrictEqual(parse.header.__metadata__, { format: "pt" });
|
78 |
+
|
79 |
+
// Example of one tensor (the header contains many tensors)
|
80 |
+
|
81 |
+
assert.deepStrictEqual(parse.header["up_blocks.3.resnets.0.norm2.bias"], {
|
82 |
+
dtype: "F32",
|
83 |
+
shape: [320],
|
84 |
+
data_offsets: [3_409_382_416, 3_409_383_696],
|
85 |
+
});
|
86 |
+
|
87 |
+
assert.deepStrictEqual(parse.parameterCount, { F32: 859_520_964 });
|
88 |
+
assert.deepStrictEqual(sum(Object.values(parse.parameterCount)), 859_520_964);
|
89 |
+
});
|
90 |
+
|
91 |
+
it("fetch info for sharded (with the default conventional filename) with file path", async () => {
|
92 |
+
const parse = await parseSafetensorsMetadata({
|
93 |
+
repo: "Alignment-Lab-AI/ALAI-gemma-7b",
|
94 |
+
computeParametersCount: true,
|
95 |
+
path: "7b/1/model.safetensors.index.json",
|
96 |
+
revision: "37e307261fe97bbf8b2463d61dbdd1a10daa264c",
|
97 |
+
});
|
98 |
+
|
99 |
+
assert(parse.sharded);
|
100 |
+
|
101 |
+
assert.strictEqual(Object.keys(parse.headers).length, 4);
|
102 |
+
|
103 |
+
assert.deepStrictEqual(parse.headers["model-00004-of-00004.safetensors"]["model.layers.24.mlp.up_proj.weight"], {
|
104 |
+
dtype: "BF16",
|
105 |
+
shape: [24576, 3072],
|
106 |
+
data_offsets: [301996032, 452990976],
|
107 |
+
});
|
108 |
+
|
109 |
+
assert.deepStrictEqual(parse.parameterCount, { BF16: 8_537_680_896 });
|
110 |
+
assert.deepStrictEqual(sum(Object.values(parse.parameterCount)), 8_537_680_896);
|
111 |
+
});
|
112 |
+
|
113 |
+
it("should detect sharded safetensors filename", async () => {
|
114 |
+
const safetensorsFilename = "model_00005-of-00072.safetensors"; // https://huggingface.co/bigscience/bloom/blob/4d8e28c67403974b0f17a4ac5992e4ba0b0dbb6f/model_00005-of-00072.safetensors
|
115 |
+
const safetensorsShardFileInfo = parseSafetensorsShardFilename(safetensorsFilename);
|
116 |
+
|
117 |
+
assert.strictEqual(safetensorsShardFileInfo?.prefix, "model_");
|
118 |
+
assert.strictEqual(safetensorsShardFileInfo?.basePrefix, "model");
|
119 |
+
assert.strictEqual(safetensorsShardFileInfo?.shard, "00005");
|
120 |
+
assert.strictEqual(safetensorsShardFileInfo?.total, "00072");
|
121 |
+
});
|
122 |
+
});
|
lib/parse-safetensors-metadata.ts
ADDED
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import type { CredentialsParams, RepoDesignation } from "../types/public";
|
2 |
+
import { omit } from "../utils/omit";
|
3 |
+
import { toRepoId } from "../utils/toRepoId";
|
4 |
+
import { typedEntries } from "../utils/typedEntries";
|
5 |
+
import { downloadFile } from "./download-file";
|
6 |
+
import { fileExists } from "./file-exists";
|
7 |
+
import { promisesQueue } from "../utils/promisesQueue";
|
8 |
+
import type { SetRequired } from "../vendor/type-fest/set-required";
|
9 |
+
|
10 |
+
export const SAFETENSORS_FILE = "model.safetensors";
|
11 |
+
export const SAFETENSORS_INDEX_FILE = "model.safetensors.index.json";
|
12 |
+
/// We advise model/library authors to use the filenames above for convention inside model repos,
|
13 |
+
/// but in some situations safetensors weights have different filenames.
|
14 |
+
export const RE_SAFETENSORS_FILE = /\.safetensors$/;
|
15 |
+
export const RE_SAFETENSORS_INDEX_FILE = /\.safetensors\.index\.json$/;
|
16 |
+
export const RE_SAFETENSORS_SHARD_FILE =
|
17 |
+
/^(?<prefix>(?<basePrefix>.*?)[_-])(?<shard>\d{5})-of-(?<total>\d{5})\.safetensors$/;
|
18 |
+
export interface SafetensorsShardFileInfo {
|
19 |
+
prefix: string;
|
20 |
+
basePrefix: string;
|
21 |
+
shard: string;
|
22 |
+
total: string;
|
23 |
+
}
|
24 |
+
export function parseSafetensorsShardFilename(filename: string): SafetensorsShardFileInfo | null {
|
25 |
+
const match = RE_SAFETENSORS_SHARD_FILE.exec(filename);
|
26 |
+
if (match && match.groups) {
|
27 |
+
return {
|
28 |
+
prefix: match.groups["prefix"],
|
29 |
+
basePrefix: match.groups["basePrefix"],
|
30 |
+
shard: match.groups["shard"],
|
31 |
+
total: match.groups["total"],
|
32 |
+
};
|
33 |
+
}
|
34 |
+
return null;
|
35 |
+
}
|
36 |
+
|
37 |
+
const PARALLEL_DOWNLOADS = 20;
|
38 |
+
const MAX_HEADER_LENGTH = 25_000_000;
|
39 |
+
|
40 |
+
class SafetensorParseError extends Error {}
|
41 |
+
|
42 |
+
type FileName = string;
|
43 |
+
|
44 |
+
export type TensorName = string;
|
45 |
+
export type Dtype = "F64" | "F32" | "F16" | "BF16" | "I64" | "I32" | "I16" | "I8" | "U8" | "BOOL";
|
46 |
+
|
47 |
+
export interface TensorInfo {
|
48 |
+
dtype: Dtype;
|
49 |
+
shape: number[];
|
50 |
+
data_offsets: [number, number];
|
51 |
+
}
|
52 |
+
|
53 |
+
export type SafetensorsFileHeader = Record<TensorName, TensorInfo> & {
|
54 |
+
__metadata__: Record<string, string>;
|
55 |
+
};
|
56 |
+
|
57 |
+
export interface SafetensorsIndexJson {
|
58 |
+
dtype?: string;
|
59 |
+
/// ^there's sometimes a dtype but it looks inconsistent.
|
60 |
+
metadata?: Record<string, string>;
|
61 |
+
/// ^ why the naming inconsistency?
|
62 |
+
weight_map: Record<TensorName, FileName>;
|
63 |
+
}
|
64 |
+
|
65 |
+
export type SafetensorsShardedHeaders = Record<FileName, SafetensorsFileHeader>;
|
66 |
+
|
67 |
+
export type SafetensorsParseFromRepo =
|
68 |
+
| {
|
69 |
+
sharded: false;
|
70 |
+
header: SafetensorsFileHeader;
|
71 |
+
parameterCount?: Partial<Record<Dtype, number>>;
|
72 |
+
}
|
73 |
+
| {
|
74 |
+
sharded: true;
|
75 |
+
index: SafetensorsIndexJson;
|
76 |
+
headers: SafetensorsShardedHeaders;
|
77 |
+
parameterCount?: Partial<Record<Dtype, number>>;
|
78 |
+
};
|
79 |
+
|
80 |
+
async function parseSingleFile(
|
81 |
+
path: string,
|
82 |
+
params: {
|
83 |
+
repo: RepoDesignation;
|
84 |
+
revision?: string;
|
85 |
+
hubUrl?: string;
|
86 |
+
/**
|
87 |
+
* Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
|
88 |
+
*/
|
89 |
+
fetch?: typeof fetch;
|
90 |
+
} & Partial<CredentialsParams>
|
91 |
+
): Promise<SafetensorsFileHeader> {
|
92 |
+
const blob = await downloadFile({ ...params, path });
|
93 |
+
|
94 |
+
if (!blob) {
|
95 |
+
throw new SafetensorParseError(`Failed to parse file ${path}: failed to fetch safetensors header length.`);
|
96 |
+
}
|
97 |
+
|
98 |
+
const bufLengthOfHeaderLE = await blob.slice(0, 8).arrayBuffer();
|
99 |
+
const lengthOfHeader = new DataView(bufLengthOfHeaderLE).getBigUint64(0, true);
|
100 |
+
// ^little-endian
|
101 |
+
if (lengthOfHeader <= 0) {
|
102 |
+
throw new SafetensorParseError(`Failed to parse file ${path}: safetensors header is malformed.`);
|
103 |
+
}
|
104 |
+
if (lengthOfHeader > MAX_HEADER_LENGTH) {
|
105 |
+
throw new SafetensorParseError(
|
106 |
+
`Failed to parse file ${path}: safetensor header is too big. Maximum supported size is ${MAX_HEADER_LENGTH} bytes.`
|
107 |
+
);
|
108 |
+
}
|
109 |
+
|
110 |
+
try {
|
111 |
+
// no validation for now, we assume it's a valid FileHeader.
|
112 |
+
const header: SafetensorsFileHeader = JSON.parse(await blob.slice(8, 8 + Number(lengthOfHeader)).text());
|
113 |
+
return header;
|
114 |
+
} catch (err) {
|
115 |
+
throw new SafetensorParseError(`Failed to parse file ${path}: safetensors header is not valid JSON.`);
|
116 |
+
}
|
117 |
+
}
|
118 |
+
|
119 |
+
async function parseShardedIndex(
|
120 |
+
path: string,
|
121 |
+
params: {
|
122 |
+
repo: RepoDesignation;
|
123 |
+
revision?: string;
|
124 |
+
hubUrl?: string;
|
125 |
+
/**
|
126 |
+
* Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
|
127 |
+
*/
|
128 |
+
fetch?: typeof fetch;
|
129 |
+
} & Partial<CredentialsParams>
|
130 |
+
): Promise<{ index: SafetensorsIndexJson; headers: SafetensorsShardedHeaders }> {
|
131 |
+
const indexBlob = await downloadFile({
|
132 |
+
...params,
|
133 |
+
path,
|
134 |
+
});
|
135 |
+
|
136 |
+
if (!indexBlob) {
|
137 |
+
throw new SafetensorParseError(`Failed to parse file ${path}: failed to fetch safetensors index.`);
|
138 |
+
}
|
139 |
+
|
140 |
+
// no validation for now, we assume it's a valid IndexJson.
|
141 |
+
let index: SafetensorsIndexJson;
|
142 |
+
try {
|
143 |
+
index = JSON.parse(await indexBlob.slice(0, 10_000_000).text());
|
144 |
+
} catch (error) {
|
145 |
+
throw new SafetensorParseError(`Failed to parse file ${path}: not a valid JSON.`);
|
146 |
+
}
|
147 |
+
|
148 |
+
const pathPrefix = path.slice(0, path.lastIndexOf("/") + 1);
|
149 |
+
const filenames = [...new Set(Object.values(index.weight_map))];
|
150 |
+
const shardedMap: SafetensorsShardedHeaders = Object.fromEntries(
|
151 |
+
await promisesQueue(
|
152 |
+
filenames.map(
|
153 |
+
(filename) => async () =>
|
154 |
+
[filename, await parseSingleFile(pathPrefix + filename, params)] satisfies [string, SafetensorsFileHeader]
|
155 |
+
),
|
156 |
+
PARALLEL_DOWNLOADS
|
157 |
+
)
|
158 |
+
);
|
159 |
+
return { index, headers: shardedMap };
|
160 |
+
}
|
161 |
+
|
162 |
+
/**
|
163 |
+
* Analyze model.safetensors.index.json or model.safetensors from a model hosted
|
164 |
+
* on Hugging Face using smart range requests to extract its metadata.
|
165 |
+
*/
|
166 |
+
export async function parseSafetensorsMetadata(
|
167 |
+
params: {
|
168 |
+
/** Only models are supported */
|
169 |
+
repo: RepoDesignation;
|
170 |
+
/**
|
171 |
+
* Relative file path to safetensors file inside `repo`. Defaults to `SAFETENSORS_FILE` or `SAFETENSORS_INDEX_FILE` (whichever one exists).
|
172 |
+
*/
|
173 |
+
path?: string;
|
174 |
+
/**
|
175 |
+
* Will include SafetensorsParseFromRepo["parameterCount"], an object containing the number of parameters for each DType
|
176 |
+
*
|
177 |
+
* @default false
|
178 |
+
*/
|
179 |
+
computeParametersCount: true;
|
180 |
+
hubUrl?: string;
|
181 |
+
revision?: string;
|
182 |
+
/**
|
183 |
+
* Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
|
184 |
+
*/
|
185 |
+
fetch?: typeof fetch;
|
186 |
+
} & Partial<CredentialsParams>
|
187 |
+
): Promise<SetRequired<SafetensorsParseFromRepo, "parameterCount">>;
|
188 |
+
export async function parseSafetensorsMetadata(
|
189 |
+
params: {
|
190 |
+
/** Only models are supported */
|
191 |
+
repo: RepoDesignation;
|
192 |
+
/**
|
193 |
+
* Will include SafetensorsParseFromRepo["parameterCount"], an object containing the number of parameters for each DType
|
194 |
+
*
|
195 |
+
* @default false
|
196 |
+
*/
|
197 |
+
path?: string;
|
198 |
+
computeParametersCount?: boolean;
|
199 |
+
hubUrl?: string;
|
200 |
+
revision?: string;
|
201 |
+
/**
|
202 |
+
* Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
|
203 |
+
*/
|
204 |
+
fetch?: typeof fetch;
|
205 |
+
} & Partial<CredentialsParams>
|
206 |
+
): Promise<SafetensorsParseFromRepo>;
|
207 |
+
export async function parseSafetensorsMetadata(
|
208 |
+
params: {
|
209 |
+
repo: RepoDesignation;
|
210 |
+
path?: string;
|
211 |
+
computeParametersCount?: boolean;
|
212 |
+
hubUrl?: string;
|
213 |
+
revision?: string;
|
214 |
+
/**
|
215 |
+
* Custom fetch function to use instead of the default one, for example to use a proxy or edit headers.
|
216 |
+
*/
|
217 |
+
fetch?: typeof fetch;
|
218 |
+
} & Partial<CredentialsParams>
|
219 |
+
): Promise<SafetensorsParseFromRepo> {
|
220 |
+
const repoId = toRepoId(params.repo);
|
221 |
+
|
222 |
+
if (repoId.type !== "model") {
|
223 |
+
throw new TypeError("Only model repos should contain safetensors files.");
|
224 |
+
}
|
225 |
+
|
226 |
+
if (RE_SAFETENSORS_FILE.test(params.path ?? "") || (await fileExists({ ...params, path: SAFETENSORS_FILE }))) {
|
227 |
+
const header = await parseSingleFile(params.path ?? SAFETENSORS_FILE, params);
|
228 |
+
return {
|
229 |
+
sharded: false,
|
230 |
+
header,
|
231 |
+
...(params.computeParametersCount && {
|
232 |
+
parameterCount: computeNumOfParamsByDtypeSingleFile(header),
|
233 |
+
}),
|
234 |
+
};
|
235 |
+
} else if (
|
236 |
+
RE_SAFETENSORS_INDEX_FILE.test(params.path ?? "") ||
|
237 |
+
(await fileExists({ ...params, path: SAFETENSORS_INDEX_FILE }))
|
238 |
+
) {
|
239 |
+
const { index, headers } = await parseShardedIndex(params.path ?? SAFETENSORS_INDEX_FILE, params);
|
240 |
+
return {
|
241 |
+
sharded: true,
|
242 |
+
index,
|
243 |
+
headers,
|
244 |
+
...(params.computeParametersCount && {
|
245 |
+
parameterCount: computeNumOfParamsByDtypeSharded(headers),
|
246 |
+
}),
|
247 |
+
};
|
248 |
+
} else {
|
249 |
+
throw new Error("model id does not seem to contain safetensors weights");
|
250 |
+
}
|
251 |
+
}
|
252 |
+
|
253 |
+
function computeNumOfParamsByDtypeSingleFile(header: SafetensorsFileHeader): Partial<Record<Dtype, number>> {
|
254 |
+
const counter: Partial<Record<Dtype, number>> = {};
|
255 |
+
const tensors = omit(header, "__metadata__");
|
256 |
+
|
257 |
+
for (const [, v] of typedEntries(tensors)) {
|
258 |
+
if (v.shape.length === 0) {
|
259 |
+
continue;
|
260 |
+
}
|
261 |
+
counter[v.dtype] = (counter[v.dtype] ?? 0) + v.shape.reduce((a, b) => a * b);
|
262 |
+
}
|
263 |
+
return counter;
|
264 |
+
}
|
265 |
+
|
266 |
+
function computeNumOfParamsByDtypeSharded(shardedMap: SafetensorsShardedHeaders): Partial<Record<Dtype, number>> {
|
267 |
+
const counter: Partial<Record<Dtype, number>> = {};
|
268 |
+
for (const header of Object.values(shardedMap)) {
|
269 |
+
for (const [k, v] of typedEntries(computeNumOfParamsByDtypeSingleFile(header))) {
|
270 |
+
counter[k] = (counter[k] ?? 0) + (v ?? 0);
|
271 |
+
}
|
272 |
+
}
|
273 |
+
return counter;
|
274 |
+
}
|