nyasukun commited on
Commit
3ab7a95
Β·
1 Parent(s): 6f6b422
Files changed (1) hide show
  1. app.py +10 -4
app.py CHANGED
@@ -116,14 +116,17 @@ def generate_text_local(model_path, text):
116
  try:
117
  logger.info(f"Running local text generation with {model_path}")
118
  pipeline = pipelines[model_path]
119
- pipeline.to("cuda")
 
 
 
 
120
  outputs = pipeline(
121
  text,
122
  max_new_tokens=40,
123
  do_sample=False,
124
  num_return_sequences=1
125
  )
126
- pipeline.to("cpu")
127
  return outputs[0]["generated_text"]
128
  except Exception as e:
129
  logger.error(f"Error in local text generation with {model_path}: {str(e)}")
@@ -149,9 +152,12 @@ def classify_text_local(model_path, text):
149
  try:
150
  logger.info(f"Running local classification with {model_path}")
151
  pipeline = pipelines[model_path]
152
- pipeline.to("cuda")
 
 
 
 
153
  result = pipeline(text)
154
- pipeline.to("cpu")
155
  return str(result)
156
  except Exception as e:
157
  logger.error(f"Error in local classification with {model_path}: {str(e)}")
 
116
  try:
117
  logger.info(f"Running local text generation with {model_path}")
118
  pipeline = pipelines[model_path]
119
+
120
+ # γƒ‡γƒγ‚€γ‚Ήζƒ…ε ±γ‚’γƒ­γ‚°γ«θ¨˜ιŒ²
121
+ device_info = next(pipeline.model.parameters()).device
122
+ logger.info(f"Model {model_path} is running on device: {device_info}")
123
+
124
  outputs = pipeline(
125
  text,
126
  max_new_tokens=40,
127
  do_sample=False,
128
  num_return_sequences=1
129
  )
 
130
  return outputs[0]["generated_text"]
131
  except Exception as e:
132
  logger.error(f"Error in local text generation with {model_path}: {str(e)}")
 
152
  try:
153
  logger.info(f"Running local classification with {model_path}")
154
  pipeline = pipelines[model_path]
155
+
156
+ # γƒ‡γƒγ‚€γ‚Ήζƒ…ε ±γ‚’γƒ­γ‚°γ«θ¨˜ιŒ²
157
+ device_info = next(pipeline.model.parameters()).device
158
+ logger.info(f"Model {model_path} is running on device: {device_info}")
159
+
160
  result = pipeline(text)
 
161
  return str(result)
162
  except Exception as e:
163
  logger.error(f"Error in local classification with {model_path}: {str(e)}")