nyasukun commited on
Commit
2b0c283
Β·
1 Parent(s): 3ab7a95
Files changed (1) hide show
  1. app.py +18 -0
app.py CHANGED
@@ -117,6 +117,10 @@ def generate_text_local(model_path, text):
117
  logger.info(f"Running local text generation with {model_path}")
118
  pipeline = pipelines[model_path]
119
 
 
 
 
 
120
  # γƒ‡γƒγ‚€γ‚Ήζƒ…ε ±γ‚’γƒ­γ‚°γ«θ¨˜ιŒ²
121
  device_info = next(pipeline.model.parameters()).device
122
  logger.info(f"Model {model_path} is running on device: {device_info}")
@@ -127,6 +131,11 @@ def generate_text_local(model_path, text):
127
  do_sample=False,
128
  num_return_sequences=1
129
  )
 
 
 
 
 
130
  return outputs[0]["generated_text"]
131
  except Exception as e:
132
  logger.error(f"Error in local text generation with {model_path}: {str(e)}")
@@ -153,11 +162,20 @@ def classify_text_local(model_path, text):
153
  logger.info(f"Running local classification with {model_path}")
154
  pipeline = pipelines[model_path]
155
 
 
 
 
 
156
  # γƒ‡γƒγ‚€γ‚Ήζƒ…ε ±γ‚’γƒ­γ‚°γ«θ¨˜ιŒ²
157
  device_info = next(pipeline.model.parameters()).device
158
  logger.info(f"Model {model_path} is running on device: {device_info}")
159
 
160
  result = pipeline(text)
 
 
 
 
 
161
  return str(result)
162
  except Exception as e:
163
  logger.error(f"Error in local classification with {model_path}: {str(e)}")
 
117
  logger.info(f"Running local text generation with {model_path}")
118
  pipeline = pipelines[model_path]
119
 
120
+ # γƒ’γƒ‡γƒ«γ‚’ζ˜Žη€Ίηš„γ« GPU に移動
121
+ if hasattr(pipeline.model, "to"):
122
+ pipeline.model.to("cuda")
123
+
124
  # γƒ‡γƒγ‚€γ‚Ήζƒ…ε ±γ‚’γƒ­γ‚°γ«θ¨˜ιŒ²
125
  device_info = next(pipeline.model.parameters()).device
126
  logger.info(f"Model {model_path} is running on device: {device_info}")
 
131
  do_sample=False,
132
  num_return_sequences=1
133
  )
134
+
135
+ # ヒデルを CPU γ«ζˆ»γ™
136
+ if hasattr(pipeline.model, "to"):
137
+ pipeline.model.to("cpu")
138
+
139
  return outputs[0]["generated_text"]
140
  except Exception as e:
141
  logger.error(f"Error in local text generation with {model_path}: {str(e)}")
 
162
  logger.info(f"Running local classification with {model_path}")
163
  pipeline = pipelines[model_path]
164
 
165
+ # γƒ’γƒ‡γƒ«γ‚’ζ˜Žη€Ίηš„γ« GPU に移動
166
+ if hasattr(pipeline.model, "to"):
167
+ pipeline.model.to("cuda")
168
+
169
  # γƒ‡γƒγ‚€γ‚Ήζƒ…ε ±γ‚’γƒ­γ‚°γ«θ¨˜ιŒ²
170
  device_info = next(pipeline.model.parameters()).device
171
  logger.info(f"Model {model_path} is running on device: {device_info}")
172
 
173
  result = pipeline(text)
174
+
175
+ # ヒデルを CPU γ«ζˆ»γ™
176
+ if hasattr(pipeline.model, "to"):
177
+ pipeline.model.to("cpu")
178
+
179
  return str(result)
180
  except Exception as e:
181
  logger.error(f"Error in local classification with {model_path}: {str(e)}")