File size: 2,892 Bytes
b4959be
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-07-23T04:19:23.488642Z",
     "start_time": "2021-07-23T04:19:21.854534Z"
    }
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import torch.backends.cudnn as cudnn\n",
    "import yaml\n",
    "from train import train\n",
    "from utils import AttrDict\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-07-23T04:19:23.885144Z",
     "start_time": "2021-07-23T04:19:23.880564Z"
    },
    "code_folding": []
   },
   "outputs": [],
   "source": [
    "cudnn.benchmark = True\n",
    "cudnn.deterministic = False"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-07-23T04:19:24.119144Z",
     "start_time": "2021-07-23T04:19:24.112032Z"
    },
    "code_folding": [
     0
    ]
   },
   "outputs": [],
   "source": [
    "def get_config(file_path):\n",
    "    with open(file_path, 'r', encoding=\"utf8\") as stream:\n",
    "        opt = yaml.safe_load(stream)\n",
    "    opt = AttrDict(opt)\n",
    "    if opt.lang_char == 'None':\n",
    "        characters = ''\n",
    "        for data in opt['select_data'].split('-'):\n",
    "            csv_path = os.path.join(opt['train_data'], data, 'labels.csv')\n",
    "            df = pd.read_csv(csv_path, sep='^([^,]+),', engine='python', usecols=['filename', 'words'], keep_default_na=False)\n",
    "            all_char = ''.join(df['words'])\n",
    "            characters += ''.join(set(all_char))\n",
    "        characters = sorted(set(characters))\n",
    "        opt.character= ''.join(characters)\n",
    "    else:\n",
    "        opt.character = opt.number + opt.symbol + opt.lang_char\n",
    "    os.makedirs(f'./saved_models/{opt.experiment_name}', exist_ok=True)\n",
    "    return opt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2021-07-23T04:49:07.045060Z",
     "start_time": "2021-07-23T04:20:15.050992Z"
    }
   },
   "outputs": [],
   "source": [
    "opt = get_config(\"config_files/en_filtered_config.yaml\")\n",
    "train(opt, amp=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}