app.py
CHANGED
@@ -22,22 +22,25 @@ INFERENCE_API = "api"
|
|
22 |
# モデル定義
|
23 |
TEXT_GENERATION_MODELS = [
|
24 |
{
|
25 |
-
"name": "
|
26 |
-
"description": "
|
|
|
27 |
"type": INFERENCE_API,
|
28 |
-
"model_id": "
|
29 |
},
|
30 |
{
|
31 |
-
"name": "
|
32 |
-
"description": "
|
33 |
-
"
|
34 |
-
"
|
|
|
35 |
},
|
36 |
{
|
37 |
-
"name": "
|
38 |
-
"description": "
|
|
|
39 |
"type": LOCAL,
|
40 |
-
"model_path": "
|
41 |
}
|
42 |
]
|
43 |
|
@@ -111,7 +114,7 @@ def preload_local_models():
|
|
111 |
logger.error(f"Error preloading model {model_path}: {str(e)}")
|
112 |
|
113 |
@spaces.GPU
|
114 |
-
def generate_text_local(model_path, text):
|
115 |
"""ローカルモデルでのテキスト生成"""
|
116 |
try:
|
117 |
logger.info(f"Running local text generation with {model_path}")
|
@@ -129,13 +132,20 @@ def generate_text_local(model_path, text):
|
|
129 |
device_info = next(pipeline.model.parameters()).device
|
130 |
logger.info(f"Model {model_path} is running on device: {device_info}")
|
131 |
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
# モデルをCPUに戻す
|
140 |
pipeline.model = pipeline.model.to("cpu")
|
141 |
if hasattr(pipeline, "device"):
|
@@ -146,15 +156,22 @@ def generate_text_local(model_path, text):
|
|
146 |
logger.error(f"Error in local text generation with {model_path}: {str(e)}")
|
147 |
return f"Error: {str(e)}"
|
148 |
|
149 |
-
def generate_text_api(model_id, text):
|
150 |
"""API経由でのテキスト生成"""
|
151 |
try:
|
152 |
logger.info(f"Running API text generation with {model_id}")
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
158 |
return response
|
159 |
except Exception as e:
|
160 |
logger.error(f"Error in API text generation with {model_id}: {str(e)}")
|
@@ -214,11 +231,11 @@ def handle_invoke(text, selected_types):
|
|
214 |
for model in TEXT_GENERATION_MODELS:
|
215 |
if model["type"] in selected_types:
|
216 |
if model["type"] == LOCAL:
|
217 |
-
future = executor.submit(generate_text_local, model["model_path"], text)
|
218 |
futures.append(future)
|
219 |
futures_to_model[future] = model
|
220 |
else: # api
|
221 |
-
future = executor.submit(generate_text_api, model["model_id"], text)
|
222 |
futures.append(future)
|
223 |
futures_to_model[future] = model
|
224 |
|
|
|
22 |
# モデル定義
|
23 |
TEXT_GENERATION_MODELS = [
|
24 |
{
|
25 |
+
"name": "Llama-2-7b-chat-hf",
|
26 |
+
"description": "Llama-2-7b-chat-hf",
|
27 |
+
"chat_model": True,
|
28 |
"type": INFERENCE_API,
|
29 |
+
"model_id": "meta-llama/Llama-2-7b-chat-hf"
|
30 |
},
|
31 |
{
|
32 |
+
"name": "TinyLlaama-1.1B-Chat-v1.0",
|
33 |
+
"description": "TinyLlaama-1.1B-Chat-v1.0",
|
34 |
+
"chat_model": True,
|
35 |
+
"type": INFERENCE_API,
|
36 |
+
"model_id": "tinyllama/TinyLlama-1.1B-Chat-v1.0"
|
37 |
},
|
38 |
{
|
39 |
+
"name": "TinyLlama_v1.1_math_code",
|
40 |
+
"description": "TinyLlama_v1.1_math_code",
|
41 |
+
"chat_model": False,
|
42 |
"type": LOCAL,
|
43 |
+
"model_path": "TinyLlama/TinyLlama_v1.1_math_code"
|
44 |
}
|
45 |
]
|
46 |
|
|
|
114 |
logger.error(f"Error preloading model {model_path}: {str(e)}")
|
115 |
|
116 |
@spaces.GPU
|
117 |
+
def generate_text_local(model_path, chat_model, text):
|
118 |
"""ローカルモデルでのテキスト生成"""
|
119 |
try:
|
120 |
logger.info(f"Running local text generation with {model_path}")
|
|
|
132 |
device_info = next(pipeline.model.parameters()).device
|
133 |
logger.info(f"Model {model_path} is running on device: {device_info}")
|
134 |
|
135 |
+
if chat_model:
|
136 |
+
outputs = pipeline(
|
137 |
+
[{"role": "user", "content": text}],
|
138 |
+
max_new_tokens=40,
|
139 |
+
do_sample=False,
|
140 |
+
num_return_sequences=1
|
141 |
+
)
|
142 |
+
else:
|
143 |
+
outputs = pipeline(
|
144 |
+
text,
|
145 |
+
max_new_tokens=40,
|
146 |
+
do_sample=False,
|
147 |
+
num_return_sequences=1
|
148 |
+
)
|
149 |
# モデルをCPUに戻す
|
150 |
pipeline.model = pipeline.model.to("cpu")
|
151 |
if hasattr(pipeline, "device"):
|
|
|
156 |
logger.error(f"Error in local text generation with {model_path}: {str(e)}")
|
157 |
return f"Error: {str(e)}"
|
158 |
|
159 |
+
def generate_text_api(model_id, chat_model, text):
|
160 |
"""API経由でのテキスト生成"""
|
161 |
try:
|
162 |
logger.info(f"Running API text generation with {model_id}")
|
163 |
+
if chat_model:
|
164 |
+
response = api_clients[model_id].chat.completions.create(
|
165 |
+
messages=[{"role": "user", "content": text}],
|
166 |
+
max_tokens=40,
|
167 |
+
temperature=0.7
|
168 |
+
)
|
169 |
+
response = response.choices[0].message.content
|
170 |
+
else:
|
171 |
+
response = api_clients[model_id].text_generation(
|
172 |
+
text,
|
173 |
+
max_new_tokens=40,
|
174 |
+
temperature=0.7)
|
175 |
return response
|
176 |
except Exception as e:
|
177 |
logger.error(f"Error in API text generation with {model_id}: {str(e)}")
|
|
|
231 |
for model in TEXT_GENERATION_MODELS:
|
232 |
if model["type"] in selected_types:
|
233 |
if model["type"] == LOCAL:
|
234 |
+
future = executor.submit(generate_text_local, model["model_path"], model["chat_model"], text)
|
235 |
futures.append(future)
|
236 |
futures_to_model[future] = model
|
237 |
else: # api
|
238 |
+
future = executor.submit(generate_text_api, model["model_id"], model["chat_model"], text)
|
239 |
futures.append(future)
|
240 |
futures_to_model[future] = model
|
241 |
|