Athspi commited on
Commit
bef7195
·
verified ·
1 Parent(s): d4def9e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -15
app.py CHANGED
@@ -162,22 +162,32 @@ def convert_model_to_onnx():
162
  dummy_prompt = "tara: Hello"
163
  dummy_input = tokenizer(dummy_prompt, return_tensors="pt").input_ids.to(device)
164
  file_path = "orpheus_model.onnx"
 
 
 
 
 
 
 
 
 
 
 
 
165
  try:
166
- # Disable Torch Dynamo during export to avoid FX-tracing issues.
167
- with torch._dynamo.disable():
168
- torch.onnx.export(
169
- model,
170
- dummy_input,
171
- file_path,
172
- export_params=True,
173
- opset_version=14,
174
- input_names=["input_ids"],
175
- output_names=["logits"],
176
- dynamic_axes={
177
- "input_ids": {0: "batch_size", 1: "sequence_length"},
178
- "logits": {0: "batch_size", 1: "sequence_length"}
179
- },
180
- )
181
  return f"Model converted to ONNX and saved as '{file_path}'."
182
  except Exception as e:
183
  return f"Error during ONNX conversion: {e}"
 
162
  dummy_prompt = "tara: Hello"
163
  dummy_input = tokenizer(dummy_prompt, return_tensors="pt").input_ids.to(device)
164
  file_path = "orpheus_model.onnx"
165
+
166
+ # Ensure the model is in evaluation mode and not compiled
167
+ model.eval()
168
+
169
+ # Reset Torch Dynamo to avoid FX-tracing issues during export.
170
+ if hasattr(torch, "_dynamo"):
171
+ try:
172
+ torch._dynamo.reset()
173
+ print("Torch Dynamo reset before ONNX export")
174
+ except Exception as e:
175
+ print(f"Warning: Torch Dynamo reset failed - {e}")
176
+
177
  try:
178
+ torch.onnx.export(
179
+ model,
180
+ dummy_input,
181
+ file_path,
182
+ export_params=True,
183
+ opset_version=14,
184
+ input_names=["input_ids"],
185
+ output_names=["logits"],
186
+ dynamic_axes={
187
+ "input_ids": {0: "batch_size", 1: "sequence_length"},
188
+ "logits": {0: "batch_size", 1: "sequence_length"}
189
+ },
190
+ )
 
 
191
  return f"Model converted to ONNX and saved as '{file_path}'."
192
  except Exception as e:
193
  return f"Error during ONNX conversion: {e}"