File size: 2,501 Bytes
3440f83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import json
import os
import urllib
from tqdm import tqdm

from vlmo.config import config, _loss_names  # noqa
from vlmo.modules import VLMo
from vlmo.transforms import keys_to_transforms

def _download(url: str, root: str):
    os.makedirs(root, exist_ok=True)
    filename = os.path.basename(url)

    download_target = os.path.join(root, filename)

    if os.path.exists(download_target) and not os.path.isfile(download_target):
        raise RuntimeError(f"{download_target} exists and is not a regular file")

    if os.path.isfile(download_target):
        return download_target

    with urllib.request.urlopen(url) as source, open(download_target, "wb") as output:
        with tqdm(
            total=int(source.info().get("Content-Length")), ncols=80, unit="iB", unit_scale=True, unit_divisor=1024
        ) as loop:
            while True:
                buffer = source.read(8192)
                if not buffer:
                    break

                output.write(buffer)
                loop.update(len(buffer))

    return download_target


def config_setting(custom_config: dict):
    cfg = eval("config")()
    for k, v in custom_config.items():
        cfg[k] = v
    return cfg


def load_from_config(model_config):
    if isinstance(model_config, str):
        model_config = json.loads(open(model_config, 'r').read())
    else:
        assert isinstance(model_config, dict)

    model_url = model_config.pop('model_url', None)
    model_path = model_config.pop('model_path', None)
    if model_path and os.path.exists(model_path):
        load_path = model_path
    elif model_url:
        load_path = _download(model_url, os.path.expanduser("~/.cache/m2_encoder"))
    else:
        from modelscope import snapshot_download
        modelscope_cfg = model_config.pop('modelscope', None)
        model_dir = snapshot_download(**modelscope_cfg)
        load_path = os.path.join(model_dir, model_config.pop('model_file'))

    cfg = config_setting(model_config)
    cfg["load_path"] = load_path

    if cfg["flash_attn"]:
        from vlmo.utils.patch_utils import patch_torch_scale_with_flash_attn
        patch_torch_scale_with_flash_attn()

    model = VLMo(cfg)

    from vlmo.modules.vlmo_module import get_pretrained_tokenizer
    txt_processor = get_pretrained_tokenizer(cfg["tokenizer_type"], from_pretrained=cfg["tokenizer"])
    img_processor = keys_to_transforms(cfg["val_transform_keys"], size=cfg["image_size"])[0]

    return model, [txt_processor, img_processor]