nyasukun commited on
Commit
93a0da6
·
1 Parent(s): f5778e9
Files changed (1) hide show
  1. app.py +43 -26
app.py CHANGED
@@ -22,22 +22,25 @@ INFERENCE_API = "api"
22
  # モデル定義
23
  TEXT_GENERATION_MODELS = [
24
  {
25
- "name": "Zephyr-7B",
26
- "description": "Specialized in understanding context and nuance",
 
27
  "type": INFERENCE_API,
28
- "model_id": "HuggingFaceH4/zephyr-7b-beta"
29
  },
30
  {
31
- "name": "Llama-2",
32
- "description": "Known for its robust performance in content analysis",
33
- "type": LOCAL,
34
- "model_path": "meta-llama/Llama-2-7b-hf"
 
35
  },
36
  {
37
- "name": "Mistral-7B",
38
- "description": "Offers precise and detailed text evaluation",
 
39
  "type": LOCAL,
40
- "model_path": "mistralai/Mistral-7B-v0.1"
41
  }
42
  ]
43
 
@@ -111,7 +114,7 @@ def preload_local_models():
111
  logger.error(f"Error preloading model {model_path}: {str(e)}")
112
 
113
  @spaces.GPU
114
- def generate_text_local(model_path, text):
115
  """ローカルモデルでのテキスト生成"""
116
  try:
117
  logger.info(f"Running local text generation with {model_path}")
@@ -129,13 +132,20 @@ def generate_text_local(model_path, text):
129
  device_info = next(pipeline.model.parameters()).device
130
  logger.info(f"Model {model_path} is running on device: {device_info}")
131
 
132
- outputs = pipeline(
133
- text,
134
- max_new_tokens=40,
135
- do_sample=False,
136
- num_return_sequences=1
137
- )
138
-
 
 
 
 
 
 
 
139
  # モデルをCPUに戻す
140
  pipeline.model = pipeline.model.to("cpu")
141
  if hasattr(pipeline, "device"):
@@ -146,15 +156,22 @@ def generate_text_local(model_path, text):
146
  logger.error(f"Error in local text generation with {model_path}: {str(e)}")
147
  return f"Error: {str(e)}"
148
 
149
- def generate_text_api(model_id, text):
150
  """API経由でのテキスト生成"""
151
  try:
152
  logger.info(f"Running API text generation with {model_id}")
153
- response = api_clients[model_id].text_generation(
154
- text,
155
- max_new_tokens=40,
156
- temperature=0.7
157
- )
 
 
 
 
 
 
 
158
  return response
159
  except Exception as e:
160
  logger.error(f"Error in API text generation with {model_id}: {str(e)}")
@@ -214,11 +231,11 @@ def handle_invoke(text, selected_types):
214
  for model in TEXT_GENERATION_MODELS:
215
  if model["type"] in selected_types:
216
  if model["type"] == LOCAL:
217
- future = executor.submit(generate_text_local, model["model_path"], text)
218
  futures.append(future)
219
  futures_to_model[future] = model
220
  else: # api
221
- future = executor.submit(generate_text_api, model["model_id"], text)
222
  futures.append(future)
223
  futures_to_model[future] = model
224
 
 
22
  # モデル定義
23
  TEXT_GENERATION_MODELS = [
24
  {
25
+ "name": "Llama-2-7b-chat-hf",
26
+ "description": "Llama-2-7b-chat-hf",
27
+ "chat_model": True,
28
  "type": INFERENCE_API,
29
+ "model_id": "meta-llama/Llama-2-7b-chat-hf"
30
  },
31
  {
32
+ "name": "TinyLlaama-1.1B-Chat-v1.0",
33
+ "description": "TinyLlaama-1.1B-Chat-v1.0",
34
+ "chat_model": True,
35
+ "type": INFERENCE_API,
36
+ "model_id": "tinyllama/TinyLlama-1.1B-Chat-v1.0"
37
  },
38
  {
39
+ "name": "TinyLlama_v1.1_math_code",
40
+ "description": "TinyLlama_v1.1_math_code",
41
+ "chat_model": False,
42
  "type": LOCAL,
43
+ "model_path": "TinyLlama/TinyLlama_v1.1_math_code"
44
  }
45
  ]
46
 
 
114
  logger.error(f"Error preloading model {model_path}: {str(e)}")
115
 
116
  @spaces.GPU
117
+ def generate_text_local(model_path, chat_model, text):
118
  """ローカルモデルでのテキスト生成"""
119
  try:
120
  logger.info(f"Running local text generation with {model_path}")
 
132
  device_info = next(pipeline.model.parameters()).device
133
  logger.info(f"Model {model_path} is running on device: {device_info}")
134
 
135
+ if chat_model:
136
+ outputs = pipeline(
137
+ [{"role": "user", "content": text}],
138
+ max_new_tokens=40,
139
+ do_sample=False,
140
+ num_return_sequences=1
141
+ )
142
+ else:
143
+ outputs = pipeline(
144
+ text,
145
+ max_new_tokens=40,
146
+ do_sample=False,
147
+ num_return_sequences=1
148
+ )
149
  # モデルをCPUに戻す
150
  pipeline.model = pipeline.model.to("cpu")
151
  if hasattr(pipeline, "device"):
 
156
  logger.error(f"Error in local text generation with {model_path}: {str(e)}")
157
  return f"Error: {str(e)}"
158
 
159
+ def generate_text_api(model_id, chat_model, text):
160
  """API経由でのテキスト生成"""
161
  try:
162
  logger.info(f"Running API text generation with {model_id}")
163
+ if chat_model:
164
+ response = api_clients[model_id].chat.completions.create(
165
+ messages=[{"role": "user", "content": text}],
166
+ max_tokens=40,
167
+ temperature=0.7
168
+ )
169
+ response = response.choices[0].message.content
170
+ else:
171
+ response = api_clients[model_id].text_generation(
172
+ text,
173
+ max_new_tokens=40,
174
+ temperature=0.7)
175
  return response
176
  except Exception as e:
177
  logger.error(f"Error in API text generation with {model_id}: {str(e)}")
 
231
  for model in TEXT_GENERATION_MODELS:
232
  if model["type"] in selected_types:
233
  if model["type"] == LOCAL:
234
+ future = executor.submit(generate_text_local, model["model_path"], model["chat_model"], text)
235
  futures.append(future)
236
  futures_to_model[future] = model
237
  else: # api
238
+ future = executor.submit(generate_text_api, model["model_id"], model["chat_model"], text)
239
  futures.append(future)
240
  futures_to_model[future] = model
241