Callmebowoo-22 commited on
Commit
0290183
·
verified ·
1 Parent(s): 89143a1

Update utils/models.py

Browse files
Files changed (1) hide show
  1. utils/models.py +36 -32
utils/models.py CHANGED
@@ -1,35 +1,29 @@
1
- from transformers import AutoModelForTimeSeriesPrediction, pipeline
2
  import torch
3
  import numpy as np
 
 
4
 
5
- device = "cuda" if torch.cuda.is_available() else "cpu"
6
-
7
- def predict_umkm(data):
8
  try:
9
- # ===== 1. Validasi Data =====
10
- demand_values = data['demand'].values.astype(float)
11
- if len(demand_values) < 3:
12
- raise ValueError("Data historis terlalu pendek (min 3 titik)")
13
-
14
- # ===== 2. GRANITE-TTM Forecasting =====
15
- # Format input khusus untuk model time series IBM
16
- inputs = {
17
- "past_values": torch.tensor(demand_values, dtype=torch.float32).unsqueeze(0).to(device),
18
- "static_categorical_features": torch.zeros(1, 1, dtype=torch.long).to(device)
19
- }
20
 
21
- # Load model dengan config yang benar
22
- model = AutoModelForTimeSeriesPrediction.from_pretrained(
23
- "ibm/granite-timeseries-ttm-r2",
24
- trust_remote_code=True
25
- ).to(device)
26
 
27
- # Generate prediksi
28
  with torch.no_grad():
29
- outputs = model(**inputs)
30
- predictions = outputs.last_hidden_state.mean(dim=1).squeeze()
31
 
32
- # ===== 3. Format untuk Chronos-T5 =====
33
  chronos = pipeline(
34
  "text-generation",
35
  model="amazon/chronos-t5-small",
@@ -38,18 +32,28 @@ def predict_umkm(data):
38
 
39
  prompt = f"""
40
  [INSTRUCTION]
41
- Berikan rekomendasi stok untuk 7 hari ke depan berdasarkan:
42
- - Prediksi demand: {predictions.cpu().numpy().tolist()[:7]}
43
  - Stok saat ini: {data['supply'].iloc[-1]}
44
- - Tren: {'↑' if predictions[-1] > predictions[0] else '↓'}
45
 
46
  [FORMAT]
47
- 1 kalimat dengan angka spesifik
 
48
  [/FORMAT]
49
  """
50
 
51
- result = chronos(prompt, max_new_tokens=50)[0]['generated_text']
52
- return result.split("[/FORMAT]")[-1].strip()
53
-
 
 
 
 
 
 
 
 
 
54
  except Exception as e:
55
- return f"⚠️ Kesalahan sistem: {str(e)}"
 
 
1
  import torch
2
  import numpy as np
3
+ from tsfm_public.toolkit.get_model import get_model
4
+ from transformers import pipeline
5
 
6
+ def predict_umkm(data, prediction_length=7, confidence=0.85):
7
+ device = "cuda" if torch.cuda.is_available() else "cpu"
8
+
9
  try:
10
+ # ===== 1. GRANITE-TTM Forecasting =====
11
+ model = get_model(
12
+ model_path="ibm-granite/granite-timeseries-ttm-r2",
13
+ context_length=min(512, len(data)),
14
+ prediction_length=prediction_length,
15
+ device=device
16
+ )
 
 
 
 
17
 
18
+ # Format input
19
+ inputs = torch.tensor(data['demand'].values, dtype=torch.float32)
20
+ inputs = inputs.unsqueeze(0).to(device) # Shape: [1, seq_len]
 
 
21
 
22
+ # Prediksi
23
  with torch.no_grad():
24
+ preds = model.generate(inputs).cpu().numpy().flatten()
 
25
 
26
+ # ===== 2. Chronos-T5 Decision =====
27
  chronos = pipeline(
28
  "text-generation",
29
  model="amazon/chronos-t5-small",
 
32
 
33
  prompt = f"""
34
  [INSTRUCTION]
35
+ Berikan rekomendasi untuk manajemen inventory dengan:
36
+ - Prediksi {prediction_length} hari: {preds.tolist()}
37
  - Stok saat ini: {data['supply'].iloc[-1]}
38
+ - Tingkat kepercayaan: {confidence*100}%
39
 
40
  [FORMAT]
41
+ 1 kalimat dalam Bahasa Indonesia dengan angka spesifik.
42
+ Estimasi ROI dalam range persentase.
43
  [/FORMAT]
44
  """
45
 
46
+ response = chronos(prompt, max_length=150)[0]['generated_text']
47
+
48
+ # Ekstrak teks rekomendasi
49
+ rec_text = response.split("[/FORMAT]")[-1].strip()
50
+
51
+ return {
52
+ "text": rec_text,
53
+ "predictions": preds.tolist(),
54
+ "roi": confidence * 0.8, # Simulasi ROI
55
+ "confidence": confidence
56
+ }
57
+
58
  except Exception as e:
59
+ return {"error": str(e)}