andstor commited on
Commit
f8a5ad5
·
verified ·
1 Parent(s): 6dedfbb

Only do model name translation for Llama-2 and CodeLlama

Browse files
Files changed (1) hide show
  1. src/model_utils.py +4 -4
src/model_utils.py CHANGED
@@ -30,8 +30,8 @@ def extract_from_url(name: str):
30
  return path[1:]
31
 
32
 
33
- def translate_llama2(text):
34
- "Translates llama-2 to its hf counterpart"
35
  if not text.endswith("-hf"):
36
  return text + "-hf"
37
  return text
@@ -39,8 +39,8 @@ def translate_llama2(text):
39
 
40
  def get_model(model_name: str, library: str, access_token: str):
41
  "Finds and grabs model from the Hub, and initializes on `meta`"
42
- if "meta-llama" in model_name:
43
- model_name = translate_llama2(model_name)
44
  if library == "auto":
45
  library = None
46
  model_name = extract_from_url(model_name)
 
30
  return path[1:]
31
 
32
 
33
+ def translate_llama(text):
34
+ "Translates Llama-2 and CodeLlama to its hf counterpart"
35
  if not text.endswith("-hf"):
36
  return text + "-hf"
37
  return text
 
39
 
40
  def get_model(model_name: str, library: str, access_token: str):
41
  "Finds and grabs model from the Hub, and initializes on `meta`"
42
+ if "meta-llama/Llama-2-" in model_name or "meta-llama/CodeLlama-" in model_name:
43
+ model_name = translate_llama(model_name)
44
  if library == "auto":
45
  library = None
46
  model_name = extract_from_url(model_name)