Skip to content

Commit 4968f0f

Browse files
authored
Support loading local models using relative paths, absolute paths, and model directory (#1268)
* Differentiate between paths and ids for local loading * Add local path unit test * Prevent invalid network requests when using invalid model id
1 parent e584042 commit 4968f0f

File tree

2 files changed

+48
-15
lines changed

2 files changed

+48
-15
lines changed

src/utils/hub.js

+36-15
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,22 @@ function isValidUrl(string, protocols = null, validHosts = null) {
190190
return true;
191191
}
192192

193+
const REPO_ID_REGEX = /^(\b[\w\-.]+\b\/)?\b[\w\-.]{1,96}\b$/;
194+
195+
/**
196+
* Tests whether a string is a valid Hugging Face model ID or not.
197+
* Adapted from https://github.com/huggingface/huggingface_hub/blob/6378820ebb03f071988a96c7f3268f5bdf8f9449/src/huggingface_hub/utils/_validators.py#L119-L170
198+
*
199+
* @param {string} string The string to test
200+
* @returns {boolean} True if the string is a valid model ID, false otherwise.
201+
*/
202+
function isValidHfModelId(string) {
203+
if (!REPO_ID_REGEX.test(string)) return false;
204+
if (string.includes("..") || string.includes("--")) return false;
205+
if (string.endsWith(".git") || string.endsWith(".ipynb")) return false;
206+
return true;
207+
}
208+
193209
/**
194210
* Helper function to get a file, using either the Fetch API or FileSystem API.
195211
*
@@ -442,27 +458,28 @@ export async function getModelFile(path_or_repo_id, filename, fatal = true, opti
442458
}
443459

444460
const revision = options.revision ?? 'main';
461+
const requestURL = pathJoin(path_or_repo_id, filename);
445462

446-
let requestURL = pathJoin(path_or_repo_id, filename);
447-
let cachePath = pathJoin(env.localModelPath, requestURL);
448-
449-
let localPath = requestURL;
450-
let remoteURL = pathJoin(
463+
const validModelId = isValidHfModelId(path_or_repo_id);
464+
const localPath = validModelId
465+
? pathJoin(env.localModelPath, requestURL)
466+
: requestURL;
467+
const remoteURL = pathJoin(
451468
env.remoteHost,
452469
env.remotePathTemplate
453470
.replaceAll('{model}', path_or_repo_id)
454471
.replaceAll('{revision}', encodeURIComponent(revision)),
455472
filename
456473
);
457474

458-
// Choose cache key for filesystem cache
459-
// When using the main revision (default), we use the request URL as the cache key.
460-
// If a specific revision is requested, we account for this in the cache key.
461-
let fsCacheKey = revision === 'main' ? requestURL : pathJoin(path_or_repo_id, revision, filename);
462-
463475
/** @type {string} */
464476
let cacheKey;
465-
let proposedCacheKey = cache instanceof FileCache ? fsCacheKey : remoteURL;
477+
const proposedCacheKey = cache instanceof FileCache
478+
// Choose cache key for filesystem cache
479+
// When using the main revision (default), we use the request URL as the cache key.
480+
// If a specific revision is requested, we account for this in the cache key.
481+
? revision === 'main' ? requestURL : pathJoin(path_or_repo_id, revision, filename)
482+
: remoteURL;
466483

467484
// Whether to cache the final response in the end.
468485
let toCacheResponse = false;
@@ -475,11 +492,10 @@ export async function getModelFile(path_or_repo_id, filename, fatal = true, opti
475492
// 1. We first try to get from cache using the local path. In some environments (like deno),
476493
// non-URL cache keys are not allowed. In these cases, `response` will be undefined.
477494
// 2. If no response is found, we try to get from cache using the remote URL or file system cache.
478-
response = await tryCache(cache, cachePath, proposedCacheKey);
495+
response = await tryCache(cache, localPath, proposedCacheKey);
479496
}
480497

481498
const cacheHit = response !== undefined;
482-
483499
if (response === undefined) {
484500
// Caching not available, or file is not cached, so we perform the request
485501

@@ -497,9 +513,9 @@ export async function getModelFile(path_or_repo_id, filename, fatal = true, opti
497513
console.warn(`Unable to load from local path "${localPath}": "${e}"`);
498514
}
499515
} else if (options.local_files_only) {
500-
throw new Error(`\`local_files_only=true\`, but attempted to load a remote file from: ${localPath}.`);
516+
throw new Error(`\`local_files_only=true\`, but attempted to load a remote file from: ${requestURL}.`);
501517
} else if (!env.allowRemoteModels) {
502-
throw new Error(`\`env.allowRemoteModels=false\`, but attempted to load a remote file from: ${localPath}.`);
518+
throw new Error(`\`env.allowRemoteModels=false\`, but attempted to load a remote file from: ${requestURL}.`);
503519
}
504520
}
505521

@@ -519,6 +535,11 @@ export async function getModelFile(path_or_repo_id, filename, fatal = true, opti
519535
return null;
520536
}
521537
}
538+
if (!validModelId) {
539+
// Before making any requests to the remote server, we check if the model ID is valid.
540+
// This prevents unnecessary network requests for invalid model IDs.
541+
throw Error(`Local file missing at "${localPath}" and download aborted due to invalid model ID "${path_or_repo_id}".`);
542+
}
522543

523544
// File not found locally, so we try to download it from the remote server
524545
response = await getFile(remoteURL);

tests/utils/hub.test.js

+12
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import { AutoModel, PreTrainedModel } from "../../src/models.js";
22

33
import { MAX_TEST_EXECUTION_TIME, DEFAULT_MODEL_OPTIONS } from "../init.js";
4+
import fs from "fs";
45

56
// TODO: Set cache folder to a temp directory
67

@@ -36,5 +37,16 @@ describe("Hub", () => {
3637
},
3738
MAX_TEST_EXECUTION_TIME,
3839
);
40+
41+
const localPath = "./models/hf-internal-testing/tiny-random-T5ForConditionalGeneration";
42+
(fs.existsSync(localPath) ? it : it.skip)(
43+
"should load a model from a local path",
44+
async () => {
45+
// 4. Ensure we can load a model from a local path
46+
const model = await AutoModel.from_pretrained(localPath, DEFAULT_MODEL_OPTIONS);
47+
expect(model).toBeInstanceOf(PreTrainedModel);
48+
},
49+
MAX_TEST_EXECUTION_TIME,
50+
);
3951
});
4052
});

0 commit comments

Comments
 (0)