Update sales_inference.py
Browse files- sales_inference.py +366 -72
sales_inference.py
CHANGED
@@ -1,91 +1,385 @@
|
|
1 |
-
|
2 |
import os
|
|
|
3 |
import numpy as np
|
4 |
import torch
|
5 |
-
|
|
|
6 |
from stable_baselines3 import PPO
|
7 |
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
|
8 |
-
import
|
|
|
9 |
from dataclasses import dataclass
|
10 |
-
from typing import List, Dict
|
|
|
|
|
11 |
|
12 |
-
#
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
-
class
|
22 |
-
def __init__(self, observation_space, features_dim: int =
|
23 |
super().__init__(observation_space, features_dim)
|
24 |
-
|
25 |
-
self.
|
26 |
-
nn.Linear(
|
27 |
-
nn.
|
28 |
-
nn.Linear(
|
|
|
|
|
|
|
29 |
).to(device)
|
30 |
-
|
|
|
|
|
31 |
|
32 |
-
|
33 |
-
|
34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
45 |
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
)
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
if __name__ == "__main__":
|
77 |
-
#
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
#
|
82 |
-
|
83 |
-
|
84 |
-
#
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
import os
|
3 |
+
import json
|
4 |
import numpy as np
|
5 |
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from openai import AzureOpenAI # Use the new AzureOpenAI client
|
8 |
from stable_baselines3 import PPO
|
9 |
from stable_baselines3.common.torch_layers import BaseFeaturesExtractor
|
10 |
+
import gymnasium as gym
|
11 |
+
from gymnasium import spaces
|
12 |
from dataclasses import dataclass
|
13 |
+
from typing import List, Dict, Any
|
14 |
+
import argparse
|
15 |
+
import logging
|
16 |
|
17 |
+
# Configure logging
|
18 |
+
logging.basicConfig(
|
19 |
+
level=logging.INFO,
|
20 |
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
|
21 |
+
handlers=[logging.StreamHandler()]
|
22 |
)
|
23 |
+
logger = logging.getLogger(__name__)
|
24 |
+
|
25 |
+
# GPU Setup
|
26 |
+
if torch.cuda.is_available():
|
27 |
+
device = torch.device("cuda")
|
28 |
+
logger.info(f"Using GPU: {torch.cuda.get_device_name(0)}")
|
29 |
+
else:
|
30 |
+
device = torch.device("cpu")
|
31 |
+
logger.info("GPU not available, using CPU for inference")
|
32 |
+
|
33 |
+
# --- Replicated/Necessary Classes from train.py ---
|
34 |
+
|
35 |
+
@dataclass
|
36 |
+
class ConversationState:
|
37 |
+
conversation_history: List[Dict[str, str]]
|
38 |
+
embedding: np.ndarray
|
39 |
+
conversation_metrics: Dict[str, float]
|
40 |
+
turn_number: int
|
41 |
+
conversion_probabilities: List[float]
|
42 |
|
43 |
+
@property
|
44 |
+
def state_vector(self) -> np.ndarray:
|
45 |
+
metric_values = np.array(list(self.conversation_metrics.values()), dtype=np.float32)
|
46 |
+
turn_info = np.array([self.turn_number], dtype=np.float32)
|
47 |
+
padded_probs = np.zeros(10, dtype=np.float32)
|
48 |
+
probs_to_pad = self.conversion_probabilities[-10:]
|
49 |
+
padded_probs[:len(probs_to_pad)] = probs_to_pad
|
50 |
+
|
51 |
+
return np.concatenate([
|
52 |
+
self.embedding,
|
53 |
+
metric_values,
|
54 |
+
turn_info,
|
55 |
+
padded_probs
|
56 |
+
]).astype(np.float32)
|
57 |
|
58 |
+
class CustomLN(BaseFeaturesExtractor):
|
59 |
+
def __init__(self, observation_space: gym.spaces.Box, features_dim: int = 128):
|
60 |
super().__init__(observation_space, features_dim)
|
61 |
+
n_input_channels = observation_space.shape[0]
|
62 |
+
self.linear_network = nn.Sequential(
|
63 |
+
nn.Linear(n_input_channels, 512),
|
64 |
+
nn.ReLU(),
|
65 |
+
nn.Linear(512, 256),
|
66 |
+
nn.ReLU(),
|
67 |
+
nn.Linear(256, features_dim),
|
68 |
+
nn.ReLU(),
|
69 |
).to(device)
|
70 |
+
|
71 |
+
def forward(self, observations: torch.Tensor) -> torch.Tensor:
|
72 |
+
return self.linear_network(observations)
|
73 |
|
74 |
+
# --- Azure OpenAI Embedding Function ---
|
75 |
+
|
76 |
+
def get_azure_openai_embedding(
|
77 |
+
text: str,
|
78 |
+
client: AzureOpenAI,
|
79 |
+
deployment_name: str
|
80 |
+
) -> np.ndarray:
|
81 |
+
"""Gets embedding from Azure OpenAI for the given text."""
|
82 |
+
try:
|
83 |
+
response = client.embeddings.create(
|
84 |
+
input=text,
|
85 |
+
model=deployment_name # For Azure, this is the deployment name
|
86 |
+
)
|
87 |
+
embedding_vector = np.array(response.data[0].embedding, dtype=np.float32)
|
88 |
+
logger.debug(f"Received embedding from Azure. Shape: {embedding_vector.shape}")
|
89 |
+
return embedding_vector
|
90 |
+
except Exception as e:
|
91 |
+
logger.error(f"Error getting embedding from Azure OpenAI: {e}")
|
92 |
+
# Fallback to a zero vector of a common dimension, or raise error
|
93 |
+
# For text-embedding-3-large, dimension is 3072. For ada-002 it's 1536.
|
94 |
+
logger.warning("Falling back to zero embedding. This will impact prediction quality.")
|
95 |
+
# It's better if the calling function determines the expected fallback dimension
|
96 |
+
# based on the actual deployment model, but for simplicity here, we'll assume 3072 if error.
|
97 |
+
return np.zeros(3072, dtype=np.float32) # Default to text-embedding-3-large dim
|
98 |
+
|
99 |
+
def process_raw_embedding(
|
100 |
+
raw_embedding: np.ndarray,
|
101 |
+
turn: int,
|
102 |
+
max_turns_for_scaling: int,
|
103 |
+
target_model_embedding_dim: int, # The dimension model's observation space expects
|
104 |
+
use_miniembeddings: bool
|
105 |
+
) -> np.ndarray:
|
106 |
+
"""
|
107 |
+
Scales and potentially reduces/pads the raw embedding (from Azure)
|
108 |
+
to match the model's expected input dimension and characteristics.
|
109 |
+
"""
|
110 |
+
dim_of_raw_embedding = len(raw_embedding)
|
111 |
+
logger.debug(f"Processing raw_embedding. Dim: {dim_of_raw_embedding}, Target model dim: {target_model_embedding_dim}, Use mini: {use_miniembeddings}")
|
112 |
+
|
113 |
+
|
114 |
+
# 1. Apply turn-based dynamic scaling (mimicking training)
|
115 |
+
progress = min(1.0, turn / max_turns_for_scaling)
|
116 |
+
scaled_embedding = raw_embedding * (0.6 + 0.4 * progress)
|
117 |
+
|
118 |
+
# 2. Adjust dimension to target_model_embedding_dim
|
119 |
+
if use_miniembeddings and dim_of_raw_embedding > target_model_embedding_dim:
|
120 |
+
logger.debug(f"Applying mini-embedding reduction from {dim_of_raw_embedding} to {target_model_embedding_dim}")
|
121 |
+
if target_model_embedding_dim <= 0:
|
122 |
+
logger.error("Target model embedding dimension is <=0. Cannot pool.")
|
123 |
+
return np.zeros(1, dtype=np.float32) # Return a minimal valid array
|
124 |
+
|
125 |
+
pool_factor = dim_of_raw_embedding // target_model_embedding_dim
|
126 |
+
if pool_factor == 0: pool_factor = 1
|
127 |
+
|
128 |
+
num_elements_to_pool = pool_factor * target_model_embedding_dim
|
129 |
+
|
130 |
+
# If not enough elements for perfect pooling (e.g. raw_dim=5, target_dim=3 -> pool_factor=1, num_elements_to_pool=3)
|
131 |
+
# or too many (e.g. raw_dim=5, target_dim=2 -> pool_factor=2, num_elements_to_pool=4)
|
132 |
+
# We'll pool from the available part of scaled_embedding
|
133 |
+
elements_for_pooling = scaled_embedding[:num_elements_to_pool] if num_elements_to_pool <= dim_of_raw_embedding else scaled_embedding
|
134 |
+
|
135 |
+
if len(elements_for_pooling) < target_model_embedding_dim : # Not enough elements even to form the target dim vector
|
136 |
+
logger.warning(f"Not enough elements ({len(elements_for_pooling)}) to pool into target_dim ({target_model_embedding_dim}). Padding result.")
|
137 |
+
reduced_embedding = np.zeros(target_model_embedding_dim, dtype=np.float32)
|
138 |
+
fill_len = min(len(elements_for_pooling), target_model_embedding_dim)
|
139 |
+
reduced_embedding[:fill_len] = elements_for_pooling[:fill_len] # Simplified: take first elements if pooling fails
|
140 |
+
else:
|
141 |
+
try:
|
142 |
+
# Adjust elements_for_pooling to be perfectly divisible if necessary
|
143 |
+
reshapable_length = (len(elements_for_pooling) // pool_factor) * pool_factor
|
144 |
+
reshaped_for_pooling = elements_for_pooling[:reshapable_length].reshape(-1, pool_factor) # -1 infers target_model_embedding_dim or similar
|
145 |
+
|
146 |
+
# Ensure the first dimension of reshaped matches target_model_embedding_dim
|
147 |
+
if reshaped_for_pooling.shape[0] > target_model_embedding_dim:
|
148 |
+
reshaped_for_pooling = reshaped_for_pooling[:target_model_embedding_dim, :]
|
149 |
+
elif reshaped_for_pooling.shape[0] < target_model_embedding_dim:
|
150 |
+
# This case should ideally be handled by padding the result
|
151 |
+
logger.warning(f"Pooling resulted in fewer dimensions ({reshaped_for_pooling.shape[0]}) than target ({target_model_embedding_dim}). Will pad.")
|
152 |
+
temp_reduced = np.mean(reshaped_for_pooling, axis=1)
|
153 |
+
reduced_embedding = np.zeros(target_model_embedding_dim, dtype=np.float32)
|
154 |
+
reduced_embedding[:len(temp_reduced)] = temp_reduced
|
155 |
+
else:
|
156 |
+
reduced_embedding = np.mean(reshaped_for_pooling, axis=1)
|
157 |
+
|
158 |
+
except ValueError as e:
|
159 |
+
logger.error(f"Reshape for pooling failed: {e}. Lengths: elements_for_pooling={len(elements_for_pooling)}, pool_factor={pool_factor}. Falling back to simple truncation/padding.")
|
160 |
+
if dim_of_raw_embedding > target_model_embedding_dim:
|
161 |
+
reduced_embedding = scaled_embedding[:target_model_embedding_dim]
|
162 |
+
else:
|
163 |
+
reduced_embedding = np.zeros(target_model_embedding_dim, dtype=np.float32)
|
164 |
+
reduced_embedding[:dim_of_raw_embedding] = scaled_embedding
|
165 |
+
|
166 |
+
processed_embedding = reduced_embedding
|
167 |
+
|
168 |
+
elif dim_of_raw_embedding == target_model_embedding_dim:
|
169 |
+
processed_embedding = scaled_embedding
|
170 |
+
elif dim_of_raw_embedding > target_model_embedding_dim:
|
171 |
+
logger.debug(f"Truncating embedding from {dim_of_raw_embedding} to {target_model_embedding_dim}")
|
172 |
+
processed_embedding = scaled_embedding[:target_model_embedding_dim]
|
173 |
+
else:
|
174 |
+
logger.debug(f"Padding embedding from {dim_of_raw_embedding} to {target_model_embedding_dim}")
|
175 |
+
processed_embedding = np.zeros(target_model_embedding_dim, dtype=np.float32)
|
176 |
+
processed_embedding[:dim_of_raw_embedding] = scaled_embedding
|
177 |
|
178 |
+
if len(processed_embedding) != target_model_embedding_dim:
|
179 |
+
logger.warning(f"Dimension mismatch after processing. Expected {target_model_embedding_dim}, got {len(processed_embedding)}. Adjusting...")
|
180 |
+
final_embedding = np.zeros(target_model_embedding_dim, dtype=np.float32)
|
181 |
+
fill_len = min(len(processed_embedding), target_model_embedding_dim)
|
182 |
+
final_embedding[:fill_len] = processed_embedding[:fill_len]
|
183 |
+
return final_embedding.astype(np.float32)
|
184 |
+
|
185 |
+
return processed_embedding.astype(np.float32)
|
186 |
+
|
187 |
+
|
188 |
+
# --- Main Prediction Logic ---
|
189 |
+
|
190 |
+
def predict_conversation_trajectory(
|
191 |
+
model: PPO,
|
192 |
+
azure_openai_client: AzureOpenAI,
|
193 |
+
azure_deployment_name: str,
|
194 |
+
conversation_messages: List[Dict[str, str]],
|
195 |
+
initial_metrics: Dict[str, float],
|
196 |
+
model_expected_embedding_dim: int,
|
197 |
+
use_miniembeddings_on_azure_emb: bool,
|
198 |
+
max_conversation_turns_scaling: int = 20
|
199 |
+
):
|
200 |
+
logger.info(f"Starting prediction. Model expects embedding_dim: {model_expected_embedding_dim}. use_mini_on_azure: {use_miniembeddings_on_azure_emb}")
|
201 |
+
|
202 |
+
current_conversation_history_text = []
|
203 |
+
current_conversation_history_struct = []
|
204 |
+
agent_predicted_probabilities = []
|
205 |
+
output_predictions = []
|
206 |
+
|
207 |
+
num_metrics = 5
|
208 |
+
expected_obs_dim = model_expected_embedding_dim + num_metrics + 1 + 10
|
209 |
+
if model.observation_space.shape[0] != expected_obs_dim:
|
210 |
+
logger.error(f"CRITICAL: Model observation space dimension mismatch! Model expects total obs_dim {model.observation_space.shape[0]}, "
|
211 |
+
f"but calculations suggest {expected_obs_dim} based on model_expected_embedding_dim={model_expected_embedding_dim}. "
|
212 |
+
f"Ensure --embedding_dim matches the dimension used for the embedding component during training.")
|
213 |
+
inferred_emb_dim = model.observation_space.shape[0] - num_metrics - 1 - 10
|
214 |
+
logger.error(f"The model might have been trained with an embedding component of dimension: {inferred_emb_dim}")
|
215 |
+
raise ValueError("Observation space dimension mismatch. Check --embedding_dim.")
|
216 |
+
|
217 |
+
for turn_idx, message_info in enumerate(conversation_messages):
|
218 |
+
speaker = message_info.get("speaker", "unknown")
|
219 |
+
message = message_info.get("message", "")
|
220 |
+
|
221 |
+
current_conversation_history_struct.append(message_info)
|
222 |
+
current_conversation_history_text.append(f"{speaker}: {message}")
|
223 |
+
|
224 |
+
text_for_embedding = "\n".join(current_conversation_history_text)
|
225 |
+
if not text_for_embedding.strip():
|
226 |
+
logger.warning("Empty text for embedding at turn_idx %s, using zero vector from Azure (or fallback).", turn_idx)
|
227 |
+
# Attempt to get an embedding for a neutral character to get shape, or use a known default.
|
228 |
+
# This path should be rare if conversations always start with text.
|
229 |
+
raw_turn_embedding = get_azure_openai_embedding(" ", azure_openai_client, azure_deployment_name)
|
230 |
+
if np.all(raw_turn_embedding == 0): # If fallback was hit
|
231 |
+
logger.warning("Fallback zero embedding used for empty text. Assuming 3072 dim if Azure call failed internally.")
|
232 |
+
raw_turn_embedding = np.zeros(3072, dtype=np.float32) # Default to text-embedding-3-large dim
|
233 |
+
else:
|
234 |
+
raw_turn_embedding = get_azure_openai_embedding(
|
235 |
+
text_for_embedding,
|
236 |
+
azure_openai_client,
|
237 |
+
azure_deployment_name
|
238 |
)
|
239 |
+
|
240 |
+
final_turn_embedding = process_raw_embedding(
|
241 |
+
raw_turn_embedding,
|
242 |
+
turn_idx,
|
243 |
+
max_conversation_turns_scaling,
|
244 |
+
model_expected_embedding_dim,
|
245 |
+
use_miniembeddings_on_azure_emb
|
246 |
+
)
|
247 |
+
|
248 |
+
if final_turn_embedding.shape[0] != model_expected_embedding_dim:
|
249 |
+
logger.error(f"Embedding dimension mismatch after processing. Expected {model_expected_embedding_dim}, got {final_turn_embedding.shape[0]}. Critical error.")
|
250 |
+
raise ValueError("Embedding dimension error after processing.")
|
251 |
+
|
252 |
+
metrics = initial_metrics.copy()
|
253 |
+
metrics['conversation_length'] = len(current_conversation_history_struct)
|
254 |
+
metrics['progress'] = min(1.0, turn_idx / max_conversation_turns_scaling)
|
255 |
+
if 'outcome' not in metrics: metrics['outcome'] = 0.5
|
256 |
+
|
257 |
+
state = ConversationState(
|
258 |
+
conversation_history=current_conversation_history_struct,
|
259 |
+
embedding=final_turn_embedding,
|
260 |
+
conversation_metrics=metrics,
|
261 |
+
turn_number=turn_idx,
|
262 |
+
conversion_probabilities=agent_predicted_probabilities
|
263 |
+
)
|
264 |
+
|
265 |
+
observation_vector = state.state_vector
|
266 |
+
|
267 |
+
if observation_vector.shape[0] != model.observation_space.shape[0]:
|
268 |
+
logger.error(f"Observation vector dimension mismatch before prediction! Expected {model.observation_space.shape[0]}, Got {observation_vector.shape[0]}")
|
269 |
+
raise ValueError("Observation vector dimension mismatch.")
|
270 |
+
|
271 |
+
action_probs, _ = model.predict(observation_vector, deterministic=True)
|
272 |
+
predicted_prob_this_turn = float(action_probs[0])
|
273 |
+
|
274 |
+
output_predictions.append({
|
275 |
+
"turn": turn_idx + 1,
|
276 |
+
"speaker": speaker,
|
277 |
+
"message": message,
|
278 |
+
"predicted_conversion_probability": predicted_prob_this_turn
|
279 |
})
|
280 |
+
agent_predicted_probabilities.append(predicted_prob_this_turn)
|
281 |
+
|
282 |
+
return output_predictions
|
283 |
+
|
284 |
+
|
285 |
+
def main():
|
286 |
+
parser = argparse.ArgumentParser(description="Run inference with Azure OpenAI embeddings.")
|
287 |
+
parser.add_argument("--model_path", type=str, required=True, help="Path to the trained PPO model (.zip file).")
|
288 |
+
parser.add_argument("--conversation_json", type=str, required=True,
|
289 |
+
help="JSON string or path to JSON file for the conversation.")
|
290 |
|
291 |
+
parser.add_argument("--azure_api_key", type=str, required=True, help="Azure OpenAI API Key.")
|
292 |
+
parser.add_argument("--azure_endpoint", type=str, required=True, help="Azure OpenAI Endpoint URL.")
|
293 |
+
parser.add_argument("--azure_deployment_name", type=str, required=True, help="Azure OpenAI embedding deployment name (e.g., for text-embedding-3-large).")
|
294 |
+
parser.add_argument("--azure_api_version", type=str, default="2023-12-01-preview", help="Azure OpenAI API Version (e.g., 2023-05-15 or 2023-12-01-preview for newer models).")
|
295 |
+
|
296 |
+
parser.add_argument("--embedding_dim", type=int, required=True,
|
297 |
+
help="The dimension of the embedding vector component EXPECTED BY THE PPO MODEL's observation space.")
|
298 |
+
parser.add_argument("--use_miniembeddings", action="store_true",
|
299 |
+
help="Flag if the Azure OpenAI embedding should be reduced (if larger than --embedding_dim) using the mini-embedding logic.")
|
300 |
+
parser.add_argument("--max_turns_scaling", type=int, default=20,
|
301 |
+
help="The 'max_turns' value used for progress scaling (default: 20).")
|
302 |
+
args = parser.parse_args()
|
303 |
+
|
304 |
+
try:
|
305 |
+
azure_client = AzureOpenAI(
|
306 |
+
api_key=args.azure_api_key,
|
307 |
+
azure_endpoint=args.azure_endpoint,
|
308 |
+
api_version=args.azure_api_version
|
309 |
)
|
310 |
+
logger.info("Testing Azure OpenAI connection by embedding a short string...")
|
311 |
+
test_embedding = get_azure_openai_embedding("test connection", azure_client, args.azure_deployment_name)
|
312 |
+
logger.info(f"Azure OpenAI connection successful. Received test embedding of shape: {test_embedding.shape}")
|
313 |
+
# This also implicitly tells us the dimension of the deployed Azure model
|
314 |
+
# We could store test_embedding.shape[0] and use it, but process_raw_embedding gets it anyway.
|
315 |
+
|
316 |
+
except Exception as e:
|
317 |
+
logger.error(f"Failed to initialize or test Azure OpenAI client: {e}")
|
318 |
+
return
|
319 |
+
|
320 |
+
try:
|
321 |
+
if os.path.exists(args.conversation_json):
|
322 |
+
with open(args.conversation_json, 'r') as f:
|
323 |
+
sample_conversation = json.load(f)
|
324 |
+
else:
|
325 |
+
sample_conversation = json.loads(args.conversation_json)
|
326 |
+
if not isinstance(sample_conversation, list):
|
327 |
+
raise ValueError("Conversation JSON must be a list of message objects.")
|
328 |
+
except Exception as e:
|
329 |
+
logger.error(f"Error loading conversation JSON: {e}")
|
330 |
+
return
|
331 |
+
|
332 |
+
initial_metrics = {
|
333 |
+
'customer_engagement': 0.5, 'sales_effectiveness': 0.5,
|
334 |
+
'conversation_length': 0, 'outcome': 0.5, 'progress': 0.0
|
335 |
+
}
|
336 |
+
|
337 |
+
try:
|
338 |
+
model = PPO.load(args.model_path, device=device)
|
339 |
+
logger.info(f"Model loaded from {args.model_path}")
|
340 |
+
logger.info(f"Model's observation space shape: {model.observation_space.shape}")
|
341 |
+
except Exception as e:
|
342 |
+
logger.error(f"Error loading PPO model: {e}")
|
343 |
+
return
|
344 |
+
|
345 |
+
predictions = predict_conversation_trajectory(
|
346 |
+
model,
|
347 |
+
azure_client,
|
348 |
+
args.azure_deployment_name,
|
349 |
+
sample_conversation,
|
350 |
+
initial_metrics,
|
351 |
+
model_expected_embedding_dim=args.embedding_dim,
|
352 |
+
use_miniembeddings_on_azure_emb=args.use_miniembeddings,
|
353 |
+
max_conversation_turns_scaling=args.max_turns_scaling
|
354 |
+
)
|
355 |
+
|
356 |
+
print("\n--- Conversation Predictions (with Azure OpenAI Embeddings) ---")
|
357 |
+
for pred_info in predictions:
|
358 |
+
print(f"Turn {pred_info['turn']} ({pred_info['speaker']}): \"{pred_info['message'][:60]}...\" -> Probability: {pred_info['predicted_conversion_probability']:.4f}")
|
359 |
+
|
360 |
if __name__ == "__main__":
|
361 |
+
# python inference_azure_openai_v2.py \
|
362 |
+
# --model_path models/sales_conversion_model.zip \
|
363 |
+
# --conversation_json sample_conv.json \
|
364 |
+
# --azure_api_key "YOUR_AZURE_API_KEY" \
|
365 |
+
# --azure_endpoint "YOUR_AZURE_ENDPOINT" \
|
366 |
+
# --azure_deployment_name "your-text-embedding-3-large-deployment-name" \
|
367 |
+
# --azure_api_version "2023-12-01-preview" \
|
368 |
+
# --embedding_dim 1024 \
|
369 |
+
# --use_miniembeddings
|
370 |
+
#
|
371 |
+
# (The above example assumes your PPO model was trained expecting 1024-dim embeddings,
|
372 |
+
# and text-embedding-3-large (3072-dim) will be reduced to 1024)
|
373 |
+
#
|
374 |
+
# If your PPO model was trained directly with 3072-dim embeddings:
|
375 |
+
# python inference_azure_openai_v2.py \
|
376 |
+
# --model_path models/sales_conversion_model.zip \
|
377 |
+
# --conversation_json sample_conv.json \
|
378 |
+
# --azure_api_key "YOUR_AZURE_API_KEY" \
|
379 |
+
# --azure_endpoint "YOUR_AZURE_ENDPOINT" \
|
380 |
+
# --azure_deployment_name "your-text-embedding-3-large-deployment-name" \
|
381 |
+
# --azure_api_version "2023-12-01-preview" \
|
382 |
+
# --embedding_dim 3072
|
383 |
+
# (Do NOT specify --use_miniembeddings in this case, as 3072 (Azure) == 3072 (model))
|
384 |
+
|
385 |
+
main()
|