Spaces:
Runtime error
Runtime error
added more math tools
Browse files- README.md +72 -25
- app.py +76 -24
- maths/elementary/arithmetic.py +53 -0
- maths/elementary/arithmetic_interface.py +25 -1
- maths/highschool/trigonometry.py +180 -0
- maths/highschool/trigonometry_interface.py +57 -1
- maths/middleschool/algebra.py +154 -0
- maths/middleschool/algebra_interface.py +43 -1
- maths/university/__init__.py +6 -0
- maths/university/calculus.py +223 -0
- maths/university/calculus_interface.py +104 -1
- maths/university/differential_equations.py +236 -0
- maths/university/differential_equations_interface.py +172 -0
- maths/university/linear_algebra.py +249 -0
- maths/university/linear_algebra_interface.py +197 -0
- maths/university/operations_research/BranchAndBoundSolver.py +384 -0
- maths/university/operations_research/DualSimplexSolver.py +443 -0
- maths/university/operations_research/bnb.ipynb +0 -0
- maths/university/operations_research/dual.ipynb +153 -0
- maths/university/operations_research/get_user_input.py +46 -0
- maths/university/operations_research/simplex_solver_with_steps.py +162 -0
- maths/university/operations_research/solve_lp_via_dual.py +317 -0
- maths/university/operations_research/solve_primal_directly.py +49 -0
- maths/university/tests/__init__.py +2 -0
- maths/university/tests/test_differential_equations.py +133 -0
- maths/university/tests/test_linear_algebra.py +132 -0
- requirements.txt +14 -1
README.md
CHANGED
@@ -13,7 +13,7 @@ short_description: Mathematics tools for different educational levels
|
|
13 |
|
14 |
# Math Education Tools
|
15 |
|
16 |
-
This web application provides
|
17 |
|
18 |
## Features
|
19 |
|
@@ -27,50 +27,94 @@ This web application provides various mathematical tools organized by educationa
|
|
27 |
- Subtraction: Subtract one number from another
|
28 |
- Multiplication: Multiply two numbers
|
29 |
- Division: Divide one number by another
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
### Middle School Math
|
32 |
|
33 |
-
- Linear Equation Solver: Solve equations of the form ax = b
|
34 |
-
- Quadratic Expression Evaluator: Calculate the value of ax² + bx + c for a given x
|
|
|
|
|
|
|
35 |
|
36 |
### High School Math
|
37 |
|
38 |
-
- Trigonometry Calculator: Calculate sine, cosine, and tangent values for angles in degrees
|
|
|
|
|
|
|
39 |
|
40 |
### University Math
|
41 |
|
42 |
-
|
43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
44 |
|
45 |
## Project Structure
|
46 |
|
47 |
```
|
48 |
-
├── app.py # Main application file
|
49 |
-
├──
|
|
|
50 |
│ ├── __init__.py
|
51 |
-
│ ├── elementary/ # Elementary school level
|
52 |
│ │ ├── __init__.py
|
53 |
-
│ │
|
54 |
-
│
|
|
|
55 |
│ │ ├── __init__.py
|
56 |
-
│ │
|
57 |
-
│
|
|
|
58 |
│ │ ├── __init__.py
|
59 |
-
│ │
|
60 |
-
│ └──
|
|
|
61 |
│ ├── __init__.py
|
62 |
-
│
|
63 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
├── __init__.py
|
65 |
-
|
|
|
66 |
```
|
67 |
|
68 |
## Getting Started
|
69 |
|
70 |
1. Install the required packages:
|
71 |
|
72 |
-
```
|
73 |
-
pip install
|
74 |
```
|
75 |
|
76 |
2. Run the application:
|
@@ -83,12 +127,15 @@ This web application provides various mathematical tools organized by educationa
|
|
83 |
|
84 |
## Extending the Project
|
85 |
|
86 |
-
To add new mathematical functions:
|
87 |
|
88 |
-
1.
|
89 |
-
2.
|
90 |
-
3.
|
91 |
-
4.
|
|
|
|
|
|
|
92 |
|
93 |
## License
|
94 |
|
|
|
13 |
|
14 |
# Math Education Tools
|
15 |
|
16 |
+
This web application provides a comprehensive suite of mathematical tools and interactive calculators, categorized by educational levels from elementary school through university. Built with Python and featuring user-friendly Gradio interfaces, it aims to assist with learning and performing a wide array of mathematical operations.
|
17 |
|
18 |
## Features
|
19 |
|
|
|
27 |
- Subtraction: Subtract one number from another
|
28 |
- Multiplication: Multiply two numbers
|
29 |
- Division: Divide one number by another
|
30 |
+
- Greatest Common Divisor (GCD): Find the GCD of two integers.
|
31 |
+
- Least Common Multiple (LCM): Find the LCM of two integers.
|
32 |
+
- Prime Number Checker: Check if a number is prime.
|
33 |
+
- Array Calculation: Perform basic operations on a list of numbers.
|
34 |
+
- Array Calculation with Visualization: Array calculation with a number line plot.
|
35 |
|
36 |
### Middle School Math
|
37 |
|
38 |
+
- Linear Equation Solver: Solve equations of the form ax = b.
|
39 |
+
- Quadratic Expression Evaluator: Calculate the value of ax² + bx + c for a given x.
|
40 |
+
- Quadratic Equation Solver: Find roots of ax² + bx + c = 0.
|
41 |
+
- Radical Simplifier: Simplify square roots (e.g., √12 to 2√3).
|
42 |
+
- Polynomial Operations: Add, subtract, and multiply polynomials.
|
43 |
|
44 |
### High School Math
|
45 |
|
46 |
+
- Trigonometry Calculator: Calculate sine, cosine, and tangent values for angles in degrees.
|
47 |
+
- Inverse Trigonometric Functions: Calculate asin, acos, atan in degrees.
|
48 |
+
- Trigonometric Equation Solver: Solve basic trigonometric equations.
|
49 |
+
- Trigonometric Identities: Demonstrate common identities like sin²(x) + cos²(x) = 1.
|
50 |
|
51 |
### University Math
|
52 |
|
53 |
+
#### Calculus
|
54 |
+
|
55 |
+
- Polynomial Derivative: Find the derivative of a polynomial function.
|
56 |
+
- Polynomial Integration: Find the indefinite integral of a polynomial function.
|
57 |
+
- Limits: Calculate the limit of an expression as a variable approaches a point.
|
58 |
+
- Series Expansion: Compute Taylor series for functions. Includes examples for Fourier series.
|
59 |
+
- Partial Derivatives: Compute partial derivatives of multi-variable expressions.
|
60 |
+
- Multiple Integrals: Evaluate definite multiple integrals.
|
61 |
+
|
62 |
+
#### Linear Algebra
|
63 |
+
|
64 |
+
- Matrix Operations: Addition, subtraction, multiplication.
|
65 |
+
- Matrix Analysis: Determinant and inverse of matrices.
|
66 |
+
- Vector Operations: Addition, subtraction, dot product, and cross product (for 3D vectors).
|
67 |
+
- System Solver: Solve systems of linear equations (Ax = B).
|
68 |
+
|
69 |
+
#### Differential Equations
|
70 |
+
|
71 |
+
- First-Order ODEs: Solve single or systems of first-order ordinary differential equations with initial conditions.
|
72 |
+
- Second-Order ODEs: Solve second-order ordinary differential equations by converting to a system.
|
73 |
+
- Solution Plotting: Visualize solutions for ODEs.
|
74 |
|
75 |
## Project Structure
|
76 |
|
77 |
```
|
78 |
+
├── app.py # Main Gradio application file
|
79 |
+
├── requirements.txt # Python package dependencies
|
80 |
+
├── maths/ # Core mathematics modules
|
81 |
│ ├── __init__.py
|
82 |
+
│ ├── elementary/ # Elementary school level functions and interfaces
|
83 |
│ │ ├── __init__.py
|
84 |
+
│ │ ├── arithmetic.py # Logic for basic arithmetic, GCD, LCM, primes
|
85 |
+
│ │ └── arithmetic_interface.py # Gradio interfaces for arithmetic
|
86 |
+
│ ├── middleschool/ # Middle school level functions and interfaces
|
87 |
│ │ ├── __init__.py
|
88 |
+
│ │ ├── algebra.py # Logic for linear/quadratic equations, polynomials
|
89 |
+
│ │ └── algebra_interface.py # Gradio interfaces for algebra
|
90 |
+
│ ├── highschool/ # High school level functions and interfaces
|
91 |
│ │ ├── __init__.py
|
92 |
+
│ │ ├── trigonometry.py # Logic for trigonometric functions and equations
|
93 |
+
│ │ └── trigonometry_interface.py # Gradio interfaces for trigonometry
|
94 |
+
│ └── university/ # University level functions and interfaces
|
95 |
│ ├── __init__.py
|
96 |
+
│ ├── calculus.py # Logic for calculus (limits, derivatives, integrals, series)
|
97 |
+
│ ├── calculus_interface.py # Gradio interfaces for calculus
|
98 |
+
│ ├── linear_algebra.py # Logic for matrix/vector operations, linear systems
|
99 |
+
│ ├── linear_algebra_interface.py # Gradio interfaces for linear algebra
|
100 |
+
│ ├── differential_equations.py # Logic for solving ODEs
|
101 |
+
│ ├── differential_equations_interface.py # Gradio interfaces for ODEs
|
102 |
+
│ └── tests/ # Unit tests for university level modules
|
103 |
+
│ ├── __init__.py
|
104 |
+
│ ├── test_linear_algebra.py
|
105 |
+
│ └── test_differential_equations.py
|
106 |
+
└── utils/ # Utility functions and interfaces
|
107 |
├── __init__.py
|
108 |
+
├── text_utils.py # Logic for text processing utilities
|
109 |
+
└── text_utils_interface.py # Gradio interface for text utils
|
110 |
```
|
111 |
|
112 |
## Getting Started
|
113 |
|
114 |
1. Install the required packages:
|
115 |
|
116 |
+
```bash
|
117 |
+
pip install -r requirements.txt
|
118 |
```
|
119 |
|
120 |
2. Run the application:
|
|
|
127 |
|
128 |
## Extending the Project
|
129 |
|
130 |
+
To add new mathematical functions or tools:
|
131 |
|
132 |
+
1. **Create Logic**: Add your mathematical function/logic to an existing Python file in the appropriate `maths/<level>/` directory or create a new `.py` file for it. Ensure your functions have clear inputs, outputs, and docstrings.
|
133 |
+
2. **Create Interface**: In the corresponding `maths/<level>/` directory (or a subdirectory like `interfaces` if preferred), create a `_interface.py` file (e.g., `newfeature_interface.py`) or add to an existing one. In this file, import your function(s) and create a Gradio interface for each. Define clear input and output components (e.g., `gr.Textbox()`, `gr.Number()`, `gr.Plot()`, `gr.Image()`).
|
134 |
+
3. **Add Tests**: For more complex logic, especially at the university level, add unit tests in a corresponding `maths/<level>/tests/` directory (e.g., `test_newfeature.py`). Use the `unittest` module or `pytest`.
|
135 |
+
4. **Update Main App**: Import your new Gradio interface object(s) in the main `app.py` file. Add the interface object(s) to the list of interfaces for the relevant educational level tab (e.g., `university_interfaces_list`) and provide a corresponding name in the tab names list (e.g., `university_tab_names`).
|
136 |
+
5. **Update `__init__.py` Files**: Ensure your new modules and interface modules are importable by updating the `__init__.py` files in their respective directories if necessary (e.g., `from . import newfeature`, `from . import newfeature_interface`).
|
137 |
+
6. **Update `requirements.txt`**: If your new feature introduces new external dependencies, add them to `requirements.txt`. It's good practice to pin versions (e.g., `new_package==1.2.3`).
|
138 |
+
7. **Update README**: Document your new feature in the "Features" section and update the "Project Structure" if you added new files/directories.
|
139 |
|
140 |
## License
|
141 |
|
app.py
CHANGED
@@ -2,36 +2,88 @@ import gradio as gr
|
|
2 |
from utils.text_utils_interface import letter_counter_interface
|
3 |
from maths.elementary.arithmetic_interface import (
|
4 |
add_interface, subtract_interface, multiply_interface, divide_interface,
|
5 |
-
array_calc_interface, array_calc_vis_interface
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
)
|
7 |
-
from maths.middleschool.algebra_interface import solve_linear_equation_interface, evaluate_expression_interface
|
8 |
-
from maths.highschool.trigonometry_interface import trig_interface
|
9 |
-
from maths.university.calculus_interface import derivative_interface, integral_interface
|
10 |
|
11 |
# Group interfaces by education level
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
|
|
|
|
|
|
23 |
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
|
|
|
|
|
|
29 |
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
|
36 |
# Main demo with tabs for each education level
|
37 |
demo = gr.TabbedInterface(
|
|
|
2 |
from utils.text_utils_interface import letter_counter_interface
|
3 |
from maths.elementary.arithmetic_interface import (
|
4 |
add_interface, subtract_interface, multiply_interface, divide_interface,
|
5 |
+
array_calc_interface, array_calc_vis_interface,
|
6 |
+
gcd_interface, lcm_interface, is_prime_interface # Assuming these were added earlier from a subtask
|
7 |
+
)
|
8 |
+
from maths.middleschool.algebra_interface import (
|
9 |
+
solve_linear_equation_interface, evaluate_expression_interface,
|
10 |
+
solve_quadratic_interface, simplify_radical_interface, polynomial_interface # Assuming these
|
11 |
+
)
|
12 |
+
from maths.highschool.trigonometry_interface import (
|
13 |
+
trig_interface, inverse_trig_interface, solve_trig_equations_interface, trig_identities_interface # Assuming these
|
14 |
+
)
|
15 |
+
from maths.university.calculus_interface import (
|
16 |
+
derivative_interface, integral_interface,
|
17 |
+
limit_interface, taylor_series_interface, fourier_series_interface, # Assuming these
|
18 |
+
partial_derivative_interface, multiple_integral_interface # Assuming these
|
19 |
+
)
|
20 |
+
from maths.university.linear_algebra_interface import (
|
21 |
+
matrix_add_interface, matrix_subtract_interface, matrix_multiply_interface,
|
22 |
+
matrix_determinant_interface, matrix_inverse_interface,
|
23 |
+
vector_add_interface, vector_subtract_interface, vector_dot_product_interface,
|
24 |
+
vector_cross_product_interface, solve_linear_system_interface
|
25 |
+
)
|
26 |
+
from maths.university.differential_equations_interface import (
|
27 |
+
first_order_ode_interface, second_order_ode_interface
|
28 |
)
|
|
|
|
|
|
|
29 |
|
30 |
# Group interfaces by education level
|
31 |
+
# Note: I'm assuming previous subtasks correctly added GCD, LCM, Polynomial, Quadratic, etc. interfaces to their respective files.
|
32 |
+
# If not, this app.py will have import errors for those. I'm focusing on the current task's imports.
|
33 |
+
elementary_interfaces_list = [
|
34 |
+
add_interface, subtract_interface, multiply_interface, divide_interface,
|
35 |
+
gcd_interface, lcm_interface, is_prime_interface, # Added from subtask 1
|
36 |
+
array_calc_interface, array_calc_vis_interface
|
37 |
+
]
|
38 |
+
elementary_tab_names = [
|
39 |
+
"Addition", "Subtraction", "Multiplication", "Division",
|
40 |
+
"GCD", "LCM", "Prime Check", # Added from subtask 1
|
41 |
+
"Array Calculation", "Array Calc Viz"
|
42 |
+
]
|
43 |
|
44 |
+
middleschool_interfaces_list = [
|
45 |
+
solve_linear_equation_interface, evaluate_expression_interface,
|
46 |
+
solve_quadratic_interface, simplify_radical_interface, polynomial_interface # Added from subtask 2
|
47 |
+
]
|
48 |
+
middleschool_tab_names = [
|
49 |
+
"Linear Equations", "Evaluate Expressions",
|
50 |
+
"Quadratic Solver", "Radical Simplifier", "Polynomial Ops" # Added from subtask 2
|
51 |
+
]
|
52 |
|
53 |
+
highschool_interfaces_list = [
|
54 |
+
trig_interface, inverse_trig_interface,
|
55 |
+
solve_trig_equations_interface, trig_identities_interface # Added from subtask 3
|
56 |
+
]
|
57 |
+
highschool_tab_names = [
|
58 |
+
"Trig Functions", "Inverse Trig",
|
59 |
+
"Solve Trig Eqs", "Trig Identities" # Added from subtask 3
|
60 |
+
]
|
61 |
|
62 |
+
university_interfaces_list = [
|
63 |
+
derivative_interface, integral_interface,
|
64 |
+
limit_interface, taylor_series_interface, fourier_series_interface, # Added from subtask 4
|
65 |
+
partial_derivative_interface, multiple_integral_interface, # Added from subtask 4
|
66 |
+
matrix_add_interface, matrix_subtract_interface, matrix_multiply_interface,
|
67 |
+
matrix_determinant_interface, matrix_inverse_interface,
|
68 |
+
vector_add_interface, vector_subtract_interface, vector_dot_product_interface,
|
69 |
+
vector_cross_product_interface, solve_linear_system_interface,
|
70 |
+
first_order_ode_interface, second_order_ode_interface
|
71 |
+
]
|
72 |
+
university_tab_names = [
|
73 |
+
"Poly Derivatives", "Poly Integrals",
|
74 |
+
"Limits", "Taylor Series", "Fourier Series", # Added from subtask 4
|
75 |
+
"Partial Derivatives", "Multiple Integrals", # Added from subtask 4
|
76 |
+
"Matrix Add", "Matrix Subtract", "Matrix Multiply",
|
77 |
+
"Matrix Determinant", "Matrix Inverse",
|
78 |
+
"Vector Add", "Vector Subtract", "Vector Dot Product",
|
79 |
+
"Vector Cross Product", "Solve Linear System",
|
80 |
+
"1st Order ODE", "2nd Order ODE"
|
81 |
+
]
|
82 |
+
|
83 |
+
elementary_tab = gr.TabbedInterface(elementary_interfaces_list, elementary_tab_names, title="Elementary School Math")
|
84 |
+
middleschool_tab = gr.TabbedInterface(middleschool_interfaces_list, middleschool_tab_names, title="Middle School Math")
|
85 |
+
highschool_tab = gr.TabbedInterface(highschool_interfaces_list, highschool_tab_names, title="High School Math")
|
86 |
+
university_tab = gr.TabbedInterface(university_interfaces_list, university_tab_names, title="University Math")
|
87 |
|
88 |
# Main demo with tabs for each education level
|
89 |
demo = gr.TabbedInterface(
|
maths/elementary/arithmetic.py
CHANGED
@@ -189,3 +189,56 @@ def calculate_array_with_visualization(numbers: list, operations: list) -> tuple
|
|
189 |
# Visualization: show all intermediate results on the number line
|
190 |
fig = create_number_line_visualization(numbers, ' -> '.join(operations), result)
|
191 |
return result, fig
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
189 |
# Visualization: show all intermediate results on the number line
|
190 |
fig = create_number_line_visualization(numbers, ' -> '.join(operations), result)
|
191 |
return result, fig
|
192 |
+
|
193 |
+
|
194 |
+
def gcd(a: int, b: int) -> int:
|
195 |
+
"""Compute the greatest common divisor of two integers.
|
196 |
+
|
197 |
+
Args:
|
198 |
+
a: The first integer.
|
199 |
+
b: The second integer.
|
200 |
+
|
201 |
+
Returns:
|
202 |
+
The greatest common divisor of a and b.
|
203 |
+
"""
|
204 |
+
while b:
|
205 |
+
a, b = b, a % b
|
206 |
+
return abs(a)
|
207 |
+
|
208 |
+
|
209 |
+
def lcm(a: int, b: int) -> int:
|
210 |
+
"""Compute the least common multiple of two integers.
|
211 |
+
|
212 |
+
Args:
|
213 |
+
a: The first integer.
|
214 |
+
b: The second integer.
|
215 |
+
|
216 |
+
Returns:
|
217 |
+
The least common multiple of a and b.
|
218 |
+
"""
|
219 |
+
if a == 0 or b == 0:
|
220 |
+
return 0
|
221 |
+
return abs(a * b) // gcd(a, b)
|
222 |
+
|
223 |
+
|
224 |
+
def is_prime(n: int) -> bool:
|
225 |
+
"""Check if a number is a prime number.
|
226 |
+
|
227 |
+
Args:
|
228 |
+
n: The number to check.
|
229 |
+
|
230 |
+
Returns:
|
231 |
+
True if n is prime, False otherwise.
|
232 |
+
"""
|
233 |
+
if n <= 1:
|
234 |
+
return False
|
235 |
+
if n <= 3:
|
236 |
+
return True
|
237 |
+
if n % 2 == 0 or n % 3 == 0:
|
238 |
+
return False
|
239 |
+
i = 5
|
240 |
+
while i * i <= n:
|
241 |
+
if n % i == 0 or n % (i + 2) == 0:
|
242 |
+
return False
|
243 |
+
i += 6
|
244 |
+
return True
|
maths/elementary/arithmetic_interface.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
import gradio as gr
|
2 |
-
from maths.elementary.arithmetic import add, subtract, multiply, divide, calculate_array, calculate_array_with_visualization
|
3 |
|
4 |
# Elementary Math Tab
|
5 |
add_interface = gr.Interface(
|
@@ -71,3 +71,27 @@ array_calc_vis_interface = gr.Interface(
|
|
71 |
title="Array Calculation with Visualization",
|
72 |
description="Calculate a sequence of numbers with specified operations and see a number line visualization."
|
73 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
from maths.elementary.arithmetic import add, subtract, multiply, divide, calculate_array, calculate_array_with_visualization, gcd, lcm, is_prime
|
3 |
|
4 |
# Elementary Math Tab
|
5 |
add_interface = gr.Interface(
|
|
|
71 |
title="Array Calculation with Visualization",
|
72 |
description="Calculate a sequence of numbers with specified operations and see a number line visualization."
|
73 |
)
|
74 |
+
|
75 |
+
gcd_interface = gr.Interface(
|
76 |
+
fn=gcd,
|
77 |
+
inputs=[gr.Number(label="A", precision=0), gr.Number(label="B", precision=0)],
|
78 |
+
outputs="number",
|
79 |
+
title="Greatest Common Divisor (GCD)",
|
80 |
+
description="Compute the greatest common divisor of two integers."
|
81 |
+
)
|
82 |
+
|
83 |
+
lcm_interface = gr.Interface(
|
84 |
+
fn=lcm,
|
85 |
+
inputs=[gr.Number(label="A", precision=0), gr.Number(label="B", precision=0)],
|
86 |
+
outputs="number",
|
87 |
+
title="Least Common Multiple (LCM)",
|
88 |
+
description="Compute the least common multiple of two integers."
|
89 |
+
)
|
90 |
+
|
91 |
+
is_prime_interface = gr.Interface(
|
92 |
+
fn=is_prime,
|
93 |
+
inputs=[gr.Number(label="Number", precision=0)],
|
94 |
+
outputs="text", # Outputting as text to give a clear True/False message
|
95 |
+
title="Prime Number Check",
|
96 |
+
description="Check if a number is a prime number."
|
97 |
+
)
|
maths/highschool/trigonometry.py
CHANGED
@@ -14,3 +14,183 @@ def cos_degrees(angle):
|
|
14 |
def tan_degrees(angle):
|
15 |
"""Calculate the tangent of an angle in degrees."""
|
16 |
return math.tan(math.radians(angle))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
def tan_degrees(angle):
|
15 |
"""Calculate the tangent of an angle in degrees."""
|
16 |
return math.tan(math.radians(angle))
|
17 |
+
|
18 |
+
|
19 |
+
def inverse_trig_functions(value: float, function_name: str) -> str:
|
20 |
+
"""
|
21 |
+
Calculates inverse trigonometric functions (asin, acos, atan) in degrees.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
value: The value to calculate the inverse trigonometric function for.
|
25 |
+
For asin and acos, must be between -1 and 1.
|
26 |
+
function_name: "asin", "acos", or "atan".
|
27 |
+
|
28 |
+
Returns:
|
29 |
+
A string representing the result in degrees, or an error message.
|
30 |
+
"""
|
31 |
+
if not isinstance(value, (int, float)):
|
32 |
+
return "Error: Input value must be a number."
|
33 |
+
|
34 |
+
func_name = function_name.lower()
|
35 |
+
result_rad = 0.0
|
36 |
+
|
37 |
+
if func_name == "asin":
|
38 |
+
if -1 <= value <= 1:
|
39 |
+
result_rad = math.asin(value)
|
40 |
+
else:
|
41 |
+
return "Error: Input for asin must be between -1 and 1."
|
42 |
+
elif func_name == "acos":
|
43 |
+
if -1 <= value <= 1:
|
44 |
+
result_rad = math.acos(value)
|
45 |
+
else:
|
46 |
+
return "Error: Input for acos must be between -1 and 1."
|
47 |
+
elif func_name == "atan":
|
48 |
+
result_rad = math.atan(value)
|
49 |
+
else:
|
50 |
+
return "Error: Invalid function name. Choose 'asin', 'acos', or 'atan'."
|
51 |
+
|
52 |
+
return f"{math.degrees(result_rad):.4f} degrees"
|
53 |
+
|
54 |
+
|
55 |
+
def solve_trig_equations(a: float, b: float, c: float, trig_func: str, interval_degrees: tuple[float, float] = (0, 360)) -> str:
|
56 |
+
"""
|
57 |
+
Solves basic trigonometric equations of the form a * func(x) + b = c.
|
58 |
+
Finds solutions for x within a given interval (in degrees).
|
59 |
+
|
60 |
+
Args:
|
61 |
+
a: Coefficient of the trigonometric function.
|
62 |
+
b: Constant term added to the function part.
|
63 |
+
c: Constant term on the other side of the equation.
|
64 |
+
trig_func: The trigonometric function ("sin", "cos", "tan").
|
65 |
+
interval_degrees: Tuple (min_angle, max_angle) for solutions in degrees.
|
66 |
+
|
67 |
+
Returns:
|
68 |
+
A string describing the solutions in degrees.
|
69 |
+
"""
|
70 |
+
if a == 0:
|
71 |
+
return "Error: Coefficient 'a' cannot be zero."
|
72 |
+
|
73 |
+
# Rearrange to func(x) = (c - b) / a
|
74 |
+
val = (c - b) / a
|
75 |
+
func = trig_func.lower()
|
76 |
+
solutions_deg = []
|
77 |
+
|
78 |
+
if func == "sin":
|
79 |
+
if not (-1 <= val <= 1):
|
80 |
+
return f"No solution: sin(x) cannot be {val:.4f}."
|
81 |
+
angle_rad_principal = math.asin(val)
|
82 |
+
elif func == "cos":
|
83 |
+
if not (-1 <= val <= 1):
|
84 |
+
return f"No solution: cos(x) cannot be {val:.4f}."
|
85 |
+
angle_rad_principal = math.acos(val)
|
86 |
+
elif func == "tan":
|
87 |
+
angle_rad_principal = math.atan(val)
|
88 |
+
else:
|
89 |
+
return "Error: Invalid trigonometric function. Choose 'sin', 'cos', or 'tan'."
|
90 |
+
|
91 |
+
# Convert principal solution to degrees
|
92 |
+
angle_deg_principal = math.degrees(angle_rad_principal)
|
93 |
+
|
94 |
+
min_interval_deg, max_interval_deg = interval_degrees
|
95 |
+
|
96 |
+
# Find solutions within the interval
|
97 |
+
# General solutions:
|
98 |
+
# sin(x) = sin(alpha) => x = n*360 + alpha OR x = n*360 + (180 - alpha)
|
99 |
+
# cos(x) = cos(alpha) => x = n*360 + alpha OR x = n*360 - alpha
|
100 |
+
# tan(x) = tan(alpha) => x = n*180 + alpha
|
101 |
+
|
102 |
+
for n in range(int(min_interval_deg / 360) - 2, int(max_interval_deg / 360) + 3): # Check a few cycles around the interval
|
103 |
+
if func == "sin":
|
104 |
+
sol1_deg = n * 360 + angle_deg_principal
|
105 |
+
sol2_deg = n * 360 + (180 - angle_deg_principal)
|
106 |
+
if min_interval_deg <= sol1_deg <= max_interval_deg:
|
107 |
+
solutions_deg.append(sol1_deg)
|
108 |
+
if min_interval_deg <= sol2_deg <= max_interval_deg:
|
109 |
+
solutions_deg.append(sol2_deg)
|
110 |
+
elif func == "cos":
|
111 |
+
sol1_deg = n * 360 + angle_deg_principal
|
112 |
+
sol2_deg = n * 360 - angle_deg_principal
|
113 |
+
if min_interval_deg <= sol1_deg <= max_interval_deg:
|
114 |
+
solutions_deg.append(sol1_deg)
|
115 |
+
if min_interval_deg <= sol2_deg <= max_interval_deg:
|
116 |
+
solutions_deg.append(sol2_deg)
|
117 |
+
elif func == "tan":
|
118 |
+
# For tan, general solution is n*180 + alpha
|
119 |
+
for n_tan in range(int(min_interval_deg / 180) - 2, int(max_interval_deg / 180) + 3):
|
120 |
+
sol_deg = n_tan * 180 + angle_deg_principal
|
121 |
+
if min_interval_deg <= sol_deg <= max_interval_deg:
|
122 |
+
solutions_deg.append(sol_deg)
|
123 |
+
|
124 |
+
# Remove duplicates and sort
|
125 |
+
unique_solutions = sorted(list(set(f"{s:.2f}" for s in solutions_deg))) # Format to avoid floating point issues
|
126 |
+
|
127 |
+
if not unique_solutions:
|
128 |
+
return f"No solutions found for {a}*{func}(x) + {b} = {c} in the interval [{min_interval_deg}, {max_interval_deg}] degrees."
|
129 |
+
|
130 |
+
return f"Solutions for x in [{min_interval_deg}, {max_interval_deg}] degrees: {', '.join(unique_solutions)}"
|
131 |
+
|
132 |
+
|
133 |
+
def trig_identities(angle_degrees: float, identity_name: str = "pythagorean1") -> str:
|
134 |
+
"""
|
135 |
+
Demonstrates common trigonometric identities for a given angle (in degrees).
|
136 |
+
|
137 |
+
Args:
|
138 |
+
angle_degrees: The angle in degrees to evaluate the identities for.
|
139 |
+
identity_name: Name of the identity to demonstrate.
|
140 |
+
"pythagorean1": sin^2(x) + cos^2(x) = 1
|
141 |
+
"pythagorean2": 1 + tan^2(x) = sec^2(x)
|
142 |
+
"pythagorean3": 1 + cot^2(x) = csc^2(x)
|
143 |
+
"all": Show all Pythagorean identities.
|
144 |
+
More can be added.
|
145 |
+
|
146 |
+
Returns:
|
147 |
+
A string demonstrating the identity.
|
148 |
+
"""
|
149 |
+
x_rad = math.radians(angle_degrees)
|
150 |
+
sinx = math.sin(x_rad)
|
151 |
+
cosx = math.cos(x_rad)
|
152 |
+
|
153 |
+
# Avoid division by zero for tan, sec, cot, csc
|
154 |
+
# Check if cosx is very close to zero
|
155 |
+
if abs(cosx) < 1e-9: # cos(90), cos(270) etc.
|
156 |
+
tanx = float('inf') if sinx > 0 else float('-inf') if sinx < 0 else 0 # tan is undefined or 0 if sinx is also 0
|
157 |
+
secx = float('inf') if cosx >= 0 else float('-inf') # sec is undefined
|
158 |
+
else:
|
159 |
+
tanx = sinx / cosx
|
160 |
+
secx = 1 / cosx
|
161 |
+
|
162 |
+
# Check if sinx is very close to zero
|
163 |
+
if abs(sinx) < 1e-9: # sin(0), sin(180) etc.
|
164 |
+
cotx = float('inf') if cosx > 0 else float('-inf') if cosx < 0 else 0 # cot is undefined or 0 if cosx is also 0
|
165 |
+
cscx = float('inf') if sinx >= 0 else float('-inf') # csc is undefined
|
166 |
+
else:
|
167 |
+
cotx = cosx / sinx
|
168 |
+
cscx = 1 / sinx
|
169 |
+
|
170 |
+
results = []
|
171 |
+
name = identity_name.lower()
|
172 |
+
|
173 |
+
if name == "pythagorean1" or name == "all":
|
174 |
+
lhs = sinx**2 + cosx**2
|
175 |
+
results.append(f"Pythagorean Identity 1: sin^2({angle_degrees}) + cos^2({angle_degrees}) = {sinx**2:.4f} + {cosx**2:.4f} = {lhs:.4f} (Expected: 1)")
|
176 |
+
|
177 |
+
if name == "pythagorean2" or name == "all":
|
178 |
+
if abs(cosx) < 1e-9:
|
179 |
+
results.append(f"Pythagorean Identity 2 (1 + tan^2(x) = sec^2(x)): Not well-defined for x={angle_degrees} degrees as cos(x) is near zero (tan(x) and sec(x) are undefined or infinite).")
|
180 |
+
else:
|
181 |
+
lhs = 1 + tanx**2
|
182 |
+
rhs = secx**2
|
183 |
+
results.append(f"Pythagorean Identity 2: 1 + tan^2({angle_degrees}) = 1 + {tanx**2:.4f} = {lhs:.4f}. sec^2({angle_degrees}) = {secx**2:.4f}. (Expected LHS = RHS)")
|
184 |
+
|
185 |
+
if name == "pythagorean3" or name == "all":
|
186 |
+
if abs(sinx) < 1e-9:
|
187 |
+
results.append(f"Pythagorean Identity 3 (1 + cot^2(x) = csc^2(x)): Not well-defined for x={angle_degrees} degrees as sin(x) is near zero (cot(x) and csc(x) are undefined or infinite).")
|
188 |
+
else:
|
189 |
+
lhs = 1 + cotx**2
|
190 |
+
rhs = cscx**2
|
191 |
+
results.append(f"Pythagorean Identity 3: 1 + cot^2({angle_degrees}) = 1 + {cotx**2:.4f} = {lhs:.4f}. csc^2({angle_degrees}) = {cscx**2:.4f}. (Expected LHS = RHS)")
|
192 |
+
|
193 |
+
if not results:
|
194 |
+
return f"Unknown identity: {identity_name}. Available: pythagorean1, pythagorean2, pythagorean3, all."
|
195 |
+
|
196 |
+
return "\n".join(results)
|
maths/highschool/trigonometry_interface.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
import gradio as gr
|
2 |
-
from maths.highschool.trigonometry import sin_degrees, cos_degrees, tan_degrees
|
3 |
|
4 |
# High School Math Tab
|
5 |
trig_interface = gr.Interface(
|
@@ -16,3 +16,59 @@ trig_interface = gr.Interface(
|
|
16 |
title="Trigonometry Calculator",
|
17 |
description="Calculate trigonometric functions for angles in degrees"
|
18 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
from maths.highschool.trigonometry import sin_degrees, cos_degrees, tan_degrees, inverse_trig_functions, solve_trig_equations, trig_identities
|
3 |
|
4 |
# High School Math Tab
|
5 |
trig_interface = gr.Interface(
|
|
|
16 |
title="Trigonometry Calculator",
|
17 |
description="Calculate trigonometric functions for angles in degrees"
|
18 |
)
|
19 |
+
|
20 |
+
inverse_trig_interface = gr.Interface(
|
21 |
+
fn=inverse_trig_functions,
|
22 |
+
inputs=[
|
23 |
+
gr.Number(label="Value"),
|
24 |
+
gr.Radio(["asin", "acos", "atan"], label="Inverse Function")
|
25 |
+
],
|
26 |
+
outputs="text",
|
27 |
+
title="Inverse Trigonometry Calculator",
|
28 |
+
description="Calculate inverse trigonometric functions. Output in degrees."
|
29 |
+
)
|
30 |
+
|
31 |
+
def parse_interval(interval_str: str) -> tuple[float, float]:
|
32 |
+
"""Helper to parse comma-separated interval string (e.g., "0,360") into a tuple."""
|
33 |
+
try:
|
34 |
+
parts = [float(x.strip()) for x in interval_str.split(',')]
|
35 |
+
if len(parts) == 2 and parts[0] <= parts[1]:
|
36 |
+
return parts[0], parts[1]
|
37 |
+
raise ValueError("Interval must be two numbers, min,max.")
|
38 |
+
except Exception:
|
39 |
+
# Return a default or raise specific error for Gradio
|
40 |
+
raise gr.Error("Invalid interval format. Use 'min,max' (e.g., '0,360').")
|
41 |
+
|
42 |
+
|
43 |
+
solve_trig_equations_interface = gr.Interface(
|
44 |
+
fn=lambda a, b, c, func, interval_str: solve_trig_equations(a,b,c,func, parse_interval(interval_str)),
|
45 |
+
inputs=[
|
46 |
+
gr.Number(label="a (coefficient of function)"),
|
47 |
+
gr.Number(label="b (constant added to function part)"),
|
48 |
+
gr.Number(label="c (constant on other side of equation)"),
|
49 |
+
gr.Radio(["sin", "cos", "tan"], label="Trigonometric Function"),
|
50 |
+
gr.Textbox(label="Interval for x (degrees, comma-separated, e.g., 0,360)", value="0,360")
|
51 |
+
],
|
52 |
+
outputs="text",
|
53 |
+
title="Trigonometric Equation Solver",
|
54 |
+
description="Solves equations like a * func(x) + b = c for x in a given interval (degrees)."
|
55 |
+
)
|
56 |
+
|
57 |
+
trig_identities_interface = gr.Interface(
|
58 |
+
fn=trig_identities,
|
59 |
+
inputs=[
|
60 |
+
gr.Number(label="Angle (degrees)"),
|
61 |
+
gr.Radio(
|
62 |
+
choices=[
|
63 |
+
("sin²(x) + cos²(x) = 1", "pythagorean1"),
|
64 |
+
("1 + tan²(x) = sec²(x)", "pythagorean2"),
|
65 |
+
("1 + cot²(x) = csc²(x)", "pythagorean3"),
|
66 |
+
("All Pythagorean", "all")
|
67 |
+
],
|
68 |
+
label="Trigonometric Identity to Demonstrate"
|
69 |
+
)
|
70 |
+
],
|
71 |
+
outputs="text",
|
72 |
+
title="Trigonometric Identities Demonstrator",
|
73 |
+
description="Show common trigonometric identities for a given angle."
|
74 |
+
)
|
maths/middleschool/algebra.py
CHANGED
@@ -18,3 +18,157 @@ def evaluate_expression(a, b, c, x):
|
|
18 |
Evaluate the expression ax² + bx + c for a given value of x.
|
19 |
"""
|
20 |
return a * (x ** 2) + b * x + c
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
Evaluate the expression ax² + bx + c for a given value of x.
|
19 |
"""
|
20 |
return a * (x ** 2) + b * x + c
|
21 |
+
|
22 |
+
|
23 |
+
def solve_quadratic(a: float, b: float, c: float) -> str:
|
24 |
+
"""
|
25 |
+
Solve the quadratic equation ax^2 + bx + c = 0.
|
26 |
+
|
27 |
+
Args:
|
28 |
+
a: Coefficient of x^2.
|
29 |
+
b: Coefficient of x.
|
30 |
+
c: Constant term.
|
31 |
+
|
32 |
+
Returns:
|
33 |
+
A string representing the solutions.
|
34 |
+
"""
|
35 |
+
if a == 0:
|
36 |
+
if b == 0:
|
37 |
+
return "Not a valid equation (a and b cannot both be zero)." if c != 0 else "Infinite solutions (0 = 0)"
|
38 |
+
# Linear equation: bx + c = 0 -> x = -c/b
|
39 |
+
return f"Linear equation: x = {-c/b}"
|
40 |
+
|
41 |
+
delta = b**2 - 4*a*c
|
42 |
+
|
43 |
+
if delta > 0:
|
44 |
+
x1 = (-b + delta**0.5) / (2*a)
|
45 |
+
x2 = (-b - delta**0.5) / (2*a)
|
46 |
+
return f"Two distinct real roots: x1 = {x1}, x2 = {x2}"
|
47 |
+
elif delta == 0:
|
48 |
+
x1 = -b / (2*a)
|
49 |
+
return f"One real root (repeated): x = {x1}"
|
50 |
+
else: # delta < 0
|
51 |
+
real_part = -b / (2*a)
|
52 |
+
imag_part = (-delta)**0.5 / (2*a)
|
53 |
+
return f"Two complex roots: x1 = {real_part} + {imag_part}i, x2 = {real_part} - {imag_part}i"
|
54 |
+
|
55 |
+
|
56 |
+
def simplify_radical(number: int) -> str:
|
57 |
+
"""
|
58 |
+
Simplifies a radical (square root) to its simplest form (e.g., sqrt(12) -> 2*sqrt(3)).
|
59 |
+
|
60 |
+
Args:
|
61 |
+
number: The number under the radical.
|
62 |
+
|
63 |
+
Returns:
|
64 |
+
A string representing the simplified radical.
|
65 |
+
"""
|
66 |
+
if not isinstance(number, int):
|
67 |
+
return "Input must be an integer."
|
68 |
+
if number < 0:
|
69 |
+
return "Cannot simplify the square root of a negative number with this function."
|
70 |
+
if number == 0:
|
71 |
+
return "0"
|
72 |
+
|
73 |
+
i = 2
|
74 |
+
factor = 1
|
75 |
+
remaining = number
|
76 |
+
while i * i <= remaining:
|
77 |
+
if remaining % (i * i) == 0:
|
78 |
+
factor *= i
|
79 |
+
remaining //= (i*i)
|
80 |
+
# Restart checking with the same i in case of factors like i^4, i^6 etc.
|
81 |
+
continue
|
82 |
+
i += 1
|
83 |
+
|
84 |
+
if factor == 1:
|
85 |
+
return f"sqrt({remaining})"
|
86 |
+
if remaining == 1:
|
87 |
+
return str(factor)
|
88 |
+
return f"{factor}*sqrt({remaining})"
|
89 |
+
|
90 |
+
|
91 |
+
def polynomial_operations(poly1_coeffs: list[float], poly2_coeffs: list[float], operation: str) -> str:
|
92 |
+
"""
|
93 |
+
Performs addition, subtraction, or multiplication of two polynomials.
|
94 |
+
Polynomials are represented by lists of coefficients in descending order of power.
|
95 |
+
Example: [1, -2, 3] represents x^2 - 2x + 3.
|
96 |
+
|
97 |
+
Args:
|
98 |
+
poly1_coeffs: Coefficients of the first polynomial.
|
99 |
+
poly2_coeffs: Coefficients of the second polynomial.
|
100 |
+
operation: "add", "subtract", or "multiply".
|
101 |
+
|
102 |
+
Returns:
|
103 |
+
A string representing the resulting polynomial or an error message.
|
104 |
+
"""
|
105 |
+
if not all(isinstance(c, (int, float)) for c in poly1_coeffs) or \
|
106 |
+
not all(isinstance(c, (int, float)) for c in poly2_coeffs):
|
107 |
+
return "Error: All coefficients must be numbers."
|
108 |
+
|
109 |
+
if not poly1_coeffs:
|
110 |
+
poly1_coeffs = [0]
|
111 |
+
if not poly2_coeffs:
|
112 |
+
poly2_coeffs = [0]
|
113 |
+
|
114 |
+
op = operation.lower()
|
115 |
+
|
116 |
+
if op == "add":
|
117 |
+
len1, len2 = len(poly1_coeffs), len(poly2_coeffs)
|
118 |
+
max_len = max(len1, len2)
|
119 |
+
p1 = [0]*(max_len - len1) + poly1_coeffs
|
120 |
+
p2 = [0]*(max_len - len2) + poly2_coeffs
|
121 |
+
result_coeffs = [p1[i] + p2[i] for i in range(max_len)]
|
122 |
+
elif op == "subtract":
|
123 |
+
len1, len2 = len(poly1_coeffs), len(poly2_coeffs)
|
124 |
+
max_len = max(len1, len2)
|
125 |
+
p1 = [0]*(max_len - len1) + poly1_coeffs
|
126 |
+
p2 = [0]*(max_len - len2) + poly2_coeffs
|
127 |
+
result_coeffs = [p1[i] - p2[i] for i in range(max_len)]
|
128 |
+
elif op == "multiply":
|
129 |
+
len1, len2 = len(poly1_coeffs), len(poly2_coeffs)
|
130 |
+
result_coeffs = [0] * (len1 + len2 - 1)
|
131 |
+
for i in range(len1):
|
132 |
+
for j in range(len2):
|
133 |
+
result_coeffs[i+j] += poly1_coeffs[i] * poly2_coeffs[j]
|
134 |
+
else:
|
135 |
+
return "Error: Invalid operation. Choose 'add', 'subtract', or 'multiply'."
|
136 |
+
|
137 |
+
# Format the result string
|
138 |
+
if not result_coeffs or all(c == 0 for c in result_coeffs):
|
139 |
+
return "0"
|
140 |
+
|
141 |
+
terms = []
|
142 |
+
degree = len(result_coeffs) - 1
|
143 |
+
for i, coeff in enumerate(result_coeffs):
|
144 |
+
power = degree - i
|
145 |
+
if coeff == 0:
|
146 |
+
continue
|
147 |
+
|
148 |
+
term_coeff = ""
|
149 |
+
if coeff == 1 and power != 0:
|
150 |
+
term_coeff = ""
|
151 |
+
elif coeff == -1 and power != 0:
|
152 |
+
term_coeff = "-"
|
153 |
+
else:
|
154 |
+
term_coeff = str(coeff)
|
155 |
+
if isinstance(coeff, float) and coeff.is_integer():
|
156 |
+
term_coeff = str(int(coeff))
|
157 |
+
|
158 |
+
|
159 |
+
if power == 0:
|
160 |
+
terms.append(term_coeff)
|
161 |
+
elif power == 1:
|
162 |
+
terms.append(f"{term_coeff}x")
|
163 |
+
else:
|
164 |
+
terms.append(f"{term_coeff}x^{power}")
|
165 |
+
|
166 |
+
# Join terms, handling signs
|
167 |
+
result_str = terms[0]
|
168 |
+
for term in terms[1:]:
|
169 |
+
if term.startswith("-"):
|
170 |
+
result_str += f" - {term[1:]}"
|
171 |
+
else:
|
172 |
+
result_str += f" + {term}"
|
173 |
+
|
174 |
+
return result_str
|
maths/middleschool/algebra_interface.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
import gradio as gr
|
2 |
-
from maths.middleschool.algebra import solve_linear_equation, evaluate_expression
|
3 |
|
4 |
# Middle School Math Tab
|
5 |
solve_linear_equation_interface = gr.Interface(
|
@@ -25,3 +25,45 @@ evaluate_expression_interface = gr.Interface(
|
|
25 |
title="Quadratic Expression Evaluator",
|
26 |
description="Evaluate ax² + bx + c for a given value of x"
|
27 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
from maths.middleschool.algebra import solve_linear_equation, evaluate_expression, solve_quadratic, simplify_radical, polynomial_operations
|
3 |
|
4 |
# Middle School Math Tab
|
5 |
solve_linear_equation_interface = gr.Interface(
|
|
|
25 |
title="Quadratic Expression Evaluator",
|
26 |
description="Evaluate ax² + bx + c for a given value of x"
|
27 |
)
|
28 |
+
|
29 |
+
solve_quadratic_interface = gr.Interface(
|
30 |
+
fn=solve_quadratic,
|
31 |
+
inputs=[
|
32 |
+
gr.Number(label="a (coefficient of x²)"),
|
33 |
+
gr.Number(label="b (coefficient of x)"),
|
34 |
+
gr.Number(label="c (constant)")
|
35 |
+
],
|
36 |
+
outputs="text",
|
37 |
+
title="Quadratic Equation Solver",
|
38 |
+
description="Solve ax² + bx + c = 0"
|
39 |
+
)
|
40 |
+
|
41 |
+
simplify_radical_interface = gr.Interface(
|
42 |
+
fn=simplify_radical,
|
43 |
+
inputs=gr.Number(label="Number under radical", precision=0),
|
44 |
+
outputs="text",
|
45 |
+
title="Radical Simplifier",
|
46 |
+
description="Simplify a square root (e.g., √12 → 2√3)"
|
47 |
+
)
|
48 |
+
|
49 |
+
def parse_coeffs(coeff_str: str) -> list[float]:
|
50 |
+
"""Helper to parse comma-separated coefficients from string input."""
|
51 |
+
if not coeff_str.strip():
|
52 |
+
return [0.0]
|
53 |
+
try:
|
54 |
+
return [float(x.strip()) for x in coeff_str.split(',') if x.strip() != '']
|
55 |
+
except ValueError:
|
56 |
+
# Return a value that indicates error, Gradio will show the function's error message
|
57 |
+
raise gr.Error("Invalid coefficient input. Ensure all coefficients are numbers.")
|
58 |
+
|
59 |
+
polynomial_interface = gr.Interface(
|
60 |
+
fn=lambda p1_str, p2_str, op: polynomial_operations(parse_coeffs(p1_str), parse_coeffs(p2_str), op),
|
61 |
+
inputs=[
|
62 |
+
gr.Textbox(label="Polynomial 1 Coefficients (comma-separated, highest power first, e.g., 1,-2,3 for x²-2x+3)"),
|
63 |
+
gr.Textbox(label="Polynomial 2 Coefficients (comma-separated, e.g., 2,5 for 2x+5)"),
|
64 |
+
gr.Radio(choices=["add", "subtract", "multiply"], label="Operation")
|
65 |
+
],
|
66 |
+
outputs="text",
|
67 |
+
title="Polynomial Operations",
|
68 |
+
description="Add, subtract, or multiply two polynomials. Enter coefficients in descending order of power."
|
69 |
+
)
|
maths/university/__init__.py
CHANGED
@@ -1,3 +1,9 @@
|
|
1 |
"""
|
2 |
University level mathematics tools.
|
3 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
"""
|
2 |
University level mathematics tools.
|
3 |
"""
|
4 |
+
|
5 |
+
from . import calculus
|
6 |
+
from . import linear_algebra
|
7 |
+
from . import linear_algebra_interface
|
8 |
+
from . import differential_equations
|
9 |
+
from . import differential_equations_interface
|
maths/university/calculus.py
CHANGED
@@ -1,6 +1,13 @@
|
|
1 |
"""
|
2 |
Calculus operations for university level.
|
3 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
def derivative_polynomial(coefficients):
|
6 |
"""
|
@@ -40,3 +47,219 @@ def integral_polynomial(coefficients, c=0):
|
|
40 |
|
41 |
result.append(c) # Add integration constant
|
42 |
return result
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
"""
|
2 |
Calculus operations for university level.
|
3 |
"""
|
4 |
+
import sympy
|
5 |
+
from sympy.parsing.mathematica import parse_mathematica # For robust expression parsing
|
6 |
+
import numpy as np # For numerical integration if needed
|
7 |
+
from typing import Callable, Union, List, Tuple
|
8 |
+
|
9 |
+
# Define common symbols for Sympy expressions
|
10 |
+
x, y, z, n = sympy.symbols('x y z n')
|
11 |
|
12 |
def derivative_polynomial(coefficients):
|
13 |
"""
|
|
|
47 |
|
48 |
result.append(c) # Add integration constant
|
49 |
return result
|
50 |
+
|
51 |
+
|
52 |
+
def calculate_limit(expression_str: str, variable_str: str, point_str: str, direction: str = '+-') -> str:
|
53 |
+
"""
|
54 |
+
Calculates the limit of an expression as a variable approaches a point.
|
55 |
+
|
56 |
+
Args:
|
57 |
+
expression_str: The mathematical expression as a string (e.g., "sin(x)/x").
|
58 |
+
variable_str: The variable in the expression (e.g., "x").
|
59 |
+
point_str: The point the variable is approaching (e.g., "0", "oo" for infinity).
|
60 |
+
direction: The direction of the limit ('+', '-', or '+-' for both sides).
|
61 |
+
|
62 |
+
Returns:
|
63 |
+
A string representing the limit or an error message.
|
64 |
+
"""
|
65 |
+
try:
|
66 |
+
var = sympy.Symbol(variable_str)
|
67 |
+
# Using parse_mathematica for more robust parsing than sympify alone
|
68 |
+
# It handles expressions like "x^2" or "sin(x)" more reliably.
|
69 |
+
# We need to ensure the variable is available in the parsing context.
|
70 |
+
local_dict = {variable_str: var}
|
71 |
+
expr = sympy.parse_expr(expression_str, local_dict=local_dict, transformations='all')
|
72 |
+
|
73 |
+
|
74 |
+
if point_str.lower() == 'oo':
|
75 |
+
point = sympy.oo
|
76 |
+
elif point_str.lower() == '-oo':
|
77 |
+
point = -sympy.oo
|
78 |
+
else:
|
79 |
+
point = sympy.sympify(point_str) # Can handle numbers or symbolic constants like pi
|
80 |
+
|
81 |
+
if direction == '+':
|
82 |
+
limit_val = sympy.limit(expr, var, point, dir='+')
|
83 |
+
elif direction == '-':
|
84 |
+
limit_val = sympy.limit(expr, var, point, dir='-')
|
85 |
+
else: # '+-' default
|
86 |
+
limit_val = sympy.limit(expr, var, point)
|
87 |
+
|
88 |
+
return str(limit_val)
|
89 |
+
except (sympy.SympifyError, TypeError, SyntaxError) as e:
|
90 |
+
return f"Error parsing expression or point: {e}. Ensure valid Sympy syntax (e.g. use 'oo' for infinity, ensure variables match)."
|
91 |
+
except Exception as e:
|
92 |
+
return f"An unexpected error occurred: {e}"
|
93 |
+
|
94 |
+
|
95 |
+
def taylor_series_expansion(expression_str: str, variable_str: str = 'x', point: float = 0, order: int = 5) -> str:
|
96 |
+
"""
|
97 |
+
Computes the Taylor series expansion of an expression around a point.
|
98 |
+
|
99 |
+
Args:
|
100 |
+
expression_str: The mathematical expression (e.g., "exp(x)").
|
101 |
+
variable_str: The variable (default 'x').
|
102 |
+
point: The point around which to expand (default 0).
|
103 |
+
order: The order of the Taylor polynomial (default 5).
|
104 |
+
|
105 |
+
Returns:
|
106 |
+
String representation of the Taylor series or an error message.
|
107 |
+
"""
|
108 |
+
try:
|
109 |
+
var = sympy.Symbol(variable_str)
|
110 |
+
expr = sympy.parse_expr(expression_str, local_dict={variable_str: var}, transformations='all')
|
111 |
+
|
112 |
+
series = expr.series(var, x0=point, n=order).removeO() # removeO() removes the O(x^n) term
|
113 |
+
return str(series)
|
114 |
+
except Exception as e:
|
115 |
+
return f"Error calculating Taylor series: {e}"
|
116 |
+
|
117 |
+
|
118 |
+
def fourier_series_example(function_type: str = "sawtooth", n_terms: int = 5) -> str:
|
119 |
+
"""
|
120 |
+
Provides an example of a Fourier series for a predefined function.
|
121 |
+
|
122 |
+
Args:
|
123 |
+
function_type: "sawtooth" or "square" wave.
|
124 |
+
n_terms: Number of terms to compute in the series.
|
125 |
+
|
126 |
+
Returns:
|
127 |
+
String representation of the Fourier series.
|
128 |
+
"""
|
129 |
+
try:
|
130 |
+
L = sympy.pi # Periodicity, assuming 2L = 2*pi for standard examples
|
131 |
+
x = sympy.Symbol('x')
|
132 |
+
|
133 |
+
if function_type == "sawtooth":
|
134 |
+
# Sawtooth wave: f(x) = x for -pi < x < pi
|
135 |
+
# a0 = 0
|
136 |
+
# an = 0
|
137 |
+
# bn = 2/L * integral(x*sin(n*pi*x/L), (x, 0, L))
|
138 |
+
# = 2/pi * integral(x*sin(nx), (x, 0, pi))
|
139 |
+
# = 2/pi * [-x*cos(nx)/n + sin(nx)/n^2]_0^pi
|
140 |
+
# = 2/pi * [-pi*cos(n*pi)/n] = -2*(-1)^n / n
|
141 |
+
series = sympy.S(0) # Start with zero, as a0 is 0
|
142 |
+
for i in range(1, n_terms + 1):
|
143 |
+
bn_coeff = -2 * ((-1)**i) / i
|
144 |
+
series += bn_coeff * sympy.sin(i * x)
|
145 |
+
return f"Fourier series for sawtooth wave (f(x)=x on [-pi,pi]): {str(series)}"
|
146 |
+
|
147 |
+
elif function_type == "square":
|
148 |
+
# Square wave: f(x) = -1 for -pi < x < 0, f(x) = 1 for 0 < x < pi
|
149 |
+
# a0 = 0
|
150 |
+
# an = 0
|
151 |
+
# bn = 2/L * integral(f(x)*sin(n*pi*x/L), (x, 0, L)) where f(x)=1 for (0,L)
|
152 |
+
# = 2/pi * integral(sin(nx), (x, 0, pi))
|
153 |
+
# = 2/pi * [-cos(nx)/n]_0^pi
|
154 |
+
# = 2/(n*pi) * (1 - cos(n*pi)) = 2/(n*pi) * (1 - (-1)^n)
|
155 |
+
# This means bn is 4/(n*pi) if n is odd, and 0 if n is even.
|
156 |
+
series = sympy.S(0)
|
157 |
+
for i in range(1, n_terms + 1):
|
158 |
+
if i % 2 != 0: # n is odd
|
159 |
+
bn_coeff = 4 / (i * sympy.pi)
|
160 |
+
series += bn_coeff * sympy.sin(i * x)
|
161 |
+
return f"Fourier series for square wave (f(x)=1 on [0,pi], -1 on [-pi,0]): {str(series)}"
|
162 |
+
else:
|
163 |
+
return "Error: Unknown function type for Fourier series. Choose 'sawtooth' or 'square'."
|
164 |
+
except Exception as e:
|
165 |
+
return f"Error generating Fourier series: {e}"
|
166 |
+
|
167 |
+
|
168 |
+
def partial_derivative(expression_str: str, variables_str: List[str]) -> str:
|
169 |
+
"""
|
170 |
+
Computes partial derivatives of an expression with respect to specified variables.
|
171 |
+
Example: expression_str="x**2*y**3", variables_str=["x", "y"] will compute d/dx then d/dy.
|
172 |
+
If you want d^2f/dx^2, use variables_str=["x", "x"].
|
173 |
+
|
174 |
+
Args:
|
175 |
+
expression_str: The mathematical expression (e.g., "x**2*y + y*z**2").
|
176 |
+
variables_str: A list of variable names (strings) to differentiate by, in order.
|
177 |
+
(e.g., ["x", "y"] for d/dy(d/dx(expr)) ).
|
178 |
+
|
179 |
+
Returns:
|
180 |
+
String representation of the partial derivative or an error message.
|
181 |
+
"""
|
182 |
+
try:
|
183 |
+
# Create symbols for all variables that might appear in the expression
|
184 |
+
# For simplicity, we'll assume 'x', 'y', 'z' are common, but user must specify which to derive by.
|
185 |
+
# A more robust way would be to parse expression_str to find all symbols.
|
186 |
+
symbols_in_expr = {s.name: s for s in sympy.parse_expr(expression_str, transformations='all').free_symbols}
|
187 |
+
|
188 |
+
# Ensure variables to differentiate by are symbols
|
189 |
+
diff_vars_symbols = []
|
190 |
+
for var_name in variables_str:
|
191 |
+
if var_name in symbols_in_expr:
|
192 |
+
diff_vars_symbols.append(symbols_in_expr[var_name])
|
193 |
+
else:
|
194 |
+
# If a variable to differentiate by is not in free_symbols, it might be an error
|
195 |
+
# or the expression doesn't depend on it. Sympy handles differentiation by non-present vars as 0.
|
196 |
+
# We'll create the symbol anyway to pass to diff.
|
197 |
+
diff_vars_symbols.append(sympy.Symbol(var_name))
|
198 |
+
|
199 |
+
if not diff_vars_symbols:
|
200 |
+
return "Error: No variables specified for differentiation."
|
201 |
+
|
202 |
+
expr = sympy.parse_expr(expression_str, local_dict=symbols_in_expr, transformations='all')
|
203 |
+
|
204 |
+
# Compute partial derivatives iteratively
|
205 |
+
current_expr = expr
|
206 |
+
for var_sym in diff_vars_symbols:
|
207 |
+
current_expr = sympy.diff(current_expr, var_sym)
|
208 |
+
|
209 |
+
return str(current_expr)
|
210 |
+
except (sympy.SympifyError, TypeError, SyntaxError) as e:
|
211 |
+
return f"Error parsing expression or variables: {e}"
|
212 |
+
except Exception as e:
|
213 |
+
return f"An unexpected error occurred: {e}"
|
214 |
+
|
215 |
+
|
216 |
+
def multiple_integral(expression_str: str, integration_vars: List[Tuple[str, Union[str, float], Union[str, float]]]) -> str:
|
217 |
+
"""
|
218 |
+
Computes definite multiple integrals.
|
219 |
+
|
220 |
+
Args:
|
221 |
+
expression_str: The mathematical expression (e.g., "x*y**2").
|
222 |
+
integration_vars: A list of tuples, where each tuple contains:
|
223 |
+
(variable_name_str, lower_bound_str_or_float, upper_bound_str_or_float).
|
224 |
+
Example: [("x", "0", "1"), ("y", "0", "x")] for integral from 0 to 1 dx of integral from 0 to x dy of (x*y**2).
|
225 |
+
The order in the list is the order of integration (inner to outer).
|
226 |
+
|
227 |
+
Returns:
|
228 |
+
String representation of the integral result or an error message.
|
229 |
+
"""
|
230 |
+
try:
|
231 |
+
# Create symbols for all variables that might appear in the expression
|
232 |
+
# and in integration bounds.
|
233 |
+
# A more robust approach might involve parsing the expression and bounds to identify all symbols.
|
234 |
+
# For now, we'll default to x, y, z if not explicitly mentioned.
|
235 |
+
present_symbols = {sym.name: sym for sym in sympy.parse_expr(expression_str, transformations='all').free_symbols}
|
236 |
+
|
237 |
+
integration_params_sympy = []
|
238 |
+
for var_name, lower_bound, upper_bound in integration_vars:
|
239 |
+
var_sym = present_symbols.get(var_name, sympy.Symbol(var_name))
|
240 |
+
if var_sym.name not in present_symbols: # Add if it wasn't in expression but is an integration var
|
241 |
+
present_symbols[var_sym.name] = var_sym
|
242 |
+
|
243 |
+
# Bounds can be numbers or expressions involving other variables
|
244 |
+
# We need to ensure these other variables are available in the context for sympify
|
245 |
+
lower_s = sympy.sympify(lower_bound, locals=present_symbols)
|
246 |
+
upper_s = sympy.sympify(upper_bound, locals=present_symbols)
|
247 |
+
integration_params_sympy.append((var_sym, lower_s, upper_s))
|
248 |
+
|
249 |
+
if not integration_params_sympy:
|
250 |
+
return "Error: No integration variables specified."
|
251 |
+
|
252 |
+
expr = sympy.parse_expr(expression_str, local_dict=present_symbols, transformations='all')
|
253 |
+
|
254 |
+
# Compute integral iteratively (from inner to outer)
|
255 |
+
# Sympy's integrate function takes tuples like (x, 0, 1)
|
256 |
+
# The order of integration_params_sympy should be from inner-most integral to outer-most.
|
257 |
+
# For sympy.integrate, the order of variables in the list is outer to inner. So we reverse.
|
258 |
+
|
259 |
+
integral_val = sympy.integrate(expr, *integration_params_sympy)
|
260 |
+
|
261 |
+
return str(integral_val)
|
262 |
+
except (sympy.SympifyError, TypeError, SyntaxError) as e:
|
263 |
+
return f"Error parsing expression, bounds, or variables: {e}. Ensure bounds are numbers or valid expressions."
|
264 |
+
except Exception as e:
|
265 |
+
return f"An unexpected error occurred during integration: {e}"
|
maths/university/calculus_interface.py
CHANGED
@@ -1,5 +1,10 @@
|
|
1 |
import gradio as gr
|
2 |
-
from maths.university.calculus import
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
# University Math Tab
|
5 |
derivative_interface = gr.Interface(
|
@@ -20,3 +25,101 @@ integral_interface = gr.Interface(
|
|
20 |
title="Polynomial Integration",
|
21 |
description="Find the indefinite integral of a polynomial function"
|
22 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
from maths.university.calculus import (
|
3 |
+
derivative_polynomial, integral_polynomial,
|
4 |
+
calculate_limit, taylor_series_expansion, fourier_series_example,
|
5 |
+
partial_derivative, multiple_integral
|
6 |
+
)
|
7 |
+
import json # For parsing list of tuples for multiple integrals
|
8 |
|
9 |
# University Math Tab
|
10 |
derivative_interface = gr.Interface(
|
|
|
25 |
title="Polynomial Integration",
|
26 |
description="Find the indefinite integral of a polynomial function"
|
27 |
)
|
28 |
+
|
29 |
+
limit_interface = gr.Interface(
|
30 |
+
fn=calculate_limit,
|
31 |
+
inputs=[
|
32 |
+
gr.Textbox(label="Expression (e.g., sin(x)/x, x**2, exp(-x))", value="sin(x)/x"),
|
33 |
+
gr.Textbox(label="Variable (e.g., x)", value="x"),
|
34 |
+
gr.Textbox(label="Point (e.g., 0, 1, oo, -oo, pi)", value="0"),
|
35 |
+
gr.Radio(choices=["+-", "+", "-"], label="Direction", value="+-")
|
36 |
+
],
|
37 |
+
outputs="text",
|
38 |
+
title="Limit Calculator",
|
39 |
+
description="Calculate the limit of an expression. Use 'oo' for infinity. Ensure variable in expression matches variable input."
|
40 |
+
)
|
41 |
+
|
42 |
+
taylor_series_interface = gr.Interface(
|
43 |
+
fn=taylor_series_expansion,
|
44 |
+
inputs=[
|
45 |
+
gr.Textbox(label="Expression (e.g., exp(x), cos(x))", value="exp(x)"),
|
46 |
+
gr.Textbox(label="Variable (e.g., x)", value="x"),
|
47 |
+
gr.Number(label="Point of Expansion (x0)", value=0),
|
48 |
+
gr.Number(label="Order of Taylor Polynomial", value=5, precision=0)
|
49 |
+
],
|
50 |
+
outputs="text",
|
51 |
+
title="Taylor Series Expansion",
|
52 |
+
description="Compute the Taylor series of a function around a point."
|
53 |
+
)
|
54 |
+
|
55 |
+
fourier_series_interface = gr.Interface(
|
56 |
+
fn=fourier_series_example,
|
57 |
+
inputs=[
|
58 |
+
gr.Radio(choices=["sawtooth", "square"], label="Function Type", value="sawtooth"),
|
59 |
+
gr.Number(label="Number of Terms (n)", value=5, precision=0)
|
60 |
+
],
|
61 |
+
outputs="text",
|
62 |
+
title="Fourier Series Examples",
|
63 |
+
description="Generate Fourier series for predefined periodic functions (e.g., sawtooth wave, square wave)."
|
64 |
+
)
|
65 |
+
|
66 |
+
partial_derivative_interface = gr.Interface(
|
67 |
+
fn=lambda expr_str, vars_str: partial_derivative(expr_str, [v.strip() for v in vars_str.split(',') if v.strip()]),
|
68 |
+
inputs=[
|
69 |
+
gr.Textbox(label="Expression (e.g., x**2*y + z*sin(x))", value="x**2*y**3 + y*sin(z)"),
|
70 |
+
gr.Textbox(label="Variables to differentiate by (comma-separated, in order, e.g., x,y or x,x for d²f/dx²)", value="x,y")
|
71 |
+
],
|
72 |
+
outputs="text",
|
73 |
+
title="Partial Derivative Calculator",
|
74 |
+
description="Compute partial derivatives. For d/dy(d/dx(f)), enter 'x,y'. For d²f/dx², enter 'x,x'."
|
75 |
+
)
|
76 |
+
|
77 |
+
|
78 |
+
def parse_integration_variables(vars_json_str: str):
|
79 |
+
"""
|
80 |
+
Parses a JSON string representing a list of integration variables and their bounds.
|
81 |
+
Expected format: '[["var1", "lower1", "upper1"], ["var2", "lower2", "upper2"]]'
|
82 |
+
"""
|
83 |
+
try:
|
84 |
+
parsed_list = json.loads(vars_json_str)
|
85 |
+
if not isinstance(parsed_list, list):
|
86 |
+
raise ValueError("Input must be a JSON list.")
|
87 |
+
|
88 |
+
integration_vars = []
|
89 |
+
for item in parsed_list:
|
90 |
+
if not (isinstance(item, list) and len(item) == 3):
|
91 |
+
raise ValueError("Each item in the list must be a sub-list of three elements: [variable_name, lower_bound, upper_bound].")
|
92 |
+
|
93 |
+
var, low, upp = item
|
94 |
+
if not isinstance(var, str):
|
95 |
+
raise ValueError(f"Variable name '{var}' must be a string.")
|
96 |
+
|
97 |
+
# Bounds can be numbers or strings (for symbolic bounds like 'x')
|
98 |
+
if not (isinstance(low, (str, int, float))) or not (isinstance(upp, (str, int, float))):
|
99 |
+
raise ValueError(f"Bounds for variable '{var}' must be numbers or strings. Got '{low}' and '{upp}'.")
|
100 |
+
|
101 |
+
integration_vars.append((str(var), str(low), str(upp))) # Ensure all elements are strings for sympy processing later if needed
|
102 |
+
|
103 |
+
return integration_vars
|
104 |
+
except json.JSONDecodeError:
|
105 |
+
raise gr.Error("Invalid JSON format for integration variables.")
|
106 |
+
except ValueError as ve:
|
107 |
+
raise gr.Error(str(ve))
|
108 |
+
except Exception as e:
|
109 |
+
raise gr.Error(f"Unexpected error parsing integration variables: {str(e)}")
|
110 |
+
|
111 |
+
|
112 |
+
multiple_integral_interface = gr.Interface(
|
113 |
+
fn=lambda expr_str, vars_json_str: multiple_integral(expr_str, parse_integration_variables(vars_json_str)),
|
114 |
+
inputs=[
|
115 |
+
gr.Textbox(label="Expression (e.g., x*y**2, 1)", value="x*y"),
|
116 |
+
gr.Textbox(
|
117 |
+
label="Integration Variables and Bounds (JSON list of [var, low, upp] tuples, inner to outer)",
|
118 |
+
value='[["y", "0", "x"], ["x", "0", "1"]]' ,
|
119 |
+
info="Example: integral_0^1 dx integral_0^x dy (x*y) is [['y', '0', 'x'], ['x', '0', '1']]"
|
120 |
+
)
|
121 |
+
],
|
122 |
+
outputs="text",
|
123 |
+
title="Definite Multiple Integral Calculator",
|
124 |
+
description="Compute multiple integrals. Order of variables in list: inner-most integral first."
|
125 |
+
)
|
maths/university/differential_equations.py
ADDED
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Differential Equations solvers for university level, using SciPy.
|
3 |
+
"""
|
4 |
+
import numpy as np
|
5 |
+
from scipy.integrate import solve_ivp
|
6 |
+
from typing import Callable, List, Tuple, Dict, Any, Union
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
|
9 |
+
# Type hint for the function defining the ODE system
|
10 |
+
ODEFunc = Callable[[float, Union[np.ndarray, List[float]]], Union[np.ndarray, List[float]]]
|
11 |
+
|
12 |
+
def solve_first_order_ode(
|
13 |
+
ode_func: ODEFunc,
|
14 |
+
t_span: Tuple[float, float],
|
15 |
+
y0: List[float],
|
16 |
+
t_eval_count: int = 100,
|
17 |
+
method: str = 'RK45',
|
18 |
+
**kwargs: Any
|
19 |
+
) -> Dict[str, Union[np.ndarray, str, bool]]:
|
20 |
+
"""
|
21 |
+
Solves a first-order ordinary differential equation (or system of first-order ODEs)
|
22 |
+
dy/dt = f(t, y) with initial condition y(t0) = y0.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
ode_func: Callable f(t, y). `t` is a scalar, `y` is an ndarray of shape (n,).
|
26 |
+
It should return an array-like object of shape (n,).
|
27 |
+
Example for dy/dt = -2*y: lambda t, y: -2 * y
|
28 |
+
Example for system dy1/dt = y2, dy2/dt = -y1: lambda t, y: [y[1], -y[0]]
|
29 |
+
t_span: Tuple (t_start, t_end) for the integration interval.
|
30 |
+
y0: List or NumPy array of initial conditions y(t_start).
|
31 |
+
t_eval_count: Number of points at which to store the computed solution.
|
32 |
+
method: Integration method to use (e.g., 'RK45', 'LSODA', 'BDF').
|
33 |
+
**kwargs: Additional keyword arguments to pass to `solve_ivp` (e.g., `rtol`, `atol`).
|
34 |
+
|
35 |
+
Returns:
|
36 |
+
A dictionary containing:
|
37 |
+
't': NumPy array of time points.
|
38 |
+
'y': NumPy array of solution values at each time point. (y[i] corresponds to t[i])
|
39 |
+
For systems, y is an array with n_equations rows and n_timepoints columns.
|
40 |
+
'message': Solver message.
|
41 |
+
'success': Boolean indicating if the solver was successful.
|
42 |
+
'plot_path': Path to a saved plot of the solution (if successful), else None.
|
43 |
+
"""
|
44 |
+
try:
|
45 |
+
y0_np = np.array(y0, dtype=float)
|
46 |
+
t_eval = np.linspace(t_span[0], t_span[1], t_eval_count)
|
47 |
+
|
48 |
+
sol = solve_ivp(ode_func, t_span, y0_np, method=method, t_eval=t_eval, **kwargs)
|
49 |
+
|
50 |
+
plot_path = None
|
51 |
+
if sol.success:
|
52 |
+
try:
|
53 |
+
plt.figure(figsize=(10, 6))
|
54 |
+
if y0_np.ndim == 0 or len(y0_np) == 1 : # Single equation
|
55 |
+
plt.plot(sol.t, sol.y[0], label=f'y(t), y0={y0_np[0] if y0_np.ndim > 0 else y0_np}')
|
56 |
+
else: # System of equations
|
57 |
+
for i in range(sol.y.shape[0]):
|
58 |
+
plt.plot(sol.t, sol.y[i], label=f'y_{i+1}(t), y0_{i+1}={y0_np[i]}')
|
59 |
+
plt.xlabel("Time (t)")
|
60 |
+
plt.ylabel("Solution y(t)")
|
61 |
+
plt.title(f"Solution of First-Order ODE ({method})")
|
62 |
+
plt.legend()
|
63 |
+
plt.grid(True)
|
64 |
+
plot_path = "ode_solution_plot.png"
|
65 |
+
plt.savefig(plot_path)
|
66 |
+
plt.close() # Close the plot to free memory
|
67 |
+
except Exception as e_plot:
|
68 |
+
print(f"Warning: Could not generate plot: {e_plot}")
|
69 |
+
plot_path = None
|
70 |
+
|
71 |
+
|
72 |
+
return {
|
73 |
+
't': sol.t,
|
74 |
+
'y': sol.y,
|
75 |
+
'message': sol.message,
|
76 |
+
'success': sol.success,
|
77 |
+
'plot_path': plot_path
|
78 |
+
}
|
79 |
+
except Exception as e:
|
80 |
+
return {
|
81 |
+
't': np.array([]),
|
82 |
+
'y': np.array([]),
|
83 |
+
'message': f"Error during ODE solving: {str(e)}",
|
84 |
+
'success': False,
|
85 |
+
'plot_path': None
|
86 |
+
}
|
87 |
+
|
88 |
+
def solve_second_order_ode(
|
89 |
+
ode_func_second_order: Callable[[float, float, float], float],
|
90 |
+
t_span: Tuple[float, float],
|
91 |
+
y0: float,
|
92 |
+
dy0_dt: float,
|
93 |
+
t_eval_count: int = 100,
|
94 |
+
method: str = 'RK45',
|
95 |
+
**kwargs: Any
|
96 |
+
) -> Dict[str, Union[np.ndarray, str, bool]]:
|
97 |
+
"""
|
98 |
+
Solves a single second-order ordinary differential equation of the form
|
99 |
+
d²y/dt² = f(t, y, dy/dt) with initial conditions y(t0)=y0 and dy/dt(t0)=dy0_dt.
|
100 |
+
|
101 |
+
This is done by converting the second-order ODE into a system of two first-order ODEs:
|
102 |
+
Let z1 = y and z2 = dy/dt.
|
103 |
+
Then dz1/dt = z2
|
104 |
+
And dz2/dt = d²y/dt² = f(t, z1, z2)
|
105 |
+
|
106 |
+
Args:
|
107 |
+
ode_func_second_order: Callable f(t, y, dy_dt). `t` is scalar, `y` is scalar, `dy_dt` is scalar.
|
108 |
+
It should return the value of d²y/dt².
|
109 |
+
Example for d²y/dt² = -0.1*(dy/dt) - y: lambda t, y, dy_dt: -0.1 * dy_dt - y
|
110 |
+
t_span: Tuple (t_start, t_end) for the integration interval.
|
111 |
+
y0: Initial value of y at t_start.
|
112 |
+
dy0_dt: Initial value of dy/dt at t_start.
|
113 |
+
t_eval_count: Number of points at which to store the computed solution.
|
114 |
+
method: Integration method to use.
|
115 |
+
**kwargs: Additional keyword arguments for `solve_ivp`.
|
116 |
+
|
117 |
+
Returns:
|
118 |
+
A dictionary containing:
|
119 |
+
't': NumPy array of time points.
|
120 |
+
'y': NumPy array of solution values y(t) at each time point.
|
121 |
+
'dy_dt': NumPy array of solution values dy/dt(t) at each time point.
|
122 |
+
'message': Solver message.
|
123 |
+
'success': Boolean indicating if the solver was successful.
|
124 |
+
'plot_path': Path to a saved plot of y(t) and dy/dt(t) (if successful), else None.
|
125 |
+
"""
|
126 |
+
# Define the system of first-order ODEs
|
127 |
+
def system_func(t: float, z: np.ndarray) -> List[float]:
|
128 |
+
y_val, dy_dt_val = z[0], z[1]
|
129 |
+
d2y_dt2_val = ode_func_second_order(t, y_val, dy_dt_val)
|
130 |
+
return [dy_dt_val, d2y_dt2_val]
|
131 |
+
|
132 |
+
initial_conditions_system = [y0, dy0_dt]
|
133 |
+
|
134 |
+
try:
|
135 |
+
t_eval = np.linspace(t_span[0], t_span[1], t_eval_count)
|
136 |
+
sol = solve_ivp(system_func, t_span, initial_conditions_system, method=method, t_eval=t_eval, **kwargs)
|
137 |
+
|
138 |
+
plot_path = None
|
139 |
+
if sol.success:
|
140 |
+
try:
|
141 |
+
plt.figure(figsize=(12, 7))
|
142 |
+
|
143 |
+
plt.subplot(2,1,1)
|
144 |
+
plt.plot(sol.t, sol.y[0], label=f'y(t), y0={y0}')
|
145 |
+
plt.xlabel("Time (t)")
|
146 |
+
plt.ylabel("y(t)")
|
147 |
+
plt.title(f"Solution of Second-Order ODE: y(t) ({method})")
|
148 |
+
plt.legend()
|
149 |
+
plt.grid(True)
|
150 |
+
|
151 |
+
plt.subplot(2,1,2)
|
152 |
+
plt.plot(sol.t, sol.y[1], label=f'dy/dt(t), dy0/dt={dy0_dt}', color='orange')
|
153 |
+
plt.xlabel("Time (t)")
|
154 |
+
plt.ylabel("dy/dt(t)")
|
155 |
+
plt.title(f"Solution of Second-Order ODE: dy/dt(t) ({method})")
|
156 |
+
plt.legend()
|
157 |
+
plt.grid(True)
|
158 |
+
|
159 |
+
plt.tight_layout()
|
160 |
+
plot_path = "ode_second_order_solution_plot.png"
|
161 |
+
plt.savefig(plot_path)
|
162 |
+
plt.close()
|
163 |
+
except Exception as e_plot:
|
164 |
+
print(f"Warning: Could not generate plot: {e_plot}")
|
165 |
+
plot_path = None
|
166 |
+
|
167 |
+
return {
|
168 |
+
't': sol.t,
|
169 |
+
'y': sol.y[0], # First component of the system's solution
|
170 |
+
'dy_dt': sol.y[1], # Second component of the system's solution
|
171 |
+
'message': sol.message,
|
172 |
+
'success': sol.success,
|
173 |
+
'plot_path': plot_path
|
174 |
+
}
|
175 |
+
except Exception as e:
|
176 |
+
return {
|
177 |
+
't': np.array([]),
|
178 |
+
'y': np.array([]),
|
179 |
+
'dy_dt': np.array([]),
|
180 |
+
'message': f"Error during ODE solving: {str(e)}",
|
181 |
+
'success': False,
|
182 |
+
'plot_path': None
|
183 |
+
}
|
184 |
+
|
185 |
+
# Example Usage (can be removed or commented out)
|
186 |
+
if __name__ == '__main__':
|
187 |
+
# --- First-order ODE example: dy/dt = -y*t with y(0)=1 ---
|
188 |
+
def first_order_example(t, y):
|
189 |
+
return -y * t
|
190 |
+
|
191 |
+
print("Solving dy/dt = -y*t, y(0)=1 from t=0 to t=5")
|
192 |
+
solution1 = solve_first_order_ode(first_order_example, (0, 5), [1], t_eval_count=50)
|
193 |
+
if solution1['success']:
|
194 |
+
print(f"First-order ODE solved. Message: {solution1['message']}")
|
195 |
+
# print("t:", solution1['t'])
|
196 |
+
# print("y:", solution1['y'])
|
197 |
+
if solution1['plot_path']:
|
198 |
+
print(f"Plot saved to {solution1['plot_path']}")
|
199 |
+
else:
|
200 |
+
print(f"First-order ODE failed. Message: {solution1['message']}")
|
201 |
+
|
202 |
+
# --- First-order system example: Lotka-Volterra ---
|
203 |
+
# dy1/dt = a*y1 - b*y1*y2 (prey)
|
204 |
+
# dy2/dt = c*y1*y2 - d*y2 (predator)
|
205 |
+
a, b, c, d = 1.5, 0.8, 0.5, 0.9
|
206 |
+
def lotka_volterra(t, y):
|
207 |
+
prey, predator = y[0], y[1]
|
208 |
+
d_prey_dt = a * prey - b * prey * predator
|
209 |
+
d_predator_dt = c * prey * predator - d * predator
|
210 |
+
return [d_prey_dt, d_predator_dt]
|
211 |
+
|
212 |
+
print("\nSolving Lotka-Volterra system from t=0 to t=20 with y0=[10, 5]")
|
213 |
+
solution_lv = solve_first_order_ode(lotka_volterra, (0, 20), [10, 5], t_eval_count=200)
|
214 |
+
if solution_lv['success']:
|
215 |
+
print(f"Lotka-Volterra solved. Plot saved to {solution_lv['plot_path']}")
|
216 |
+
else:
|
217 |
+
print(f"Lotka-Volterra failed. Message: {solution_lv['message']}")
|
218 |
+
|
219 |
+
|
220 |
+
# --- Second-order ODE example: d²y/dt² = -sin(y) (simple pendulum) ---
|
221 |
+
# y is theta, dy/dt is omega. d²y/dt² = -g/L * sin(y)
|
222 |
+
g_L = 9.81 / 1.0 # Example: g/L = 9.81
|
223 |
+
def pendulum_ode(t, y_angle, dy_dt_angular_velocity):
|
224 |
+
return -g_L * np.sin(y_angle)
|
225 |
+
|
226 |
+
print("\nSolving d²y/dt² = -g/L*sin(y), y(0)=pi/4, dy/dt(0)=0 from t=0 to t=10")
|
227 |
+
solution2 = solve_second_order_ode(pendulum_ode, (0, 10), y0=np.pi/4, dy0_dt=0, t_eval_count=100)
|
228 |
+
if solution2['success']:
|
229 |
+
print(f"Second-order ODE solved. Message: {solution2['message']}")
|
230 |
+
# print("t:", solution2['t'])
|
231 |
+
# print("y:", solution2['y'])
|
232 |
+
# print("dy/dt:", solution2['dy_dt'])
|
233 |
+
if solution2['plot_path']:
|
234 |
+
print(f"Plot saved to {solution2['plot_path']}")
|
235 |
+
else:
|
236 |
+
print(f"Second-order ODE failed. Message: {solution2['message']}")
|
maths/university/differential_equations_interface.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Gradio Interface for Differential Equations solvers.
|
3 |
+
|
4 |
+
SECURITY WARNING: These interfaces use eval() to parse user-provided Python lambda strings
|
5 |
+
for defining ODE functions. This is a potential security risk if arbitrary code is entered.
|
6 |
+
Use with caution and only with trusted inputs, especially if deploying this application
|
7 |
+
outside a controlled environment.
|
8 |
+
"""
|
9 |
+
import gradio as gr
|
10 |
+
import numpy as np
|
11 |
+
import matplotlib.pyplot as plt # For displaying plots if not returned directly by solver
|
12 |
+
from typing import Callable, Tuple, List, Dict, Any, Union
|
13 |
+
import ast # For safer string literal evaluation if needed for parameters
|
14 |
+
import math # To make math functions available in eval scope for ODEs
|
15 |
+
|
16 |
+
# Import the solver functions
|
17 |
+
from maths.university.differential_equations import solve_first_order_ode, solve_second_order_ode
|
18 |
+
|
19 |
+
# --- Helper Functions ---
|
20 |
+
|
21 |
+
def parse_float_list(input_str: str, expected_len: int = 0) -> List[float]:
|
22 |
+
"""Parses a comma-separated string of floats into a list."""
|
23 |
+
try:
|
24 |
+
if not input_str.strip():
|
25 |
+
if expected_len > 0: # If expecting specific number, empty is error
|
26 |
+
raise ValueError("Input string is empty.")
|
27 |
+
return [] # Allow empty list if not expecting specific length
|
28 |
+
|
29 |
+
parts = [float(p.strip()) for p in input_str.split(',') if p.strip()]
|
30 |
+
if expected_len > 0 and len(parts) != expected_len:
|
31 |
+
raise ValueError(f"Expected {expected_len} values, but got {len(parts)}.")
|
32 |
+
return parts
|
33 |
+
except ValueError as e:
|
34 |
+
raise gr.Error(f"Invalid format for list of numbers. Use comma-separated floats. Error: {e}")
|
35 |
+
|
36 |
+
def parse_time_span(time_span_str: str) -> Tuple[float, float]:
|
37 |
+
"""Parses a comma-separated string for time span (t_start, t_end)."""
|
38 |
+
parts = parse_float_list(time_span_str, expected_len=2)
|
39 |
+
if parts[0] >= parts[1]:
|
40 |
+
raise gr.Error("t_start must be less than t_end for the time span.")
|
41 |
+
return (parts[0], parts[1])
|
42 |
+
|
43 |
+
def string_to_ode_func(lambda_str: str, expected_args: Tuple[str, ...]) -> Callable:
|
44 |
+
"""
|
45 |
+
Converts a string representation of a Python lambda function into a callable.
|
46 |
+
Includes a basic check for 'lambda' keyword and argument count.
|
47 |
+
|
48 |
+
Args:
|
49 |
+
lambda_str: The string, e.g., "lambda t, y: -2*y" or "lambda t, y, dy_dt: -0.1*dy_dt -y".
|
50 |
+
expected_args: A tuple of expected argument names, e.g., ('t', 'y') or ('t', 'y', 'dy_dt').
|
51 |
+
|
52 |
+
Returns:
|
53 |
+
A callable function.
|
54 |
+
Raises:
|
55 |
+
gr.Error: If the string is not a valid lambda or has wrong argument structure.
|
56 |
+
"""
|
57 |
+
lambda_str = lambda_str.strip()
|
58 |
+
if not lambda_str.startswith("lambda"):
|
59 |
+
raise gr.Error("ODE function must be a Python lambda string (e.g., 'lambda t, y: -y').")
|
60 |
+
|
61 |
+
try:
|
62 |
+
# Basic check for argument names within the lambda definition part (before ':')
|
63 |
+
# This is a heuristic and not a full AST parse for safety here, as eval is the main concern.
|
64 |
+
lambda_def_part = lambda_str.split(":")[0]
|
65 |
+
for arg_name in expected_args:
|
66 |
+
if arg_name not in lambda_def_part:
|
67 |
+
raise gr.Error(f"Lambda function string does not seem to contain expected argument: '{arg_name}'. Expected args: {expected_args}")
|
68 |
+
|
69 |
+
# The eval environment will have access to common math functions and numpy (np)
|
70 |
+
# This is where the security risk lies.
|
71 |
+
safe_eval_globals = {
|
72 |
+
"np": np,
|
73 |
+
"math": math,
|
74 |
+
"sin": math.sin, "cos": math.cos, "tan": math.tan,
|
75 |
+
"exp": math.exp, "log": math.log, "log10": math.log10,
|
76 |
+
"sqrt": math.sqrt, "fabs": math.fabs, "pow": math.pow,
|
77 |
+
"pi": math.pi, "e": math.e
|
78 |
+
}
|
79 |
+
# User must use 'math.func' or 'np.func' for most things not listed above.
|
80 |
+
|
81 |
+
func = eval(lambda_str, safe_eval_globals, {})
|
82 |
+
if not callable(func):
|
83 |
+
raise gr.Error("The provided string did not evaluate to a callable function.")
|
84 |
+
return func
|
85 |
+
except SyntaxError as se:
|
86 |
+
raise gr.Error(f"Syntax error in lambda function: {se}. Ensure it's valid Python syntax.")
|
87 |
+
except Exception as e:
|
88 |
+
raise gr.Error(f"Error evaluating lambda function string: {e}. Ensure it's a valid lambda returning the derivative(s).")
|
89 |
+
|
90 |
+
|
91 |
+
# --- Gradio Interface for First-Order ODEs ---
|
92 |
+
first_order_ode_interface = gr.Interface(
|
93 |
+
fn=lambda ode_str, t_span_str, y0_str, t_eval_count, method: solve_first_order_ode(
|
94 |
+
string_to_ode_func(ode_str, ('t', 'y')),
|
95 |
+
parse_time_span(t_span_str),
|
96 |
+
parse_float_list(y0_str), # y0 can be a list for systems
|
97 |
+
int(t_eval_count),
|
98 |
+
method
|
99 |
+
),
|
100 |
+
inputs=[
|
101 |
+
gr.Textbox(label="ODE Function (lambda t, y: ...)",
|
102 |
+
placeholder="e.g., lambda t, y: -y*t OR for system lambda t, y: [y[1], -0.1*y[1] - y[0]]",
|
103 |
+
info="Define dy/dt or a system [dy1/dt, dy2/dt,...]. `y` is a list/array for systems."),
|
104 |
+
gr.Textbox(label="Time Span (t_start, t_end)", placeholder="e.g., 0,10"),
|
105 |
+
gr.Textbox(label="Initial Condition(s) y(t_start)", placeholder="e.g., 1 OR for system 1,0"),
|
106 |
+
gr.Slider(minimum=10, maximum=1000, value=100, step=10, label="Evaluation Points Count"),
|
107 |
+
gr.Radio(choices=['RK45', 'LSODA', 'BDF', 'RK23', 'DOP853'], value='RK45', label="Solver Method")
|
108 |
+
],
|
109 |
+
outputs=[
|
110 |
+
gr.Image(label="Solution Plot", type="filepath", show_label=True, visible=lambda res: res['success'] and res['plot_path'] is not None),
|
111 |
+
gr.Textbox(label="Solver Message"),
|
112 |
+
gr.Textbox(label="Success Status"),
|
113 |
+
gr.JSON(label="Raw Data (t, y values)", visible=lambda res: res['success']) # For users to copy if needed
|
114 |
+
],
|
115 |
+
title="First-Order ODE Solver",
|
116 |
+
description="Solves dy/dt = f(t,y) or a system of first-order ODEs. " \
|
117 |
+
"WARNING: Uses eval() for the ODE function string - potential security risk. " \
|
118 |
+
"For systems, `y` in lambda is `[y1, y2, ...]`, return `[dy1/dt, dy2/dt, ...]`. " \
|
119 |
+
"Example (Damped Oscillator): ODE: lambda t, y: [y[1], -0.5*y[1] - y[0]], y0: 1,0, Timespan: 0,20",
|
120 |
+
allow_flagging="never"
|
121 |
+
)
|
122 |
+
|
123 |
+
# --- Gradio Interface for Second-Order ODEs ---
|
124 |
+
second_order_ode_interface = gr.Interface(
|
125 |
+
fn=lambda ode_str, t_span_str, y0_val_str, dy0_dt_val_str, t_eval_count, method: solve_second_order_ode(
|
126 |
+
string_to_ode_func(ode_str, ('t', 'y', 'dy_dt')), # Note: dy_dt is one variable name here
|
127 |
+
parse_time_span(t_span_str),
|
128 |
+
parse_float_list(y0_val_str, expected_len=1)[0], # y0 is a single float
|
129 |
+
parse_float_list(dy0_dt_val_str, expected_len=1)[0], # dy0_dt is a single float
|
130 |
+
int(t_eval_count),
|
131 |
+
method
|
132 |
+
),
|
133 |
+
inputs=[
|
134 |
+
gr.Textbox(label="ODE Function (lambda t, y, dy_dt: ...)",
|
135 |
+
placeholder="e.g., lambda t, y, dy_dt: -0.1*dy_dt - math.sin(y)",
|
136 |
+
info="Define d²y/dt². `y` is current value, `dy_dt` is current first derivative."),
|
137 |
+
gr.Textbox(label="Time Span (t_start, t_end)", placeholder="e.g., 0,20"),
|
138 |
+
gr.Textbox(label="Initial y(t_start)", placeholder="e.g., 1.0"),
|
139 |
+
gr.Textbox(label="Initial dy/dt(t_start)", placeholder="e.g., 0.0"),
|
140 |
+
gr.Slider(minimum=10, maximum=1000, value=100, step=10, label="Evaluation Points Count"),
|
141 |
+
gr.Radio(choices=['RK45', 'LSODA', 'BDF', 'RK23', 'DOP853'], value='RK45', label="Solver Method")
|
142 |
+
],
|
143 |
+
outputs=[
|
144 |
+
gr.Image(label="Solution Plot (y(t) and dy/dt(t))", type="filepath", show_label=True, visible=lambda res: res['success'] and res['plot_path'] is not None),
|
145 |
+
gr.Textbox(label="Solver Message"),
|
146 |
+
gr.Textbox(label="Success Status"),
|
147 |
+
gr.JSON(label="Raw Data (t, y, dy_dt values)", visible=lambda res: res['success'])
|
148 |
+
],
|
149 |
+
title="Second-Order ODE Solver",
|
150 |
+
description="Solves d²y/dt² = f(t, y, dy/dt). " \
|
151 |
+
"WARNING: Uses eval() for the ODE function string - potential security risk. " \
|
152 |
+
"Example (Pendulum): ODE: lambda t, y, dy_dt: -9.81/1.0 * math.sin(y), y0: math.pi/4, dy0/dt: 0, Timespan: 0,10",
|
153 |
+
allow_flagging="never"
|
154 |
+
)
|
155 |
+
|
156 |
+
# Example usage for testing (can be removed)
|
157 |
+
if __name__ == '__main__':
|
158 |
+
# Test first order
|
159 |
+
# result1 = solve_first_order_ode(lambda t,y: -y*t, (0,5), [1])
|
160 |
+
# print(result1['success'], result1['message'])
|
161 |
+
# if result1['plot_path']: import os; os.system(f"open {result1['plot_path']}")
|
162 |
+
|
163 |
+
|
164 |
+
# Test second order
|
165 |
+
# result2 = solve_second_order_ode(lambda t,y,dydt: -0.1*dydt - y, (0,20), 1, 0)
|
166 |
+
# print(result2['success'], result2['message'])
|
167 |
+
# if result2['plot_path']: import os; os.system(f"open {result2['plot_path']}")
|
168 |
+
|
169 |
+
# To run one of these interfaces for testing:
|
170 |
+
# first_order_ode_interface.launch()
|
171 |
+
# second_order_ode_interface.launch()
|
172 |
+
pass
|
maths/university/linear_algebra.py
ADDED
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Linear Algebra operations for university level, using NumPy.
|
3 |
+
"""
|
4 |
+
import numpy as np
|
5 |
+
from typing import List, Union, Tuple
|
6 |
+
|
7 |
+
Matrix = Union[List[List[float]], np.ndarray]
|
8 |
+
Vector = Union[List[float], np.ndarray]
|
9 |
+
|
10 |
+
def matrix_add(matrix1: Matrix, matrix2: Matrix) -> np.ndarray:
|
11 |
+
"""
|
12 |
+
Adds two matrices.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
matrix1: The first matrix.
|
16 |
+
matrix2: The second matrix.
|
17 |
+
|
18 |
+
Returns:
|
19 |
+
The resulting matrix as a NumPy array.
|
20 |
+
Raises:
|
21 |
+
ValueError: If matrices have incompatible shapes.
|
22 |
+
"""
|
23 |
+
m1 = np.array(matrix1)
|
24 |
+
m2 = np.array(matrix2)
|
25 |
+
if m1.shape != m2.shape:
|
26 |
+
raise ValueError("Matrices must have the same dimensions for addition.")
|
27 |
+
return np.add(m1, m2)
|
28 |
+
|
29 |
+
def matrix_subtract(matrix1: Matrix, matrix2: Matrix) -> np.ndarray:
|
30 |
+
"""
|
31 |
+
Subtracts the second matrix from the first.
|
32 |
+
|
33 |
+
Args:
|
34 |
+
matrix1: The first matrix.
|
35 |
+
matrix2: The second matrix.
|
36 |
+
|
37 |
+
Returns:
|
38 |
+
The resulting matrix as a NumPy array.
|
39 |
+
Raises:
|
40 |
+
ValueError: If matrices have incompatible shapes.
|
41 |
+
"""
|
42 |
+
m1 = np.array(matrix1)
|
43 |
+
m2 = np.array(matrix2)
|
44 |
+
if m1.shape != m2.shape:
|
45 |
+
raise ValueError("Matrices must have the same dimensions for subtraction.")
|
46 |
+
return np.subtract(m1, m2)
|
47 |
+
|
48 |
+
def matrix_multiply(matrix1: Matrix, matrix2: Matrix) -> np.ndarray:
|
49 |
+
"""
|
50 |
+
Multiplies two matrices.
|
51 |
+
|
52 |
+
Args:
|
53 |
+
matrix1: The first matrix.
|
54 |
+
matrix2: The second matrix.
|
55 |
+
|
56 |
+
Returns:
|
57 |
+
The resulting matrix as a NumPy array.
|
58 |
+
Raises:
|
59 |
+
ValueError: If matrices have incompatible shapes for multiplication.
|
60 |
+
"""
|
61 |
+
m1 = np.array(matrix1)
|
62 |
+
m2 = np.array(matrix2)
|
63 |
+
if m1.shape[1] != m2.shape[0]:
|
64 |
+
raise ValueError("Number of columns in the first matrix must equal number of rows in the second for multiplication.")
|
65 |
+
return np.dot(m1, m2)
|
66 |
+
|
67 |
+
def matrix_determinant(matrix: Matrix) -> float:
|
68 |
+
"""
|
69 |
+
Calculates the determinant of a square matrix.
|
70 |
+
|
71 |
+
Args:
|
72 |
+
matrix: The matrix (must be square).
|
73 |
+
|
74 |
+
Returns:
|
75 |
+
The determinant of the matrix.
|
76 |
+
Raises:
|
77 |
+
ValueError: If the matrix is not square.
|
78 |
+
"""
|
79 |
+
m = np.array(matrix)
|
80 |
+
if m.shape[0] != m.shape[1]:
|
81 |
+
raise ValueError("Matrix must be square to calculate its determinant.")
|
82 |
+
return np.linalg.det(m)
|
83 |
+
|
84 |
+
def matrix_inverse(matrix: Matrix) -> np.ndarray:
|
85 |
+
"""
|
86 |
+
Calculates the inverse of a square matrix.
|
87 |
+
|
88 |
+
Args:
|
89 |
+
matrix: The matrix (must be square and invertible).
|
90 |
+
|
91 |
+
Returns:
|
92 |
+
The inverse of the matrix as a NumPy array.
|
93 |
+
Raises:
|
94 |
+
ValueError: If the matrix is not square.
|
95 |
+
np.linalg.LinAlgError: If the matrix is singular (not invertible).
|
96 |
+
"""
|
97 |
+
m = np.array(matrix)
|
98 |
+
if m.shape[0] != m.shape[1]:
|
99 |
+
raise ValueError("Matrix must be square to calculate its inverse.")
|
100 |
+
return np.linalg.inv(m)
|
101 |
+
|
102 |
+
def vector_add(vector1: Vector, vector2: Vector) -> np.ndarray:
|
103 |
+
"""
|
104 |
+
Adds two vectors.
|
105 |
+
|
106 |
+
Args:
|
107 |
+
vector1: The first vector.
|
108 |
+
vector2: The second vector.
|
109 |
+
|
110 |
+
Returns:
|
111 |
+
The resulting vector as a NumPy array.
|
112 |
+
Raises:
|
113 |
+
ValueError: If vectors have incompatible shapes.
|
114 |
+
"""
|
115 |
+
v1 = np.array(vector1)
|
116 |
+
v2 = np.array(vector2)
|
117 |
+
if v1.shape != v2.shape:
|
118 |
+
raise ValueError("Vectors must have the same dimensions for addition.")
|
119 |
+
return np.add(v1, v2)
|
120 |
+
|
121 |
+
def vector_subtract(vector1: Vector, vector2: Vector) -> np.ndarray:
|
122 |
+
"""
|
123 |
+
Subtracts the second vector from the first.
|
124 |
+
|
125 |
+
Args:
|
126 |
+
vector1: The first vector.
|
127 |
+
vector2: The second vector.
|
128 |
+
|
129 |
+
Returns:
|
130 |
+
The resulting vector as a NumPy array.
|
131 |
+
Raises:
|
132 |
+
ValueError: If vectors have incompatible shapes.
|
133 |
+
"""
|
134 |
+
v1 = np.array(vector1)
|
135 |
+
v2 = np.array(vector2)
|
136 |
+
if v1.shape != v2.shape:
|
137 |
+
raise ValueError("Vectors must have the same dimensions for subtraction.")
|
138 |
+
return np.subtract(v1, v2)
|
139 |
+
|
140 |
+
def vector_dot_product(vector1: Vector, vector2: Vector) -> float:
|
141 |
+
"""
|
142 |
+
Calculates the dot product of two vectors.
|
143 |
+
|
144 |
+
Args:
|
145 |
+
vector1: The first vector.
|
146 |
+
vector2: The second vector.
|
147 |
+
|
148 |
+
Returns:
|
149 |
+
The dot product of the vectors.
|
150 |
+
Raises:
|
151 |
+
ValueError: If vectors have incompatible shapes.
|
152 |
+
"""
|
153 |
+
v1 = np.array(vector1)
|
154 |
+
v2 = np.array(vector2)
|
155 |
+
if v1.shape != v2.shape:
|
156 |
+
raise ValueError("Vectors must have the same dimensions for dot product.")
|
157 |
+
return np.dot(v1, v2)
|
158 |
+
|
159 |
+
def vector_cross_product(vector1: Vector, vector2: Vector) -> np.ndarray:
|
160 |
+
"""
|
161 |
+
Calculates the cross product of two 3D vectors.
|
162 |
+
|
163 |
+
Args:
|
164 |
+
vector1: The first 3D vector.
|
165 |
+
vector2: The second 3D vector.
|
166 |
+
|
167 |
+
Returns:
|
168 |
+
The resulting 3D vector as a NumPy array.
|
169 |
+
Raises:
|
170 |
+
ValueError: If vectors are not 3-dimensional.
|
171 |
+
"""
|
172 |
+
v1 = np.array(vector1)
|
173 |
+
v2 = np.array(vector2)
|
174 |
+
if v1.shape != (3,) or v2.shape != (3,):
|
175 |
+
raise ValueError("Cross product is only defined for 3-dimensional vectors.")
|
176 |
+
return np.cross(v1, v2)
|
177 |
+
|
178 |
+
def solve_linear_system(coefficients_matrix: Matrix, constants_vector: Vector) -> np.ndarray:
|
179 |
+
"""
|
180 |
+
Solves a system of linear equations Ax = B.
|
181 |
+
|
182 |
+
Args:
|
183 |
+
coefficients_matrix (A): The matrix of coefficients.
|
184 |
+
constants_vector (B): The vector of constants.
|
185 |
+
|
186 |
+
Returns:
|
187 |
+
The solution vector (x) as a NumPy array.
|
188 |
+
Raises:
|
189 |
+
ValueError: If the coefficient matrix is not square or dimensions are incompatible.
|
190 |
+
np.linalg.LinAlgError: If the system has no unique solution (e.g., singular matrix).
|
191 |
+
"""
|
192 |
+
A = np.array(coefficients_matrix)
|
193 |
+
B = np.array(constants_vector)
|
194 |
+
|
195 |
+
if A.shape[0] != A.shape[1]:
|
196 |
+
raise ValueError("Coefficient matrix must be square.")
|
197 |
+
if A.shape[0] != B.shape[0]:
|
198 |
+
raise ValueError("Number of rows in coefficient matrix must match number of elements in constants vector.")
|
199 |
+
|
200 |
+
return np.linalg.solve(A, B)
|
201 |
+
|
202 |
+
# Example Usage (can be removed or commented out for production)
|
203 |
+
if __name__ == '__main__':
|
204 |
+
# Matrix examples
|
205 |
+
m_a = [[1, 2], [3, 4]]
|
206 |
+
m_b = [[5, 6], [7, 8]]
|
207 |
+
print("Matrix A+B:", matrix_add(m_a, m_b))
|
208 |
+
print("Matrix A*B:", matrix_multiply(m_a, m_b))
|
209 |
+
m_c = [[1,2,3],[4,5,6],[7,8,9]] # Singular
|
210 |
+
m_d = [[1,2,3],[0,1,4],[5,6,0]]
|
211 |
+
print("Determinant of D:", matrix_determinant(m_d))
|
212 |
+
try:
|
213 |
+
print("Inverse of D:", matrix_inverse(m_d))
|
214 |
+
except np.linalg.LinAlgError as e:
|
215 |
+
print("Error inverting D:", e)
|
216 |
+
try:
|
217 |
+
print("Inverse of C (singular):", matrix_inverse(m_c)) # Should raise error
|
218 |
+
except np.linalg.LinAlgError as e:
|
219 |
+
print("Error inverting C:", e)
|
220 |
+
|
221 |
+
|
222 |
+
# Vector examples
|
223 |
+
v_a = [1, 2, 3]
|
224 |
+
v_b = [4, 5, 6]
|
225 |
+
print("Vector A+B:", vector_add(v_a, v_b))
|
226 |
+
print("Vector A.B (dot):", vector_dot_product(v_a, v_b))
|
227 |
+
print("Vector AxB (cross):", vector_cross_product(v_a, v_b))
|
228 |
+
|
229 |
+
v_c = [1,2]
|
230 |
+
v_d = [3,4]
|
231 |
+
try:
|
232 |
+
print(vector_cross_product(v_c,v_d)) # Should error
|
233 |
+
except ValueError as e:
|
234 |
+
print("Error cross product:", e)
|
235 |
+
|
236 |
+
# Solve linear system example
|
237 |
+
# 2x + y = 1
|
238 |
+
# x + y = 1
|
239 |
+
coeffs = [[2, 1], [1, 1]]
|
240 |
+
consts = [1, 1]
|
241 |
+
print("Solution to 2x+y=1, x+y=1 is:", solve_linear_system(coeffs, consts)) # Expected: [0, 1]
|
242 |
+
|
243 |
+
# Singular system
|
244 |
+
coeffs_singular = [[1, 1], [1, 1]]
|
245 |
+
consts_singular = [1, 2] # No solution
|
246 |
+
try:
|
247 |
+
print("Solution to singular system:", solve_linear_system(coeffs_singular, consts_singular))
|
248 |
+
except np.linalg.LinAlgError as e:
|
249 |
+
print("Error solving singular system:", e)
|
maths/university/linear_algebra_interface.py
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Gradio Interface for Linear Algebra operations.
|
3 |
+
"""
|
4 |
+
import gradio as gr
|
5 |
+
import numpy as np
|
6 |
+
import json # For parsing matrices and vectors easily
|
7 |
+
from typing import Union
|
8 |
+
|
9 |
+
from maths.university.linear_algebra import (
|
10 |
+
matrix_add, matrix_subtract, matrix_multiply,
|
11 |
+
matrix_determinant, matrix_inverse,
|
12 |
+
vector_add, vector_subtract, vector_dot_product,
|
13 |
+
vector_cross_product, solve_linear_system
|
14 |
+
)
|
15 |
+
|
16 |
+
# Helper functions to parse inputs
|
17 |
+
def parse_matrix(matrix_str: str, allow_empty_rows_cols=False) -> np.ndarray:
|
18 |
+
"""Parses a string representation of a matrix (JSON format or comma/semicolon separated) into a NumPy array."""
|
19 |
+
try:
|
20 |
+
# Try JSON parsing first for structure like [[1,2],[3,4]]
|
21 |
+
m = json.loads(matrix_str)
|
22 |
+
arr = np.array(m, dtype=float)
|
23 |
+
if not allow_empty_rows_cols and (arr.shape[0] == 0 or (arr.ndim > 1 and arr.shape[1] == 0)):
|
24 |
+
raise ValueError("Matrix cannot have zero rows or columns.")
|
25 |
+
if arr.ndim == 1 and arr.shape[0] > 0: # Handles single row matrix like "[1,2,3]"
|
26 |
+
arr = arr.reshape(1, -1)
|
27 |
+
elif arr.ndim == 0 and not allow_empty_rows_cols : # Handles scalar input like "5" if not allowed
|
28 |
+
raise ValueError("Input is a scalar, not a matrix.")
|
29 |
+
return arr
|
30 |
+
except (json.JSONDecodeError, TypeError, ValueError) as e_json:
|
31 |
+
# Fallback for simpler comma/semicolon separated format, e.g. "1,2;3,4"
|
32 |
+
try:
|
33 |
+
rows = [list(map(float, row.split(','))) for row in matrix_str.split(';') if row.strip()]
|
34 |
+
if not rows and not allow_empty_rows_cols:
|
35 |
+
raise ValueError("Matrix input is empty.")
|
36 |
+
if not rows and allow_empty_rows_cols: # e.g. for an empty matrix if needed
|
37 |
+
return np.array([])
|
38 |
+
|
39 |
+
# Check for consistent row lengths
|
40 |
+
if len(rows) > 1:
|
41 |
+
first_row_len = len(rows[0])
|
42 |
+
if not all(len(r) == first_row_len for r in rows):
|
43 |
+
raise ValueError("All rows must have the same number of columns.")
|
44 |
+
|
45 |
+
arr = np.array(rows, dtype=float)
|
46 |
+
if not allow_empty_rows_cols and (arr.shape[0] == 0 or (arr.ndim > 1 and arr.shape[1] == 0)):
|
47 |
+
raise ValueError("Matrix cannot have zero rows or columns after parsing.")
|
48 |
+
return arr
|
49 |
+
|
50 |
+
except ValueError as e_csv:
|
51 |
+
raise gr.Error(f"Invalid matrix format. Use JSON (e.g., [[1,2],[3,4]]) or comma/semicolon (e.g., 1,2;3,4). Error: {e_csv} (Original JSON error: {e_json})")
|
52 |
+
except Exception as e_gen: # Catch any other parsing errors
|
53 |
+
raise gr.Error(f"General error parsing matrix: {e_gen}")
|
54 |
+
|
55 |
+
|
56 |
+
def parse_vector(vector_str: str) -> np.ndarray:
|
57 |
+
"""Parses a string representation of a vector (JSON or comma-separated) into a NumPy array."""
|
58 |
+
try:
|
59 |
+
# Try JSON parsing for lists like [1,2,3]
|
60 |
+
v = json.loads(vector_str)
|
61 |
+
arr = np.array(v, dtype=float)
|
62 |
+
if arr.ndim != 1:
|
63 |
+
raise ValueError("Vector must be 1-dimensional.")
|
64 |
+
if arr.shape[0] == 0:
|
65 |
+
raise ValueError("Vector cannot be empty.")
|
66 |
+
return arr
|
67 |
+
except (json.JSONDecodeError, TypeError, ValueError) as e_json:
|
68 |
+
# Fallback for simpler comma-separated format e.g. "1,2,3"
|
69 |
+
try:
|
70 |
+
if not vector_str.strip():
|
71 |
+
raise ValueError("Vector input is empty.")
|
72 |
+
arr = np.array([float(x.strip()) for x in vector_str.split(',') if x.strip()], dtype=float)
|
73 |
+
if arr.shape[0] == 0: # case like "," or " , "
|
74 |
+
raise ValueError("Vector cannot be empty after parsing.")
|
75 |
+
return arr
|
76 |
+
except ValueError as e_csv:
|
77 |
+
raise gr.Error(f"Invalid vector format. Use JSON (e.g., [1,2,3]) or comma-separated (e.g., 1,2,3). Error: {e_csv} (Original JSON error: {e_json})")
|
78 |
+
except Exception as e_gen:
|
79 |
+
raise gr.Error(f"General error parsing vector: {e_gen}")
|
80 |
+
|
81 |
+
def format_output(data: Union[np.ndarray, float, str]) -> str:
|
82 |
+
"""Formats NumPy array or float for display, handling errors."""
|
83 |
+
if isinstance(data, np.ndarray):
|
84 |
+
return str(data.tolist()) # Convert to list for cleaner display
|
85 |
+
elif isinstance(data, (float, int, str)):
|
86 |
+
return str(data)
|
87 |
+
return "Output type not recognized."
|
88 |
+
|
89 |
+
# --- Matrix Interfaces ---
|
90 |
+
matrix_add_interface = gr.Interface(
|
91 |
+
fn=lambda m1_str, m2_str: format_output(matrix_add(parse_matrix(m1_str), parse_matrix(m2_str))),
|
92 |
+
inputs=[
|
93 |
+
gr.Textbox(label="Matrix 1 (JSON or CSV format)", placeholder="e.g., [[1,2],[3,4]] or 1,2;3,4"),
|
94 |
+
gr.Textbox(label="Matrix 2 (JSON or CSV format)", placeholder="e.g., [[5,6],[7,8]] or 5,6;7,8")
|
95 |
+
],
|
96 |
+
outputs=gr.Textbox(label="Resulting Matrix"),
|
97 |
+
title="Matrix Addition",
|
98 |
+
description="Adds two matrices. Ensure they have the same dimensions."
|
99 |
+
)
|
100 |
+
|
101 |
+
matrix_subtract_interface = gr.Interface(
|
102 |
+
fn=lambda m1_str, m2_str: format_output(matrix_subtract(parse_matrix(m1_str), parse_matrix(m2_str))),
|
103 |
+
inputs=[
|
104 |
+
gr.Textbox(label="Matrix 1 (Minuend)", placeholder="e.g., [[5,6],[7,8]]"),
|
105 |
+
gr.Textbox(label="Matrix 2 (Subtrahend)", placeholder="e.g., [[1,2],[3,4]]")
|
106 |
+
],
|
107 |
+
outputs=gr.Textbox(label="Resulting Matrix"),
|
108 |
+
title="Matrix Subtraction",
|
109 |
+
description="Subtracts the second matrix from the first. Ensure they have the same dimensions."
|
110 |
+
)
|
111 |
+
|
112 |
+
matrix_multiply_interface = gr.Interface(
|
113 |
+
fn=lambda m1_str, m2_str: format_output(matrix_multiply(parse_matrix(m1_str), parse_matrix(m2_str))),
|
114 |
+
inputs=[
|
115 |
+
gr.Textbox(label="Matrix 1", placeholder="e.g., [[1,2],[3,4]]"),
|
116 |
+
gr.Textbox(label="Matrix 2", placeholder="e.g., [[5,6],[7,8]]")
|
117 |
+
],
|
118 |
+
outputs=gr.Textbox(label="Resulting Matrix"),
|
119 |
+
title="Matrix Multiplication",
|
120 |
+
description="Multiplies two matrices. Columns of Matrix 1 must equal rows of Matrix 2."
|
121 |
+
)
|
122 |
+
|
123 |
+
matrix_determinant_interface = gr.Interface(
|
124 |
+
fn=lambda m_str: format_output(matrix_determinant(parse_matrix(m_str))),
|
125 |
+
inputs=gr.Textbox(label="Matrix (must be square)", placeholder="e.g., [[1,2],[3,4]]"),
|
126 |
+
outputs=gr.Textbox(label="Determinant"),
|
127 |
+
title="Matrix Determinant",
|
128 |
+
description="Calculates the determinant of a square matrix."
|
129 |
+
)
|
130 |
+
|
131 |
+
matrix_inverse_interface = gr.Interface(
|
132 |
+
fn=lambda m_str: format_output(matrix_inverse(parse_matrix(m_str))),
|
133 |
+
inputs=gr.Textbox(label="Matrix (must be square and invertible)", placeholder="e.g., [[1,2],[3,7]]"),
|
134 |
+
outputs=gr.Textbox(label="Inverse Matrix"),
|
135 |
+
title="Matrix Inverse",
|
136 |
+
description="Calculates the inverse of a square matrix. Matrix must be invertible (non-singular)."
|
137 |
+
)
|
138 |
+
|
139 |
+
# --- Vector Interfaces ---
|
140 |
+
vector_add_interface = gr.Interface(
|
141 |
+
fn=lambda v1_str, v2_str: format_output(vector_add(parse_vector(v1_str), parse_vector(v2_str))),
|
142 |
+
inputs=[
|
143 |
+
gr.Textbox(label="Vector 1 (JSON or CSV format)", placeholder="e.g., [1,2,3] or 1,2,3"),
|
144 |
+
gr.Textbox(label="Vector 2 (JSON or CSV format)", placeholder="e.g., [4,5,6] or 4,5,6")
|
145 |
+
],
|
146 |
+
outputs=gr.Textbox(label="Resulting Vector"),
|
147 |
+
title="Vector Addition",
|
148 |
+
description="Adds two vectors. Ensure they have the same dimensions."
|
149 |
+
)
|
150 |
+
|
151 |
+
vector_subtract_interface = gr.Interface(
|
152 |
+
fn=lambda v1_str, v2_str: format_output(vector_subtract(parse_vector(v1_str), parse_vector(v2_str))),
|
153 |
+
inputs=[
|
154 |
+
gr.Textbox(label="Vector 1 (Minuend)", placeholder="e.g., [4,5,6]"),
|
155 |
+
gr.Textbox(label="Vector 2 (Subtrahend)", placeholder="e.g., [1,2,3]")
|
156 |
+
],
|
157 |
+
outputs=gr.Textbox(label="Resulting Vector"),
|
158 |
+
title="Vector Subtraction",
|
159 |
+
description="Subtracts the second vector from the first. Ensure they have the same dimensions."
|
160 |
+
)
|
161 |
+
|
162 |
+
vector_dot_product_interface = gr.Interface(
|
163 |
+
fn=lambda v1_str, v2_str: format_output(vector_dot_product(parse_vector(v1_str), parse_vector(v2_str))),
|
164 |
+
inputs=[
|
165 |
+
gr.Textbox(label="Vector 1", placeholder="e.g., [1,2,3]"),
|
166 |
+
gr.Textbox(label="Vector 2", placeholder="e.g., [4,5,6]")
|
167 |
+
],
|
168 |
+
outputs=gr.Textbox(label="Dot Product (Scalar)"),
|
169 |
+
title="Vector Dot Product",
|
170 |
+
description="Calculates the dot product of two vectors. Ensure they have the same dimensions."
|
171 |
+
)
|
172 |
+
|
173 |
+
vector_cross_product_interface = gr.Interface(
|
174 |
+
fn=lambda v1_str, v2_str: format_output(vector_cross_product(parse_vector(v1_str), parse_vector(v2_str))),
|
175 |
+
inputs=[
|
176 |
+
gr.Textbox(label="Vector 1 (3D)", placeholder="e.g., [1,2,3]"),
|
177 |
+
gr.Textbox(label="Vector 2 (3D)", placeholder="e.g., [4,5,6]")
|
178 |
+
],
|
179 |
+
outputs=gr.Textbox(label="Resulting Vector (3D)"),
|
180 |
+
title="Vector Cross Product",
|
181 |
+
description="Calculates the cross product of two 3D vectors."
|
182 |
+
)
|
183 |
+
|
184 |
+
# --- Solve Linear System Interface ---
|
185 |
+
solve_linear_system_interface = gr.Interface(
|
186 |
+
fn=lambda A_str, B_str: format_output(solve_linear_system(parse_matrix(A_str), parse_vector(B_str))),
|
187 |
+
inputs=[
|
188 |
+
gr.Textbox(label="Coefficients Matrix (A)", placeholder="For Ax=B. e.g., [[2,1],[1,1]] for 2x+y, x+y"),
|
189 |
+
gr.Textbox(label="Constants Vector (B)", placeholder="For Ax=B. e.g., [1,1] for =1, =1")
|
190 |
+
],
|
191 |
+
outputs=gr.Textbox(label="Solution Vector (x)"),
|
192 |
+
title="Linear System Solver (Ax = B)",
|
193 |
+
description="Solves a system of linear equations Ax = B. Matrix A must be square and invertible."
|
194 |
+
)
|
195 |
+
|
196 |
+
# It might be useful to group these interfaces in a Gradio TabbedInterface or Blocks layout in a main app file.
|
197 |
+
# For now, this file just defines the individual interfaces.
|
maths/university/operations_research/BranchAndBoundSolver.py
ADDED
@@ -0,0 +1,384 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cvxpy as cp
|
2 |
+
import numpy as np
|
3 |
+
import math
|
4 |
+
from queue import PriorityQueue
|
5 |
+
import networkx as nx
|
6 |
+
import matplotlib.pyplot as plt
|
7 |
+
from tabulate import tabulate
|
8 |
+
from scipy.optimize import linprog
|
9 |
+
|
10 |
+
|
11 |
+
class BranchAndBoundSolver:
|
12 |
+
def __init__(self, c, A, b, integer_vars=None, binary_vars=None, maximize=True):
|
13 |
+
"""
|
14 |
+
Initialize the Branch and Bound solver
|
15 |
+
|
16 |
+
Parameters:
|
17 |
+
- c: Objective coefficients (for max c'x)
|
18 |
+
- A, b: Constraints Ax <= b
|
19 |
+
- integer_vars: Indices of variables that must be integers
|
20 |
+
- binary_vars: Indices of variables that must be binary (0 or 1)
|
21 |
+
- maximize: True for maximization, False for minimization
|
22 |
+
"""
|
23 |
+
self.c = c
|
24 |
+
self.A = A
|
25 |
+
self.b = b
|
26 |
+
self.n = len(c)
|
27 |
+
|
28 |
+
# Process binary and integer variables
|
29 |
+
self.binary_vars = [] if binary_vars is None else binary_vars
|
30 |
+
|
31 |
+
# If integer_vars not specified, assume all non-binary variables are integers
|
32 |
+
if integer_vars is None:
|
33 |
+
self.integer_vars = list(range(self.n))
|
34 |
+
else:
|
35 |
+
self.integer_vars = integer_vars.copy()
|
36 |
+
|
37 |
+
# Add binary variables to integer variables list if they're not already there
|
38 |
+
for idx in self.binary_vars:
|
39 |
+
if idx not in self.integer_vars:
|
40 |
+
self.integer_vars.append(idx)
|
41 |
+
|
42 |
+
# Best solution found so far
|
43 |
+
self.best_solution = None
|
44 |
+
self.best_objective = float('-inf') if maximize else float('inf')
|
45 |
+
self.maximize = maximize
|
46 |
+
|
47 |
+
# Track nodes explored
|
48 |
+
self.nodes_explored = 0
|
49 |
+
|
50 |
+
# Graph for visualization
|
51 |
+
self.graph = nx.DiGraph()
|
52 |
+
self.node_id = 0
|
53 |
+
|
54 |
+
# For tabular display of steps
|
55 |
+
self.steps_table = []
|
56 |
+
|
57 |
+
# Set of active nodes
|
58 |
+
self.active_nodes = set()
|
59 |
+
|
60 |
+
def is_integer_feasible(self, x):
|
61 |
+
"""Check if the solution satisfies integer constraints"""
|
62 |
+
if x is None:
|
63 |
+
return False
|
64 |
+
|
65 |
+
for idx in self.integer_vars:
|
66 |
+
if abs(round(x[idx]) - x[idx]) > 1e-6:
|
67 |
+
return False
|
68 |
+
return True
|
69 |
+
|
70 |
+
def get_branching_variable(self, x):
|
71 |
+
"""Select most fractional variable to branch on"""
|
72 |
+
max_fractional = -1
|
73 |
+
branching_var = -1
|
74 |
+
|
75 |
+
for idx in self.integer_vars:
|
76 |
+
fractional_part = abs(x[idx] - round(x[idx]))
|
77 |
+
if fractional_part > max_fractional and fractional_part > 1e-6:
|
78 |
+
max_fractional = fractional_part
|
79 |
+
branching_var = idx
|
80 |
+
|
81 |
+
return branching_var
|
82 |
+
|
83 |
+
def solve_relaxation(self, lower_bounds, upper_bounds):
|
84 |
+
"""Solve the continuous relaxation with given bounds"""
|
85 |
+
x = cp.Variable(self.n)
|
86 |
+
|
87 |
+
# Set the objective - maximize c'x or minimize -c'x
|
88 |
+
if self.maximize:
|
89 |
+
objective = cp.Maximize(self.c @ x)
|
90 |
+
else:
|
91 |
+
objective = cp.Minimize(self.c @ x)
|
92 |
+
|
93 |
+
# Basic constraints Ax <= b
|
94 |
+
constraints = [self.A @ x <= self.b]
|
95 |
+
|
96 |
+
# Add bounds
|
97 |
+
for i in range(self.n):
|
98 |
+
if lower_bounds[i] is not None:
|
99 |
+
constraints.append(x[i] >= lower_bounds[i])
|
100 |
+
if upper_bounds[i] is not None:
|
101 |
+
constraints.append(x[i] <= upper_bounds[i])
|
102 |
+
|
103 |
+
prob = cp.Problem(objective, constraints)
|
104 |
+
|
105 |
+
try:
|
106 |
+
objective_value = prob.solve()
|
107 |
+
return x.value, objective_value
|
108 |
+
except:
|
109 |
+
return None, float('-inf') if self.maximize else float('inf')
|
110 |
+
|
111 |
+
def add_node_to_graph(self, node_name, objective_value, x_value, parent=None, branch_var=None, branch_cond=None):
|
112 |
+
"""Add a node to the branch and bound graph"""
|
113 |
+
self.graph.add_node(node_name, obj=objective_value, x=x_value,
|
114 |
+
branch_var=branch_var, branch_cond=branch_cond)
|
115 |
+
|
116 |
+
if parent is not None:
|
117 |
+
# Use branch_var + 1 to show 1-indexed variables in the display
|
118 |
+
label = f"x_{branch_var + 1} {branch_cond}"
|
119 |
+
self.graph.add_edge(parent, node_name, label=label)
|
120 |
+
|
121 |
+
return node_name
|
122 |
+
|
123 |
+
def visualize_graph(self):
|
124 |
+
"""Visualize the branch and bound graph"""
|
125 |
+
fig = plt.figure(figsize=(20, 8))
|
126 |
+
pos = nx.spring_layout(self.graph) # Use spring layout instead of graphviz
|
127 |
+
|
128 |
+
# Node labels: Node name, Objective value and solution
|
129 |
+
labels = {}
|
130 |
+
for node, data in self.graph.nodes(data=True):
|
131 |
+
if data.get('x') is not None:
|
132 |
+
x_str = ', '.join([f"{x:.2f}" for x in data['x']])
|
133 |
+
labels[node] = f"{node}\n({data['obj']:.2f}, ({x_str}))"
|
134 |
+
else:
|
135 |
+
labels[node] = f"{node}\nInfeasible"
|
136 |
+
|
137 |
+
# Edge labels: Branching conditions
|
138 |
+
edge_labels = nx.get_edge_attributes(self.graph, 'label')
|
139 |
+
|
140 |
+
# Draw nodes
|
141 |
+
nx.draw_networkx_nodes(self.graph, pos, node_size=2000, node_color='skyblue')
|
142 |
+
|
143 |
+
# Draw edges
|
144 |
+
nx.draw_networkx_edges(self.graph, pos, width=1.5, arrowsize=20, edge_color='gray')
|
145 |
+
|
146 |
+
# Draw labels
|
147 |
+
nx.draw_networkx_labels(self.graph, pos, labels, font_size=10, font_family='sans-serif')
|
148 |
+
nx.draw_networkx_edge_labels(self.graph, pos, edge_labels=edge_labels, font_size=10, font_family='sans-serif')
|
149 |
+
|
150 |
+
plt.title("Branch and Bound Tree", fontsize=14)
|
151 |
+
plt.axis('off')
|
152 |
+
plt.tight_layout()
|
153 |
+
return fig # Return the figure instead of showing it
|
154 |
+
|
155 |
+
|
156 |
+
def display_steps_table(self):
|
157 |
+
"""Display the steps in tabular format"""
|
158 |
+
headers = ["Node", "z", "x", "z*", "x*", "UB", "LB", "Z at end of stage"]
|
159 |
+
print(tabulate(self.steps_table, headers=headers, tablefmt="grid"))
|
160 |
+
|
161 |
+
def solve(self, verbose=True):
|
162 |
+
"""Solve the problem using branch and bound"""
|
163 |
+
# Initialize bounds
|
164 |
+
lower_bounds = [0] * self.n
|
165 |
+
upper_bounds = [None] * self.n # None means unbounded
|
166 |
+
|
167 |
+
# Set upper bounds for binary variables
|
168 |
+
for idx in self.binary_vars:
|
169 |
+
upper_bounds[idx] = 1
|
170 |
+
|
171 |
+
# Create a priority queue for nodes (max heap for maximization, min heap for minimization)
|
172 |
+
# We use negative values for maximization to simulate max heap with Python's min heap
|
173 |
+
node_queue = PriorityQueue()
|
174 |
+
|
175 |
+
# Solve the root relaxation
|
176 |
+
print("Step 1: Solving root relaxation (continuous problem)")
|
177 |
+
x_root, obj_root = self.solve_relaxation(lower_bounds, upper_bounds)
|
178 |
+
|
179 |
+
if x_root is None:
|
180 |
+
print("Root problem infeasible")
|
181 |
+
return None, float('-inf') if self.maximize else float('inf')
|
182 |
+
|
183 |
+
# Add root node to the graph
|
184 |
+
root_node = "S0"
|
185 |
+
self.add_node_to_graph(root_node, obj_root, x_root)
|
186 |
+
|
187 |
+
print(f"Root relaxation objective: {obj_root:.6f}")
|
188 |
+
print(f"Root solution: {x_root}")
|
189 |
+
|
190 |
+
# Initial upper bound is the root objective
|
191 |
+
upper_bound = obj_root
|
192 |
+
|
193 |
+
# Check if the root solution is already integer-feasible
|
194 |
+
if self.is_integer_feasible(x_root):
|
195 |
+
print("Root solution is integer-feasible! No need for branching.")
|
196 |
+
self.best_solution = x_root
|
197 |
+
self.best_objective = obj_root
|
198 |
+
|
199 |
+
# Add to steps table
|
200 |
+
active_nodes_str = "∅" if not self.active_nodes else "{" + ", ".join(self.active_nodes) + "}"
|
201 |
+
self.steps_table.append([
|
202 |
+
root_node, f"{obj_root:.2f}", f"({', '.join([f'{x:.2f}' for x in x_root])})",
|
203 |
+
f"{self.best_objective:.2f}", f"({', '.join([f'{x:.2f}' for x in self.best_solution])})",
|
204 |
+
f"{upper_bound:.2f}", f"{self.best_objective:.2f}", active_nodes_str
|
205 |
+
])
|
206 |
+
|
207 |
+
self.display_steps_table()
|
208 |
+
self.visualize_graph()
|
209 |
+
return x_root, obj_root
|
210 |
+
|
211 |
+
# Add root node to the queue and active nodes set
|
212 |
+
priority = -obj_root if self.maximize else obj_root
|
213 |
+
node_queue.put((priority, self.nodes_explored, root_node, lower_bounds.copy(), upper_bounds.copy()))
|
214 |
+
self.active_nodes.add(root_node)
|
215 |
+
|
216 |
+
# Add entry to steps table for root node
|
217 |
+
active_nodes_str = "{" + ", ".join(self.active_nodes) + "}"
|
218 |
+
lb_str = "-" if self.best_objective == float('-inf') else f"{self.best_objective:.2f}"
|
219 |
+
x_star_str = "-" if self.best_solution is None else f"({', '.join([f'{x:.2f}' for x in self.best_solution])})"
|
220 |
+
|
221 |
+
self.steps_table.append([
|
222 |
+
root_node, f"{obj_root:.2f}", f"({', '.join([f'{x:.2f}' for x in x_root])})",
|
223 |
+
lb_str, x_star_str, f"{upper_bound:.2f}", lb_str, active_nodes_str
|
224 |
+
])
|
225 |
+
|
226 |
+
print("\nStarting branch and bound process:")
|
227 |
+
node_counter = 1
|
228 |
+
|
229 |
+
while not node_queue.empty():
|
230 |
+
# Get the node with the highest objective (for maximization)
|
231 |
+
priority, _, node_name, node_lower_bounds, node_upper_bounds = node_queue.get()
|
232 |
+
self.nodes_explored += 1
|
233 |
+
|
234 |
+
print(f"\nStep {self.nodes_explored + 1}: Exploring node {node_name}")
|
235 |
+
|
236 |
+
# Remove from active nodes
|
237 |
+
self.active_nodes.remove(node_name)
|
238 |
+
|
239 |
+
# Branch on most fractional variable
|
240 |
+
branch_var = self.get_branching_variable(self.graph.nodes[node_name]['x'])
|
241 |
+
branch_val = self.graph.nodes[node_name]['x'][branch_var]
|
242 |
+
|
243 |
+
# For binary variables, always branch with x=0 and x=1
|
244 |
+
if branch_var in self.binary_vars:
|
245 |
+
floor_val = 0
|
246 |
+
ceil_val = 1
|
247 |
+
print(f" Branching on binary variable x_{branch_var + 1} with value {branch_val:.6f}")
|
248 |
+
print(f" Creating two branches: x_{branch_var + 1} = 0 and x_{branch_var + 1} = 1")
|
249 |
+
else:
|
250 |
+
floor_val = math.floor(branch_val)
|
251 |
+
ceil_val = math.ceil(branch_val)
|
252 |
+
print(f" Branching on variable x_{branch_var + 1} with value {branch_val:.6f}")
|
253 |
+
print(f" Creating two branches: x_{branch_var + 1} ≤ {floor_val} and x_{branch_var + 1} ≥ {ceil_val}")
|
254 |
+
|
255 |
+
# Process left branch (floor)
|
256 |
+
left_node = f"S{node_counter}"
|
257 |
+
node_counter += 1
|
258 |
+
|
259 |
+
# Create the "floor" branch
|
260 |
+
floor_lower_bounds = node_lower_bounds.copy()
|
261 |
+
floor_upper_bounds = node_upper_bounds.copy()
|
262 |
+
|
263 |
+
# For binary variables, set both bounds to 0 (x=0)
|
264 |
+
if branch_var in self.binary_vars:
|
265 |
+
floor_lower_bounds[branch_var] = 0
|
266 |
+
floor_upper_bounds[branch_var] = 0
|
267 |
+
branch_cond = f"= 0"
|
268 |
+
else:
|
269 |
+
floor_upper_bounds[branch_var] = floor_val
|
270 |
+
branch_cond = f"≤ {floor_val}"
|
271 |
+
|
272 |
+
# Solve the relaxation for this node
|
273 |
+
x_floor, obj_floor = self.solve_relaxation(floor_lower_bounds, floor_upper_bounds)
|
274 |
+
|
275 |
+
# Add node to graph
|
276 |
+
self.add_node_to_graph(left_node, obj_floor if x_floor is not None else float('-inf'),
|
277 |
+
x_floor, node_name, branch_var, branch_cond)
|
278 |
+
|
279 |
+
# Process the floor branch
|
280 |
+
if x_floor is None:
|
281 |
+
print(f" {left_node} is infeasible")
|
282 |
+
else:
|
283 |
+
print(f" {left_node} relaxation objective: {obj_floor:.6f}")
|
284 |
+
print(f" {left_node} solution: {x_floor}")
|
285 |
+
|
286 |
+
# Check if integer feasible and update best solution if needed
|
287 |
+
if self.is_integer_feasible(x_floor) and ((self.maximize and obj_floor > self.best_objective) or
|
288 |
+
(not self.maximize and obj_floor < self.best_objective)):
|
289 |
+
self.best_solution = x_floor.copy()
|
290 |
+
self.best_objective = obj_floor
|
291 |
+
print(f" Found new best integer solution with objective {self.best_objective:.6f}")
|
292 |
+
|
293 |
+
# Add to queue if not fathomed
|
294 |
+
if ((self.maximize and obj_floor > self.best_objective) or
|
295 |
+
(not self.maximize and obj_floor < self.best_objective)):
|
296 |
+
if not self.is_integer_feasible(x_floor): # Only branch if not integer feasible
|
297 |
+
priority = -obj_floor if self.maximize else obj_floor
|
298 |
+
node_queue.put((priority, self.nodes_explored, left_node,
|
299 |
+
floor_lower_bounds.copy(), floor_upper_bounds.copy()))
|
300 |
+
self.active_nodes.add(left_node)
|
301 |
+
|
302 |
+
# Process right branch (ceil)
|
303 |
+
right_node = f"S{node_counter}"
|
304 |
+
node_counter += 1
|
305 |
+
|
306 |
+
# Create the "ceil" branch
|
307 |
+
ceil_lower_bounds = node_lower_bounds.copy()
|
308 |
+
ceil_upper_bounds = node_upper_bounds.copy()
|
309 |
+
|
310 |
+
# For binary variables, set both bounds to 1 (x=1)
|
311 |
+
if branch_var in self.binary_vars:
|
312 |
+
ceil_lower_bounds[branch_var] = 1
|
313 |
+
ceil_upper_bounds[branch_var] = 1
|
314 |
+
branch_cond = f"= 1"
|
315 |
+
else:
|
316 |
+
ceil_lower_bounds[branch_var] = ceil_val
|
317 |
+
branch_cond = f"≥ {ceil_val}"
|
318 |
+
|
319 |
+
# Solve the relaxation for this node
|
320 |
+
x_ceil, obj_ceil = self.solve_relaxation(ceil_lower_bounds, ceil_upper_bounds)
|
321 |
+
|
322 |
+
# Add node to graph
|
323 |
+
self.add_node_to_graph(right_node, obj_ceil if x_ceil is not None else float('-inf'),
|
324 |
+
x_ceil, node_name, branch_var, branch_cond)
|
325 |
+
|
326 |
+
# Process the ceil branch
|
327 |
+
if x_ceil is None:
|
328 |
+
print(f" {right_node} is infeasible")
|
329 |
+
else:
|
330 |
+
print(f" {right_node} relaxation objective: {obj_ceil:.6f}")
|
331 |
+
print(f" {right_node} solution: {x_ceil}")
|
332 |
+
|
333 |
+
# Check if integer feasible and update best solution if needed
|
334 |
+
if self.is_integer_feasible(x_ceil) and ((self.maximize and obj_ceil > self.best_objective) or
|
335 |
+
(not self.maximize and obj_ceil < self.best_objective)):
|
336 |
+
self.best_solution = x_ceil.copy()
|
337 |
+
self.best_objective = obj_ceil
|
338 |
+
print(f" Found new best integer solution with objective {self.best_objective:.6f}")
|
339 |
+
|
340 |
+
# Add to queue if not fathomed
|
341 |
+
if ((self.maximize and obj_ceil > self.best_objective) or
|
342 |
+
(not self.maximize and obj_ceil < self.best_objective)):
|
343 |
+
if not self.is_integer_feasible(x_ceil): # Only branch if not integer feasible
|
344 |
+
priority = -obj_ceil if self.maximize else obj_ceil
|
345 |
+
node_queue.put((priority, self.nodes_explored, right_node,
|
346 |
+
ceil_lower_bounds.copy(), ceil_upper_bounds.copy()))
|
347 |
+
self.active_nodes.add(right_node)
|
348 |
+
|
349 |
+
# Update upper bound as the best objective in the remaining nodes
|
350 |
+
if not node_queue.empty():
|
351 |
+
# Upper bound is the best possible objective in the remaining nodes
|
352 |
+
next_priority = node_queue.queue[0][0]
|
353 |
+
upper_bound = -next_priority if self.maximize else next_priority
|
354 |
+
else:
|
355 |
+
upper_bound = self.best_objective
|
356 |
+
|
357 |
+
# Add to steps table
|
358 |
+
active_nodes_str = "∅" if not self.active_nodes else "{" + ", ".join(self.active_nodes) + "}"
|
359 |
+
lb_str = f"{self.best_objective:.2f}" if self.best_objective != float('-inf') else "-"
|
360 |
+
x_star_str = "-" if self.best_solution is None else f"({', '.join([f'{x:.2f}' for x in self.best_solution])})"
|
361 |
+
|
362 |
+
self.steps_table.append([
|
363 |
+
node_name,
|
364 |
+
f"{self.graph.nodes[node_name]['obj']:.2f}",
|
365 |
+
f"({', '.join([f'{x:.2f}' for x in self.graph.nodes[node_name]['x']])})",
|
366 |
+
lb_str, x_star_str, f"{upper_bound:.2f}", lb_str, active_nodes_str
|
367 |
+
])
|
368 |
+
|
369 |
+
print("\nBranch and bound completed!")
|
370 |
+
print(f"Nodes explored: {self.nodes_explored}")
|
371 |
+
|
372 |
+
if self.best_solution is not None:
|
373 |
+
print(f"Optimal objective: {self.best_objective:.6f}")
|
374 |
+
print(f"Optimal solution: {self.best_solution}")
|
375 |
+
else:
|
376 |
+
print("No feasible integer solution found")
|
377 |
+
|
378 |
+
# Display steps table
|
379 |
+
self.display_steps_table()
|
380 |
+
|
381 |
+
# Visualize the graph
|
382 |
+
self.visualize_graph()
|
383 |
+
|
384 |
+
return self.best_solution, self.best_objective
|
maths/university/operations_research/DualSimplexSolver.py
ADDED
@@ -0,0 +1,443 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import sys
|
3 |
+
from .solve_lp_via_dual import solve_lp_via_dual
|
4 |
+
from .solve_primal_directly import solve_primal_directly
|
5 |
+
|
6 |
+
TOLERANCE = 1e-9
|
7 |
+
|
8 |
+
class DualSimplexSolver:
|
9 |
+
"""
|
10 |
+
Solves a Linear Programming problem using the Dual Simplex Method.
|
11 |
+
|
12 |
+
Assumes the problem is provided in the form:
|
13 |
+
Maximize/Minimize c^T * x
|
14 |
+
Subject to:
|
15 |
+
A * x <= / >= / = b
|
16 |
+
x >= 0
|
17 |
+
|
18 |
+
The algorithm works best when the initial tableau (after converting all
|
19 |
+
constraints to <=) is dual feasible (objective row coefficients >= 0 for Max)
|
20 |
+
but primal infeasible (some RHS values are negative).
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(self, objective_type, c, A, relations, b):
|
24 |
+
"""
|
25 |
+
Initializes the solver.
|
26 |
+
|
27 |
+
Args:
|
28 |
+
objective_type (str): 'max' or 'min'.
|
29 |
+
c (list or np.array): Coefficients of the objective function.
|
30 |
+
A (list of lists or np.array): Coefficients of the constraints LHS.
|
31 |
+
relations (list): List of strings ('<=', '>=', '=') for each constraint.
|
32 |
+
b (list or np.array): RHS values of the constraints.
|
33 |
+
"""
|
34 |
+
self.objective_type = objective_type.lower()
|
35 |
+
self.original_c = np.array(c, dtype=float)
|
36 |
+
self.original_A = np.array(A, dtype=float)
|
37 |
+
self.original_relations = relations
|
38 |
+
self.original_b = np.array(b, dtype=float)
|
39 |
+
|
40 |
+
self.num_original_vars = len(c)
|
41 |
+
self.num_constraints = len(b)
|
42 |
+
|
43 |
+
self.tableau = None
|
44 |
+
self.basic_vars = [] # Indices of basic variables (column index)
|
45 |
+
self.var_names = [] # Names like 'x1', 's1', etc.
|
46 |
+
self.is_minimized_problem = False # Flag to adjust final Z
|
47 |
+
|
48 |
+
self._preprocess()
|
49 |
+
|
50 |
+
def _preprocess(self):
|
51 |
+
"""
|
52 |
+
Converts the problem to the standard form for Dual Simplex:
|
53 |
+
- Maximization objective
|
54 |
+
- All constraints are <=
|
55 |
+
- Adds slack variables
|
56 |
+
- Builds the initial tableau
|
57 |
+
"""
|
58 |
+
# --- 1. Handle Objective Function ---
|
59 |
+
if self.objective_type == 'min':
|
60 |
+
self.is_minimized_problem = True
|
61 |
+
current_c = -self.original_c
|
62 |
+
else:
|
63 |
+
current_c = self.original_c
|
64 |
+
|
65 |
+
# --- 2. Handle Constraints and Slack Variables ---
|
66 |
+
num_slacks_added = 0
|
67 |
+
processed_A = []
|
68 |
+
processed_b = []
|
69 |
+
self.basic_vars = [] # Will store column indices of basic vars
|
70 |
+
|
71 |
+
# Create variable names
|
72 |
+
self.var_names = [f'x{i+1}' for i in range(self.num_original_vars)]
|
73 |
+
slack_var_names = []
|
74 |
+
|
75 |
+
for i in range(self.num_constraints):
|
76 |
+
A_row = self.original_A[i]
|
77 |
+
b_val = self.original_b[i]
|
78 |
+
relation = self.original_relations[i]
|
79 |
+
|
80 |
+
if relation == '>=':
|
81 |
+
# Multiply by -1 to convert to <=
|
82 |
+
processed_A.append(-A_row)
|
83 |
+
processed_b.append(-b_val)
|
84 |
+
elif relation == '=':
|
85 |
+
# Convert Ax = b into Ax <= b and Ax >= b
|
86 |
+
# First: Ax <= b
|
87 |
+
processed_A.append(A_row)
|
88 |
+
processed_b.append(b_val)
|
89 |
+
# Second: Ax >= b --> -Ax <= -b
|
90 |
+
processed_A.append(-A_row)
|
91 |
+
processed_b.append(-b_val)
|
92 |
+
elif relation == '<=':
|
93 |
+
processed_A.append(A_row)
|
94 |
+
processed_b.append(b_val)
|
95 |
+
else:
|
96 |
+
raise ValueError(f"Invalid relation symbol: {relation}")
|
97 |
+
|
98 |
+
# Update number of effective constraints after handling '='
|
99 |
+
effective_num_constraints = len(processed_b)
|
100 |
+
|
101 |
+
# Add slack variables for all processed constraints (which are now all <=)
|
102 |
+
num_slack_vars = effective_num_constraints
|
103 |
+
final_A = np.zeros((effective_num_constraints, self.num_original_vars + num_slack_vars))
|
104 |
+
final_b = np.array(processed_b, dtype=float)
|
105 |
+
|
106 |
+
# Populate original variable coefficients
|
107 |
+
final_A[:, :self.num_original_vars] = np.array(processed_A, dtype=float)
|
108 |
+
|
109 |
+
# Add slack variable identity matrix part and names
|
110 |
+
for i in range(effective_num_constraints):
|
111 |
+
slack_col_index = self.num_original_vars + i
|
112 |
+
final_A[i, slack_col_index] = 1
|
113 |
+
slack_var_names.append(f's{i+1}')
|
114 |
+
self.basic_vars.append(slack_col_index) # Initially, slacks are basic
|
115 |
+
|
116 |
+
self.var_names.extend(slack_var_names)
|
117 |
+
|
118 |
+
# --- 3. Build the Tableau ---
|
119 |
+
num_total_vars = self.num_original_vars + num_slack_vars
|
120 |
+
# Rows: 1 for objective + number of constraints
|
121 |
+
# Cols: 1 for Z + number of total vars + 1 for RHS
|
122 |
+
self.tableau = np.zeros((effective_num_constraints + 1, num_total_vars + 2))
|
123 |
+
|
124 |
+
# Row 0 (Objective Z): [1, -c, 0_slacks, 0_rhs]
|
125 |
+
self.tableau[0, 0] = 1 # Z coefficient
|
126 |
+
self.tableau[0, 1:self.num_original_vars + 1] = -current_c
|
127 |
+
# Slack coefficients in objective are 0 initially
|
128 |
+
# RHS of objective row is 0 initially
|
129 |
+
|
130 |
+
# Rows 1 to m (Constraints): [0, A_final, b_final]
|
131 |
+
self.tableau[1:, 1:num_total_vars + 1] = final_A
|
132 |
+
self.tableau[1:, -1] = final_b
|
133 |
+
|
134 |
+
# Ensure the initial objective row is dual feasible (non-negative coeffs for Max)
|
135 |
+
# We rely on the user providing a problem where this holds after conversion.
|
136 |
+
if np.any(self.tableau[0, 1:-1] < -TOLERANCE):
|
137 |
+
print("\nWarning: Initial tableau is not dual feasible (objective row has negative coefficients).")
|
138 |
+
print("The standard Dual Simplex method might not apply directly or may require Phase I.")
|
139 |
+
# For this implementation, we'll proceed, but it might fail if assumption is violated.
|
140 |
+
|
141 |
+
|
142 |
+
def _print_tableau(self, iteration):
|
143 |
+
"""Prints the current state of the tableau."""
|
144 |
+
print(f"\n--- Iteration {iteration} ---")
|
145 |
+
header = ["BV"] + ["Z"] + self.var_names + ["RHS"]
|
146 |
+
print(" ".join(f"{h:>8}" for h in header))
|
147 |
+
print("-" * (len(header) * 9))
|
148 |
+
|
149 |
+
basic_var_map = {idx: name for idx, name in enumerate(self.var_names)}
|
150 |
+
row_basic_vars = ["Z"] + [basic_var_map.get(bv_idx, f'col{bv_idx}') for bv_idx in self.basic_vars]
|
151 |
+
|
152 |
+
for i, row_bv_name in enumerate(row_basic_vars):
|
153 |
+
row_str = [f"{row_bv_name:>8}"]
|
154 |
+
row_str.extend([f"{val: >8.3f}" for val in self.tableau[i]])
|
155 |
+
print(" ".join(row_str))
|
156 |
+
print("-" * (len(header) * 9))
|
157 |
+
|
158 |
+
|
159 |
+
def _find_pivot_row(self):
|
160 |
+
"""Finds the index of the leaving variable (pivot row)."""
|
161 |
+
rhs_values = self.tableau[1:, -1]
|
162 |
+
# Find the index of the most negative RHS value (among constraints)
|
163 |
+
if np.all(rhs_values >= -TOLERANCE):
|
164 |
+
return -1 # All RHS non-negative, current solution is feasible (and optimal)
|
165 |
+
|
166 |
+
pivot_row_index = np.argmin(rhs_values) + 1 # +1 because we skip obj row 0
|
167 |
+
# Check if the minimum value is actually negative
|
168 |
+
if self.tableau[pivot_row_index, -1] >= -TOLERANCE:
|
169 |
+
return -1 # Should not happen if np.all check passed, but safety check
|
170 |
+
|
171 |
+
print(f"\nStep: Select Pivot Row (Leaving Variable)")
|
172 |
+
print(f" RHS values (b): {rhs_values}")
|
173 |
+
leaving_var_idx = self.basic_vars[pivot_row_index - 1]
|
174 |
+
leaving_var_name = self.var_names[leaving_var_idx]
|
175 |
+
print(f" Most negative RHS is {self.tableau[pivot_row_index, -1]:.3f} in Row {pivot_row_index} (Basic Var: {leaving_var_name}).")
|
176 |
+
print(f" Leaving Variable: {leaving_var_name} (Row {pivot_row_index})")
|
177 |
+
return pivot_row_index
|
178 |
+
|
179 |
+
def _find_pivot_col(self, pivot_row_index):
|
180 |
+
"""Finds the index of the entering variable (pivot column)."""
|
181 |
+
pivot_row = self.tableau[pivot_row_index, 1:-1] # Exclude Z and RHS cols
|
182 |
+
objective_row = self.tableau[0, 1:-1] # Exclude Z and RHS cols
|
183 |
+
|
184 |
+
ratios = {}
|
185 |
+
min_ratio = float('inf')
|
186 |
+
pivot_col_index = -1
|
187 |
+
|
188 |
+
print(f"\nStep: Select Pivot Column (Entering Variable) using Ratio Test")
|
189 |
+
print(f" Pivot Row (Row {pivot_row_index}) coefficients (excluding Z, RHS): {pivot_row}")
|
190 |
+
print(f" Objective Row coefficients (excluding Z, RHS): {objective_row}")
|
191 |
+
print(f" Calculating ratios = ObjCoeff / abs(PivotRowCoeff) for PivotRowCoeff < 0:")
|
192 |
+
|
193 |
+
found_negative_coeff = False
|
194 |
+
for j, coeff in enumerate(pivot_row):
|
195 |
+
col_var_index = j # This is the index within the var_names list
|
196 |
+
col_tableau_index = j + 1 # This is the index in the full tableau row
|
197 |
+
|
198 |
+
if coeff < -TOLERANCE: # Must be strictly negative
|
199 |
+
found_negative_coeff = True
|
200 |
+
obj_coeff = objective_row[j]
|
201 |
+
# Ratio calculation: obj_coeff / abs(coeff) or obj_coeff / -coeff
|
202 |
+
ratio = obj_coeff / (-coeff)
|
203 |
+
ratios[col_var_index] = ratio
|
204 |
+
print(f" Var {self.var_names[col_var_index]} (Col {col_tableau_index}): Coeff={coeff:.3f}, ObjCoeff={obj_coeff:.3f}, Ratio = {obj_coeff:.3f} / {-coeff:.3f} = {ratio:.3f}")
|
205 |
+
|
206 |
+
# Update minimum ratio
|
207 |
+
if ratio < min_ratio:
|
208 |
+
min_ratio = ratio
|
209 |
+
pivot_col_index = col_tableau_index # Store the tableau column index
|
210 |
+
|
211 |
+
if not found_negative_coeff:
|
212 |
+
print(" No negative coefficients found in the pivot row.")
|
213 |
+
return -1 # Indicates primal infeasibility (dual unboundedness)
|
214 |
+
|
215 |
+
# Handle potential ties in minimum ratio (choose smallest column index - Bland's rule simplified)
|
216 |
+
min_ratio_vars = [idx for idx, r in ratios.items() if abs(r - min_ratio) < TOLERANCE]
|
217 |
+
if len(min_ratio_vars) > 1:
|
218 |
+
print(f" Tie detected for minimum ratio ({min_ratio:.3f}) among variables: {[self.var_names[idx] for idx in min_ratio_vars]}.")
|
219 |
+
# Apply Bland's rule: choose the variable with the smallest index
|
220 |
+
pivot_col_index = min(min_ratio_vars) + 1 # +1 for tableau index
|
221 |
+
print(f" Applying Bland's rule: Choosing variable with smallest index: {self.var_names[pivot_col_index - 1]}.")
|
222 |
+
elif pivot_col_index != -1:
|
223 |
+
entering_var_name = self.var_names[pivot_col_index - 1] # -1 to get var_name index
|
224 |
+
print(f" Minimum ratio is {min_ratio:.3f} for variable {entering_var_name} (Column {pivot_col_index}).")
|
225 |
+
print(f" Entering Variable: {entering_var_name} (Column {pivot_col_index})")
|
226 |
+
else:
|
227 |
+
# This case should technically not be reached if found_negative_coeff was true
|
228 |
+
print("Error in ratio calculation or tie-breaking.")
|
229 |
+
return -2 # Error indicator
|
230 |
+
|
231 |
+
return pivot_col_index
|
232 |
+
|
233 |
+
|
234 |
+
def _pivot(self, pivot_row_index, pivot_col_index):
|
235 |
+
"""Performs the pivot operation."""
|
236 |
+
pivot_element = self.tableau[pivot_row_index, pivot_col_index]
|
237 |
+
|
238 |
+
print(f"\nStep: Pivot Operation")
|
239 |
+
print(f" Pivot Element: {pivot_element:.3f} at (Row {pivot_row_index}, Col {pivot_col_index})")
|
240 |
+
|
241 |
+
if abs(pivot_element) < TOLERANCE:
|
242 |
+
print("Error: Pivot element is zero. Cannot proceed.")
|
243 |
+
# This might indicate an issue with the problem formulation or numerical instability.
|
244 |
+
raise ZeroDivisionError("Pivot element is too close to zero.")
|
245 |
+
|
246 |
+
# 1. Normalize the pivot row
|
247 |
+
print(f" Normalizing Pivot Row {pivot_row_index} by dividing by {pivot_element:.3f}")
|
248 |
+
self.tableau[pivot_row_index, :] /= pivot_element
|
249 |
+
|
250 |
+
# 2. Eliminate other entries in the pivot column
|
251 |
+
print(f" Eliminating other entries in Pivot Column {pivot_col_index}:")
|
252 |
+
for i in range(self.tableau.shape[0]):
|
253 |
+
if i != pivot_row_index:
|
254 |
+
factor = self.tableau[i, pivot_col_index]
|
255 |
+
if abs(factor) > TOLERANCE: # Only perform if factor is non-zero
|
256 |
+
print(f" Row {i} = Row {i} - ({factor:.3f}) * (New Row {pivot_row_index})")
|
257 |
+
self.tableau[i, :] -= factor * self.tableau[pivot_row_index, :]
|
258 |
+
|
259 |
+
# 3. Update basic variables list
|
260 |
+
# The variable corresponding to pivot_col_index becomes basic for pivot_row_index
|
261 |
+
old_basic_var_index = self.basic_vars[pivot_row_index - 1]
|
262 |
+
new_basic_var_index = pivot_col_index - 1 # Convert tableau col index to var_names index
|
263 |
+
self.basic_vars[pivot_row_index - 1] = new_basic_var_index
|
264 |
+
print(f" Updating Basic Variables: {self.var_names[new_basic_var_index]} replaces {self.var_names[old_basic_var_index]} in the basis for Row {pivot_row_index}.")
|
265 |
+
|
266 |
+
|
267 |
+
def solve(self, use_fallbacks=True):
|
268 |
+
"""
|
269 |
+
Executes the Dual Simplex algorithm.
|
270 |
+
|
271 |
+
Args:
|
272 |
+
use_fallbacks (bool): If True, will attempt to use alternative solvers
|
273 |
+
when the dual simplex method encounters issues
|
274 |
+
|
275 |
+
Returns:
|
276 |
+
tuple: (tableau, basic_vars) if successful using dual simplex,
|
277 |
+
or a dictionary of results if fallback solvers were used
|
278 |
+
"""
|
279 |
+
print("--- Starting Dual Simplex Method ---")
|
280 |
+
if self.tableau is None:
|
281 |
+
print("Error: Tableau not initialized.")
|
282 |
+
return None
|
283 |
+
|
284 |
+
iteration = 0
|
285 |
+
self._print_tableau(iteration)
|
286 |
+
|
287 |
+
while iteration < 100: # Safety break for too many iterations
|
288 |
+
iteration += 1
|
289 |
+
|
290 |
+
# 1. Check for Optimality (Primal Feasibility)
|
291 |
+
pivot_row_index = self._find_pivot_row()
|
292 |
+
if pivot_row_index == -1:
|
293 |
+
print("\n--- Optimal Solution Found ---")
|
294 |
+
print(" All RHS values are non-negative.")
|
295 |
+
self._print_results()
|
296 |
+
return self.tableau, self.basic_vars
|
297 |
+
|
298 |
+
# 2. Select Entering Variable (Pivot Column)
|
299 |
+
pivot_col_index = self._find_pivot_col(pivot_row_index)
|
300 |
+
|
301 |
+
# 3. Check for Primal Infeasibility (Dual Unboundedness)
|
302 |
+
if pivot_col_index == -1:
|
303 |
+
print("\n--- Primal Problem Infeasible ---")
|
304 |
+
print(f" All coefficients in Pivot Row {pivot_row_index} are non-negative, but RHS is negative.")
|
305 |
+
print(" The dual problem is unbounded, implying the primal problem has no feasible solution.")
|
306 |
+
|
307 |
+
if use_fallbacks:
|
308 |
+
return self._try_fallback_solvers("primal_infeasible")
|
309 |
+
return None, None # Indicate infeasibility
|
310 |
+
|
311 |
+
elif pivot_col_index == -2:
|
312 |
+
# Error during pivot column selection
|
313 |
+
print("\n--- Error during pivot column selection ---")
|
314 |
+
|
315 |
+
if use_fallbacks:
|
316 |
+
return self._try_fallback_solvers("pivot_error")
|
317 |
+
return None, None
|
318 |
+
|
319 |
+
# 4. Perform Pivot Operation
|
320 |
+
try:
|
321 |
+
self._pivot(pivot_row_index, pivot_col_index)
|
322 |
+
except ZeroDivisionError as e:
|
323 |
+
print(f"\n--- Error during pivot operation: {e} ---")
|
324 |
+
|
325 |
+
if use_fallbacks:
|
326 |
+
return self._try_fallback_solvers("numerical_instability")
|
327 |
+
return None, None
|
328 |
+
|
329 |
+
# Print the tableau after pivoting
|
330 |
+
self._print_tableau(iteration)
|
331 |
+
|
332 |
+
print("\n--- Maximum Iterations Reached ---")
|
333 |
+
print(" The algorithm did not converge within the iteration limit.")
|
334 |
+
print(" This might indicate cycling or a very large problem.")
|
335 |
+
|
336 |
+
if use_fallbacks:
|
337 |
+
return self._try_fallback_solvers("iteration_limit")
|
338 |
+
return None, None # Indicate non-convergence
|
339 |
+
|
340 |
+
def _try_fallback_solvers(self, error_type):
|
341 |
+
"""
|
342 |
+
Tries alternative solvers when the dual simplex method fails.
|
343 |
+
|
344 |
+
Args:
|
345 |
+
error_type (str): Type of error encountered in the dual simplex method
|
346 |
+
|
347 |
+
Returns:
|
348 |
+
dict: Results from fallback solvers
|
349 |
+
"""
|
350 |
+
print(f"\n--- Using Fallback Solvers due to '{error_type}' ---")
|
351 |
+
|
352 |
+
results = {
|
353 |
+
"error_type": error_type,
|
354 |
+
"dual_simplex_result": None,
|
355 |
+
"dual_approach_result": None,
|
356 |
+
"direct_solver_result": None
|
357 |
+
}
|
358 |
+
|
359 |
+
# First try using solve_lp_via_dual (which uses complementary slackness)
|
360 |
+
print("\n=== Attempting to solve via Dual Approach with Complementary Slackness ===")
|
361 |
+
status, message, primal_sol, dual_sol, obj_val = solve_lp_via_dual(
|
362 |
+
self.objective_type,
|
363 |
+
self.original_c,
|
364 |
+
self.original_A,
|
365 |
+
self.original_relations,
|
366 |
+
self.original_b
|
367 |
+
)
|
368 |
+
|
369 |
+
results["dual_approach_result"] = {
|
370 |
+
"status": status,
|
371 |
+
"message": message,
|
372 |
+
"primal_solution": primal_sol,
|
373 |
+
"dual_solution": dual_sol,
|
374 |
+
"objective_value": obj_val
|
375 |
+
}
|
376 |
+
|
377 |
+
print(f"Dual Approach Result: {message}")
|
378 |
+
if status == 0 and primal_sol:
|
379 |
+
print(f"Objective Value: {obj_val}")
|
380 |
+
return results
|
381 |
+
|
382 |
+
# If that fails, try direct method (most robust)
|
383 |
+
print("\n=== Attempting direct solution using SciPy's linprog solver ===")
|
384 |
+
status, message, primal_sol, _, obj_val = solve_primal_directly(
|
385 |
+
self.objective_type,
|
386 |
+
self.original_c,
|
387 |
+
self.original_A,
|
388 |
+
self.original_relations,
|
389 |
+
self.original_b
|
390 |
+
)
|
391 |
+
|
392 |
+
results["direct_solver_result"] = {
|
393 |
+
"status": status,
|
394 |
+
"message": message,
|
395 |
+
"primal_solution": primal_sol,
|
396 |
+
"objective_value": obj_val
|
397 |
+
}
|
398 |
+
|
399 |
+
print(f"Direct Solver Result: {message}")
|
400 |
+
if status == 0 and primal_sol:
|
401 |
+
print(f"Objective Value: {obj_val}")
|
402 |
+
|
403 |
+
return results
|
404 |
+
|
405 |
+
def _print_results(self):
|
406 |
+
"""Prints the final solution."""
|
407 |
+
print("\n--- Final Solution ---")
|
408 |
+
self._print_tableau("Final")
|
409 |
+
|
410 |
+
# Objective Value
|
411 |
+
final_obj_value = self.tableau[0, -1]
|
412 |
+
if self.is_minimized_problem:
|
413 |
+
final_obj_value = -final_obj_value # Correct for Min Z = -Max(-Z)
|
414 |
+
print(f"Optimal Objective Value (Min Z): {final_obj_value:.6f}")
|
415 |
+
else:
|
416 |
+
print(f"Optimal Objective Value (Max Z): {final_obj_value:.6f}")
|
417 |
+
|
418 |
+
# Variable Values
|
419 |
+
solution = {}
|
420 |
+
num_total_vars = len(self.var_names)
|
421 |
+
final_solution_vector = np.zeros(num_total_vars)
|
422 |
+
|
423 |
+
for i, basis_col_idx in enumerate(self.basic_vars):
|
424 |
+
# basis_col_idx is the index in the var_names list
|
425 |
+
# The corresponding tableau row is i + 1
|
426 |
+
final_solution_vector[basis_col_idx] = self.tableau[i + 1, -1]
|
427 |
+
|
428 |
+
print("Optimal Variable Values:")
|
429 |
+
for i in range(self.num_original_vars):
|
430 |
+
var_name = self.var_names[i]
|
431 |
+
value = final_solution_vector[i]
|
432 |
+
print(f" {var_name}: {value:.6f}")
|
433 |
+
solution[var_name] = value
|
434 |
+
|
435 |
+
# Optionally print slack variable values
|
436 |
+
print("Slack/Surplus Variable Values:")
|
437 |
+
for i in range(self.num_original_vars, num_total_vars):
|
438 |
+
var_name = self.var_names[i]
|
439 |
+
value = final_solution_vector[i]
|
440 |
+
# Only print non-zero slacks for brevity, or all if needed
|
441 |
+
if abs(value) > TOLERANCE:
|
442 |
+
print(f" {var_name}: {value:.6f}")
|
443 |
+
|
maths/university/operations_research/bnb.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
maths/university/operations_research/dual.ipynb
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 3,
|
6 |
+
"metadata": {},
|
7 |
+
"outputs": [
|
8 |
+
{
|
9 |
+
"name": "stdout",
|
10 |
+
"output_type": "stream",
|
11 |
+
"text": [
|
12 |
+
"\n",
|
13 |
+
"Warning: Initial tableau is not dual feasible (objective row has negative coefficients).\n",
|
14 |
+
"The standard Dual Simplex method might not apply directly or may require Phase I.\n",
|
15 |
+
"--- Starting Dual Simplex Method ---\n",
|
16 |
+
"\n",
|
17 |
+
"--- Iteration 0 ---\n",
|
18 |
+
" BV Z x1 x2 s1 s2 s3 RHS\n",
|
19 |
+
"------------------------------------------------------------------------\n",
|
20 |
+
" Z 1.000 1.000 -2.000 0.000 0.000 0.000 0.000\n",
|
21 |
+
" s1 0.000 -5.000 -4.000 1.000 0.000 0.000 -20.000\n",
|
22 |
+
" s2 0.000 1.000 5.000 0.000 1.000 0.000 10.000\n",
|
23 |
+
" s3 0.000 -1.000 -5.000 0.000 0.000 1.000 -10.000\n",
|
24 |
+
"------------------------------------------------------------------------\n",
|
25 |
+
"\n",
|
26 |
+
"Step: Select Pivot Row (Leaving Variable)\n",
|
27 |
+
" RHS values (b): [-20. 10. -10.]\n",
|
28 |
+
" Most negative RHS is -20.000 in Row 1 (Basic Var: s1).\n",
|
29 |
+
" Leaving Variable: s1 (Row 1)\n",
|
30 |
+
"\n",
|
31 |
+
"Step: Select Pivot Column (Entering Variable) using Ratio Test\n",
|
32 |
+
" Pivot Row (Row 1) coefficients (excluding Z, RHS): [-5. -4. 1. 0. 0.]\n",
|
33 |
+
" Objective Row coefficients (excluding Z, RHS): [ 1. -2. 0. 0. 0.]\n",
|
34 |
+
" Calculating ratios = ObjCoeff / abs(PivotRowCoeff) for PivotRowCoeff < 0:\n",
|
35 |
+
" Var x1 (Col 1): Coeff=-5.000, ObjCoeff=1.000, Ratio = 1.000 / 5.000 = 0.200\n",
|
36 |
+
" Var x2 (Col 2): Coeff=-4.000, ObjCoeff=-2.000, Ratio = -2.000 / 4.000 = -0.500\n",
|
37 |
+
" Minimum ratio is -0.500 for variable x2 (Column 2).\n",
|
38 |
+
" Entering Variable: x2 (Column 2)\n",
|
39 |
+
"\n",
|
40 |
+
"Step: Pivot Operation\n",
|
41 |
+
" Pivot Element: -4.000 at (Row 1, Col 2)\n",
|
42 |
+
" Normalizing Pivot Row 1 by dividing by -4.000\n",
|
43 |
+
" Eliminating other entries in Pivot Column 2:\n",
|
44 |
+
" Row 0 = Row 0 - (-2.000) * (New Row 1)\n",
|
45 |
+
" Row 2 = Row 2 - (5.000) * (New Row 1)\n",
|
46 |
+
" Row 3 = Row 3 - (-5.000) * (New Row 1)\n",
|
47 |
+
" Updating Basic Variables: x2 replaces s1 in the basis for Row 1.\n",
|
48 |
+
"\n",
|
49 |
+
"--- Iteration 1 ---\n",
|
50 |
+
" BV Z x1 x2 s1 s2 s3 RHS\n",
|
51 |
+
"------------------------------------------------------------------------\n",
|
52 |
+
" Z 1.000 3.500 0.000 -0.500 0.000 0.000 10.000\n",
|
53 |
+
" x2 -0.000 1.250 1.000 -0.250 -0.000 -0.000 5.000\n",
|
54 |
+
" s2 0.000 -5.250 0.000 1.250 1.000 0.000 -15.000\n",
|
55 |
+
" s3 0.000 5.250 0.000 -1.250 0.000 1.000 15.000\n",
|
56 |
+
"------------------------------------------------------------------------\n",
|
57 |
+
"\n",
|
58 |
+
"Step: Select Pivot Row (Leaving Variable)\n",
|
59 |
+
" RHS values (b): [ 5. -15. 15.]\n",
|
60 |
+
" Most negative RHS is -15.000 in Row 2 (Basic Var: s2).\n",
|
61 |
+
" Leaving Variable: s2 (Row 2)\n",
|
62 |
+
"\n",
|
63 |
+
"Step: Select Pivot Column (Entering Variable) using Ratio Test\n",
|
64 |
+
" Pivot Row (Row 2) coefficients (excluding Z, RHS): [-5.25 0. 1.25 1. 0. ]\n",
|
65 |
+
" Objective Row coefficients (excluding Z, RHS): [ 3.5 0. -0.5 0. 0. ]\n",
|
66 |
+
" Calculating ratios = ObjCoeff / abs(PivotRowCoeff) for PivotRowCoeff < 0:\n",
|
67 |
+
" Var x1 (Col 1): Coeff=-5.250, ObjCoeff=3.500, Ratio = 3.500 / 5.250 = 0.667\n",
|
68 |
+
" Minimum ratio is 0.667 for variable x1 (Column 1).\n",
|
69 |
+
" Entering Variable: x1 (Column 1)\n",
|
70 |
+
"\n",
|
71 |
+
"Step: Pivot Operation\n",
|
72 |
+
" Pivot Element: -5.250 at (Row 2, Col 1)\n",
|
73 |
+
" Normalizing Pivot Row 2 by dividing by -5.250\n",
|
74 |
+
" Eliminating other entries in Pivot Column 1:\n",
|
75 |
+
" Row 0 = Row 0 - (3.500) * (New Row 2)\n",
|
76 |
+
" Row 1 = Row 1 - (1.250) * (New Row 2)\n",
|
77 |
+
" Row 3 = Row 3 - (5.250) * (New Row 2)\n",
|
78 |
+
" Updating Basic Variables: x1 replaces s2 in the basis for Row 2.\n",
|
79 |
+
"\n",
|
80 |
+
"--- Iteration 2 ---\n",
|
81 |
+
" BV Z x1 x2 s1 s2 s3 RHS\n",
|
82 |
+
"------------------------------------------------------------------------\n",
|
83 |
+
" Z 1.000 0.000 0.000 0.333 0.667 0.000 0.000\n",
|
84 |
+
" x2 0.000 0.000 1.000 0.048 0.238 0.000 1.429\n",
|
85 |
+
" x1 -0.000 1.000 -0.000 -0.238 -0.190 -0.000 2.857\n",
|
86 |
+
" s3 0.000 0.000 0.000 0.000 1.000 1.000 0.000\n",
|
87 |
+
"------------------------------------------------------------------------\n",
|
88 |
+
"\n",
|
89 |
+
"--- Optimal Solution Found ---\n",
|
90 |
+
" All RHS values are non-negative.\n",
|
91 |
+
"\n",
|
92 |
+
"--- Final Solution ---\n",
|
93 |
+
"\n",
|
94 |
+
"--- Iteration Final ---\n",
|
95 |
+
" BV Z x1 x2 s1 s2 s3 RHS\n",
|
96 |
+
"------------------------------------------------------------------------\n",
|
97 |
+
" Z 1.000 0.000 0.000 0.333 0.667 0.000 0.000\n",
|
98 |
+
" x2 0.000 0.000 1.000 0.048 0.238 0.000 1.429\n",
|
99 |
+
" x1 -0.000 1.000 -0.000 -0.238 -0.190 -0.000 2.857\n",
|
100 |
+
" s3 0.000 0.000 0.000 0.000 1.000 1.000 0.000\n",
|
101 |
+
"------------------------------------------------------------------------\n",
|
102 |
+
"Optimal Objective Value (Max Z): 0.000000\n",
|
103 |
+
"Optimal Variable Values:\n",
|
104 |
+
" x1: 2.857143\n",
|
105 |
+
" x2: 1.428571\n",
|
106 |
+
"Slack/Surplus Variable Values:\n"
|
107 |
+
]
|
108 |
+
}
|
109 |
+
],
|
110 |
+
"source": [
|
111 |
+
"from functions.DualSimplexSolver import DualSimplexSolver\n",
|
112 |
+
"if __name__ == \"__main__\":\n",
|
113 |
+
" try:\n",
|
114 |
+
" example_obj_type = 'max'\n",
|
115 |
+
" example_c = [-1,2]\n",
|
116 |
+
" example_A = [\n",
|
117 |
+
" [5,4],\n",
|
118 |
+
" [1,5],\n",
|
119 |
+
" ]\n",
|
120 |
+
" example_relations = ['>=', '=']\n",
|
121 |
+
" example_b = [20,10]\n",
|
122 |
+
"\n",
|
123 |
+
" solver = DualSimplexSolver(example_obj_type, example_c, example_A, example_relations, example_b)\n",
|
124 |
+
" solver.solve()\n",
|
125 |
+
" except Exception as e:\n",
|
126 |
+
" print(f\"\\nAn error occurred: {e}\")\n",
|
127 |
+
" import traceback\n",
|
128 |
+
" traceback.print_exc()"
|
129 |
+
]
|
130 |
+
}
|
131 |
+
],
|
132 |
+
"metadata": {
|
133 |
+
"kernelspec": {
|
134 |
+
"display_name": ".venv (3.13.2)",
|
135 |
+
"language": "python",
|
136 |
+
"name": "python3"
|
137 |
+
},
|
138 |
+
"language_info": {
|
139 |
+
"codemirror_mode": {
|
140 |
+
"name": "ipython",
|
141 |
+
"version": 3
|
142 |
+
},
|
143 |
+
"file_extension": ".py",
|
144 |
+
"mimetype": "text/x-python",
|
145 |
+
"name": "python",
|
146 |
+
"nbconvert_exporter": "python",
|
147 |
+
"pygments_lexer": "ipython3",
|
148 |
+
"version": "3.13.2"
|
149 |
+
}
|
150 |
+
},
|
151 |
+
"nbformat": 4,
|
152 |
+
"nbformat_minor": 2
|
153 |
+
}
|
maths/university/operations_research/get_user_input.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
def get_user_input():
|
2 |
+
"""Gets the LP problem details from the user."""
|
3 |
+
# (Identical to the input function in the previous Dual Simplex example)
|
4 |
+
print("Enter the Linear Programming Problem:")
|
5 |
+
while True:
|
6 |
+
obj_type = input("Is it a 'max' or 'min' problem? ").strip().lower()
|
7 |
+
if obj_type in ['max', 'min']:
|
8 |
+
break
|
9 |
+
print("Invalid input. Please enter 'max' or 'min'.")
|
10 |
+
c_str = input("Enter objective function coefficients (space-separated): ").strip()
|
11 |
+
c = list(map(float, c_str.split()))
|
12 |
+
num_vars = len(c)
|
13 |
+
A = []
|
14 |
+
relations = []
|
15 |
+
b = []
|
16 |
+
i = 1
|
17 |
+
print(f"Enter constraints (one per line). There are {num_vars} variables (x1, x2,...).")
|
18 |
+
print("Format: coeff1 coeff2 ... relation rhs (e.g., '3 1 >= 3')")
|
19 |
+
print("Type 'done' when finished.")
|
20 |
+
while True:
|
21 |
+
line = input(f"Constraint {i}: ").strip()
|
22 |
+
if line.lower() == 'done':
|
23 |
+
if i == 1:
|
24 |
+
print("Error: At least one constraint is required.")
|
25 |
+
continue
|
26 |
+
break
|
27 |
+
parts = line.split()
|
28 |
+
if len(parts) != num_vars + 2:
|
29 |
+
print(f"Error: Expected {num_vars} coefficients, 1 relation, and 1 RHS value.")
|
30 |
+
continue
|
31 |
+
try:
|
32 |
+
coeffs = list(map(float, parts[:num_vars]))
|
33 |
+
relation = parts[num_vars]
|
34 |
+
rhs = float(parts[num_vars + 1])
|
35 |
+
if relation not in ['<=', '>=', '=']:
|
36 |
+
print("Error: Invalid relation symbol. Use '<=', '>=', or '='.")
|
37 |
+
continue
|
38 |
+
A.append(coeffs)
|
39 |
+
relations.append(relation)
|
40 |
+
b.append(rhs)
|
41 |
+
i += 1
|
42 |
+
except ValueError:
|
43 |
+
print("Error: Invalid number format for coefficients or RHS.")
|
44 |
+
continue
|
45 |
+
return obj_type, c, A, relations, b
|
46 |
+
|
maths/university/operations_research/simplex_solver_with_steps.py
ADDED
@@ -0,0 +1,162 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import streamlit as st
|
3 |
+
from tabulate import tabulate
|
4 |
+
|
5 |
+
|
6 |
+
def simplex_solver_with_steps(c, A, b, bounds):
|
7 |
+
"""
|
8 |
+
Solve LP using simplex method and display full tableau at each step
|
9 |
+
|
10 |
+
Parameters:
|
11 |
+
- c: Objective coefficients (for maximizing c'x)
|
12 |
+
- A: Constraint coefficients matrix
|
13 |
+
- b: Right-hand side of constraints
|
14 |
+
- bounds: Variable bounds as [(lower_1, upper_1), (lower_2, upper_2), ...]
|
15 |
+
|
16 |
+
Returns:
|
17 |
+
- x: Optimal solution
|
18 |
+
- optimal_value: Optimal objective value
|
19 |
+
"""
|
20 |
+
st.markdown("\n--- Starting Simplex Method ---")
|
21 |
+
st.text(f"Objective: Maximize {' + '.join([f'{c[i]}x_{i}' for i in range(len(c))])}")
|
22 |
+
st.text(f"Constraints:")
|
23 |
+
for i in range(len(b)):
|
24 |
+
constraint_str = ' + '.join([f"{A[i,j]}x_{j}" for j in range(A.shape[1])])
|
25 |
+
st.text(f" {constraint_str} <= {b[i]}")
|
26 |
+
|
27 |
+
# Convert problem to standard form (for tableau method)
|
28 |
+
# First handle bounds by adding necessary constraints
|
29 |
+
A_with_bounds = A.copy()
|
30 |
+
b_with_bounds = b.copy()
|
31 |
+
|
32 |
+
for i, (lb, ub) in enumerate(bounds):
|
33 |
+
if lb is not None and lb > 0:
|
34 |
+
# For variables with lower bounds > 0, we'll substitute x_i = x_i' + lb
|
35 |
+
# This affects all constraints where x_i appears
|
36 |
+
for j in range(A.shape[0]):
|
37 |
+
b_with_bounds[j] -= A[j, i] * lb
|
38 |
+
|
39 |
+
# Number of variables and constraints
|
40 |
+
n_vars = len(c)
|
41 |
+
n_constraints = A.shape[0]
|
42 |
+
|
43 |
+
# Add slack variables to create standard form
|
44 |
+
# The tableau will have: [objective row | RHS]
|
45 |
+
# [-------------|----]
|
46 |
+
# [constraints | RHS]
|
47 |
+
|
48 |
+
# Initial tableau:
|
49 |
+
# First row is -c (negative of objective coefficients) and 0s for slack variables, then 0 (for max)
|
50 |
+
# The rest are constraint coefficients, then identity matrix for slack variables, then RHS
|
51 |
+
tableau = np.zeros((n_constraints + 1, n_vars + n_constraints + 1))
|
52 |
+
|
53 |
+
# Set the objective row (negated for maximization)
|
54 |
+
tableau[0, :n_vars] = -c
|
55 |
+
|
56 |
+
# Set the constraint coefficients
|
57 |
+
tableau[1:, :n_vars] = A_with_bounds
|
58 |
+
|
59 |
+
# Set the slack variable coefficients (identity matrix)
|
60 |
+
for i in range(n_constraints):
|
61 |
+
tableau[i + 1, n_vars + i] = 1
|
62 |
+
|
63 |
+
# Set the RHS
|
64 |
+
tableau[1:, -1] = b_with_bounds
|
65 |
+
|
66 |
+
# Base and non-base variables
|
67 |
+
base_vars = list(range(n_vars, n_vars + n_constraints)) # Slack variables are initially basic
|
68 |
+
|
69 |
+
# Function to print current tableau
|
70 |
+
def print_tableau(tableau, base_vars):
|
71 |
+
headers = [f"x_{j}" for j in range(n_vars)] + [f"s_{j}" for j in range(n_constraints)] + ["RHS"]
|
72 |
+
rows = []
|
73 |
+
row_labels = ["z"] + [f"eq_{i}" for i in range(n_constraints)]
|
74 |
+
|
75 |
+
for i, row in enumerate(tableau):
|
76 |
+
rows.append([row_labels[i]] + [f"{val:.3f}" for val in row])
|
77 |
+
|
78 |
+
st.text("\nCurrent Tableau:")
|
79 |
+
st.text(tabulate(rows, headers=headers, tablefmt="grid"))
|
80 |
+
st.text(f"Basic variables: {[f'x_{v}' if v < n_vars else f's_{v-n_vars}' for v in base_vars]}")
|
81 |
+
|
82 |
+
# Print initial tableau
|
83 |
+
st.text("\nInitial tableau:")
|
84 |
+
print_tableau(tableau, base_vars)
|
85 |
+
|
86 |
+
# Main simplex loop
|
87 |
+
iteration = 0
|
88 |
+
max_iterations = 100 # Prevent infinite loops
|
89 |
+
|
90 |
+
while iteration < max_iterations:
|
91 |
+
iteration += 1
|
92 |
+
st.text(f"\n--- Iteration {iteration} ---")
|
93 |
+
|
94 |
+
# Find the entering variable (most negative coefficient in objective row for maximization)
|
95 |
+
entering_col = np.argmin(tableau[0, :-1])
|
96 |
+
if tableau[0, entering_col] >= -1e-10: # Small negative numbers due to floating-point errors
|
97 |
+
st.text("Optimal solution reached - no negative coefficients in objective row")
|
98 |
+
break
|
99 |
+
|
100 |
+
st.text(f"Entering variable: {'x_' + str(entering_col) if entering_col < n_vars else 's_' + str(entering_col - n_vars)}")
|
101 |
+
|
102 |
+
# Find the leaving variable using min ratio test
|
103 |
+
ratios = []
|
104 |
+
for i in range(1, n_constraints + 1):
|
105 |
+
if tableau[i, entering_col] <= 0:
|
106 |
+
ratios.append(np.inf) # Avoid division by zero or negative
|
107 |
+
else:
|
108 |
+
ratios.append(tableau[i, -1] / tableau[i, entering_col])
|
109 |
+
|
110 |
+
if all(r == np.inf for r in ratios):
|
111 |
+
st.text("Unbounded solution - no leaving variable found")
|
112 |
+
return None, float('inf') # Problem is unbounded
|
113 |
+
|
114 |
+
# Find the row with minimum ratio
|
115 |
+
leaving_row = np.argmin(ratios) + 1 # +1 because we skip the objective row
|
116 |
+
leaving_var = base_vars[leaving_row - 1]
|
117 |
+
|
118 |
+
st.text(f"Leaving variable: {'x_' + str(leaving_var) if leaving_var < n_vars else 's_' + str(leaving_var - n_vars)}")
|
119 |
+
st.text(f"Pivot element: {tableau[leaving_row, entering_col]:.3f} at row {leaving_row}, column {entering_col}")
|
120 |
+
|
121 |
+
# Perform pivot operation
|
122 |
+
# First, normalize the pivot row
|
123 |
+
pivot = tableau[leaving_row, entering_col]
|
124 |
+
tableau[leaving_row] = tableau[leaving_row] / pivot
|
125 |
+
|
126 |
+
# Update other rows
|
127 |
+
for i in range(tableau.shape[0]):
|
128 |
+
if i != leaving_row:
|
129 |
+
factor = tableau[i, entering_col]
|
130 |
+
tableau[i] = tableau[i] - factor * tableau[leaving_row]
|
131 |
+
|
132 |
+
# Update basic variables
|
133 |
+
base_vars[leaving_row - 1] = entering_col
|
134 |
+
|
135 |
+
# Print updated tableau
|
136 |
+
st.text("\nAfter pivot:")
|
137 |
+
print_tableau(tableau, base_vars)
|
138 |
+
|
139 |
+
if iteration == max_iterations:
|
140 |
+
st.text("Max iterations reached without convergence")
|
141 |
+
return None, None
|
142 |
+
|
143 |
+
# Extract solution
|
144 |
+
x = np.zeros(n_vars)
|
145 |
+
for i, var in enumerate(base_vars):
|
146 |
+
if var < n_vars: # If it's an original variable and not a slack
|
147 |
+
x[var] = tableau[i + 1, -1]
|
148 |
+
|
149 |
+
# Account for variable substitutions (if lower bounds were applied)
|
150 |
+
for i, (lb, _) in enumerate(bounds):
|
151 |
+
if lb is not None and lb > 0:
|
152 |
+
x[i] += lb
|
153 |
+
|
154 |
+
# Calculate objective value
|
155 |
+
optimal_value = np.dot(c, x)
|
156 |
+
|
157 |
+
st.markdown("\n--- Simplex Method Complete ---")
|
158 |
+
st.text(f"Optimal solution found: {x}")
|
159 |
+
st.text(f"Optimal objective value: {optimal_value}")
|
160 |
+
|
161 |
+
return x, optimal_value
|
162 |
+
|
maths/university/operations_research/solve_lp_via_dual.py
ADDED
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import numpy as np
|
3 |
+
from scipy.optimize import linprog # Using SciPy for robust LP solving
|
4 |
+
import warnings
|
5 |
+
from .solve_primal_directly import solve_primal_directly
|
6 |
+
|
7 |
+
def solve_lp_via_dual(objective_type, c, A, relations, b, TOLERANCE=1e-6):
|
8 |
+
"""
|
9 |
+
Solves an LP problem by formulating and solving the dual, then using
|
10 |
+
complementary slackness.
|
11 |
+
|
12 |
+
Args:
|
13 |
+
objective_type (str): 'max' or 'min'.
|
14 |
+
c (list or np.array): Primal objective coefficients.
|
15 |
+
A (list of lists or np.array): Primal constraint matrix LHS.
|
16 |
+
relations (list): Primal constraint relations ('<=', '>=', '=').
|
17 |
+
b (list or np.array): Primal constraint RHS.
|
18 |
+
|
19 |
+
Returns:
|
20 |
+
A tuple: (status, message, primal_solution, dual_solution, objective_value)
|
21 |
+
status: 0 for success, non-zero for failure.
|
22 |
+
message: Description of the outcome.
|
23 |
+
primal_solution: Dictionary of primal variable values (x).
|
24 |
+
dual_solution: Dictionary of dual variable values (p).
|
25 |
+
objective_value: Optimal objective value of the primal.
|
26 |
+
"""
|
27 |
+
primal_c = np.array(c, dtype=float)
|
28 |
+
primal_A = np.array(A, dtype=float)
|
29 |
+
primal_b = np.array(b, dtype=float)
|
30 |
+
primal_relations = relations
|
31 |
+
num_primal_vars = len(primal_c)
|
32 |
+
num_primal_constraints = len(primal_b)
|
33 |
+
|
34 |
+
print("--- Step 1: Formulate the Dual Problem ---")
|
35 |
+
|
36 |
+
# Standardize Primal for consistent dual formulation: Convert to Min problem
|
37 |
+
original_obj_type = objective_type.lower()
|
38 |
+
if original_obj_type == 'max':
|
39 |
+
print(" Primal is Max. Converting to Min by negating objective coefficients.")
|
40 |
+
primal_c_std = -primal_c
|
41 |
+
objective_sign_flip = -1.0
|
42 |
+
else:
|
43 |
+
print(" Primal is Min. Using original objective coefficients.")
|
44 |
+
primal_c_std = primal_c
|
45 |
+
objective_sign_flip = 1.0
|
46 |
+
|
47 |
+
# Handle constraint relations for dual formulation
|
48 |
+
# We'll formulate the dual based on a standard primal form:
|
49 |
+
# Min c'x s.t. Ax >= b, x >= 0
|
50 |
+
# Dual: Max p'b s.t. p'A <= c', p >= 0
|
51 |
+
# Let's adjust the input primal constraints to fit the Ax >= b form needed for this dual pair.
|
52 |
+
|
53 |
+
A_geq = []
|
54 |
+
b_geq = []
|
55 |
+
print(" Adjusting primal constraints to >= form for dual formulation:")
|
56 |
+
for i in range(num_primal_constraints):
|
57 |
+
rel = primal_relations[i]
|
58 |
+
a_row = primal_A[i]
|
59 |
+
b_val = primal_b[i]
|
60 |
+
|
61 |
+
if rel == '<=':
|
62 |
+
print(f" Constraint {i+1}: Multiplying by -1 ( {a_row} <= {b_val} -> {-a_row} >= {-b_val} )")
|
63 |
+
A_geq.append(-a_row)
|
64 |
+
b_geq.append(-b_val)
|
65 |
+
elif rel == '>=':
|
66 |
+
print(f" Constraint {i+1}: Keeping as >= ( {a_row} >= {b_val} )")
|
67 |
+
A_geq.append(a_row)
|
68 |
+
b_geq.append(b_val)
|
69 |
+
elif rel == '=':
|
70 |
+
# Represent equality as two inequalities: >= and <=
|
71 |
+
print(f" Constraint {i+1}: Splitting '=' into >= and <= ")
|
72 |
+
# >= part
|
73 |
+
print(f" Part 1: {a_row} >= {b_val}")
|
74 |
+
A_geq.append(a_row)
|
75 |
+
b_geq.append(b_val)
|
76 |
+
# <= part -> multiply by -1 to get >=
|
77 |
+
print(f" Part 2: {-a_row} >= {-b_val} (from {a_row} <= {b_val})")
|
78 |
+
A_geq.append(-a_row)
|
79 |
+
b_geq.append(-b_val)
|
80 |
+
else:
|
81 |
+
return 1, f"Invalid relation '{rel}' in constraint {i+1}", None, None, None
|
82 |
+
|
83 |
+
primal_A_std = np.array(A_geq, dtype=float)
|
84 |
+
primal_b_std = np.array(b_geq, dtype=float)
|
85 |
+
num_dual_vars = primal_A_std.shape[0] # One dual var per standardized constraint
|
86 |
+
|
87 |
+
# Now formulate the dual: Max p' * primal_b_std s.t. p' * primal_A_std <= primal_c_std, p >= 0
|
88 |
+
dual_c = primal_b_std # Coefficients for Maximize objective
|
89 |
+
dual_A = primal_A_std.T # Transpose A
|
90 |
+
dual_b = primal_c_std # RHS for dual constraints (<= type)
|
91 |
+
|
92 |
+
print("\n Dual Problem Formulation:")
|
93 |
+
print(f" Objective: Maximize p * [{', '.join(f'{bi:.2f}' for bi in dual_c)}]")
|
94 |
+
print(f" Subject to:")
|
95 |
+
for j in range(dual_A.shape[0]): # Iterate through dual constraints
|
96 |
+
print(f" {' + '.join(f'{dual_A[j, i]:.2f}*p{i+1}' for i in range(num_dual_vars))} <= {dual_b[j]:.2f}")
|
97 |
+
print(f" p_i >= 0 for i=1..{num_dual_vars}")
|
98 |
+
|
99 |
+
print("\n--- Step 2: Solve the Dual Problem using SciPy linprog ---")
|
100 |
+
# linprog solves Min problems, so we Max p*b by Min -p*b
|
101 |
+
# Constraints for linprog: A_ub @ x <= b_ub, A_eq @ x == b_eq
|
102 |
+
# Our dual is Max p*b s.t. p*A <= c. Let p be x for linprog.
|
103 |
+
# Maximize dual_c @ p => Minimize -dual_c @ p
|
104 |
+
# Subject to: dual_A @ p <= dual_b
|
105 |
+
# p >= 0 (default bounds)
|
106 |
+
|
107 |
+
c_linprog = -dual_c
|
108 |
+
A_ub_linprog = dual_A
|
109 |
+
b_ub_linprog = dual_b
|
110 |
+
|
111 |
+
# Use method='highs' which is the default and generally robust
|
112 |
+
# Options can be added for more control if needed
|
113 |
+
try:
|
114 |
+
result_dual = linprog(c_linprog, A_ub=A_ub_linprog, b_ub=b_ub_linprog, bounds=[(0, None)] * num_dual_vars, method='highs') # Using HiGHS solver
|
115 |
+
except ValueError as e:
|
116 |
+
# Sometimes specific solvers are not available or fail
|
117 |
+
print(f" SciPy linprog(method='highs') failed: {e}. Trying 'simplex'.")
|
118 |
+
try:
|
119 |
+
result_dual = linprog(c_linprog, A_ub=A_ub_linprog, b_ub=b_ub_linprog, bounds=[(0, None)] * num_dual_vars, method='simplex')
|
120 |
+
except Exception as e_simplex:
|
121 |
+
return 1, f"SciPy linprog failed for dual problem with both 'highs' and 'simplex': {e_simplex}", None, None, None
|
122 |
+
|
123 |
+
|
124 |
+
if not result_dual.success:
|
125 |
+
# Check status for specific reasons
|
126 |
+
if result_dual.status == 2: # Infeasible
|
127 |
+
msg = "Dual problem is infeasible. Primal problem is unbounded (or infeasible)."
|
128 |
+
elif result_dual.status == 3: # Unbounded
|
129 |
+
msg = "Dual problem is unbounded. Primal problem is infeasible."
|
130 |
+
else:
|
131 |
+
msg = f"Failed to solve the dual problem. Status: {result_dual.status} - {result_dual.message}"
|
132 |
+
return result_dual.status, msg, None, None, None
|
133 |
+
|
134 |
+
# Optimal dual solution found
|
135 |
+
optimal_dual_p = result_dual.x
|
136 |
+
optimal_dual_obj = -result_dual.fun # Negate back to get Max value
|
137 |
+
|
138 |
+
dual_solution_dict = {f'p{i+1}': optimal_dual_p[i] for i in range(num_dual_vars)}
|
139 |
+
|
140 |
+
print("\n Optimal Dual Solution Found:")
|
141 |
+
print(f" Dual Variables (p*):")
|
142 |
+
for i in range(num_dual_vars):
|
143 |
+
print(f" p{i+1} = {optimal_dual_p[i]:.6f}")
|
144 |
+
print(f" Optimal Dual Objective Value (Max p*b): {optimal_dual_obj:.6f}")
|
145 |
+
|
146 |
+
print("\n--- Step 3: Check Strong Duality ---")
|
147 |
+
# The optimal objective value of the dual should equal the optimal objective
|
148 |
+
# value of the primal (after adjusting for Min/Max conversion).
|
149 |
+
expected_primal_obj = optimal_dual_obj * objective_sign_flip
|
150 |
+
print(f" Strong duality implies the optimal primal objective value should be: {expected_primal_obj:.6f}")
|
151 |
+
|
152 |
+
print("\n--- Step 4 & 5: Use Complementary Slackness to find Primal Variables ---")
|
153 |
+
|
154 |
+
# Calculate Dual Slacks: dual_slack = dual_b - dual_A @ optimal_dual_p
|
155 |
+
# dual_b is primal_c_std
|
156 |
+
# dual_A is primal_A_std.T
|
157 |
+
# dual_slack_j = primal_c_std_j - (optimal_dual_p @ primal_A_std)_j
|
158 |
+
dual_slacks = dual_b - dual_A @ optimal_dual_p
|
159 |
+
print(" Calculating Dual Slacks (c'_j - p* A'_j):")
|
160 |
+
for j in range(num_primal_vars):
|
161 |
+
print(f" Dual Slack for primal var x{j+1}: {dual_slacks[j]:.6f}")
|
162 |
+
|
163 |
+
|
164 |
+
# Identify conditions from Complementary Slackness (CS)
|
165 |
+
# 1. Dual Slackness: If dual_slack_j > TOLERANCE, then primal x*_j = 0
|
166 |
+
# 2. Primal Slackness: If optimal_dual_p_i > TOLERANCE, then i-th standardized primal constraint is binding
|
167 |
+
# (primal_A_std[i] @ x* = primal_b_std[i])
|
168 |
+
|
169 |
+
binding_constraints_indices = []
|
170 |
+
zero_primal_vars_indices = []
|
171 |
+
|
172 |
+
print("\n Applying Complementary Slackness Conditions:")
|
173 |
+
# Dual Slackness
|
174 |
+
print(" From Dual Slackness (if c'_j - p* A'_j > 0, then x*_j = 0):")
|
175 |
+
for j in range(num_primal_vars):
|
176 |
+
if dual_slacks[j] > TOLERANCE:
|
177 |
+
print(f" Dual Slack for x{j+1} is {dual_slacks[j]:.4f} > 0 => x{j+1}* = 0")
|
178 |
+
zero_primal_vars_indices.append(j)
|
179 |
+
else:
|
180 |
+
print(f" Dual Slack for x{j+1} is {dual_slacks[j]:.4f} approx 0 => x{j+1}* may be non-zero")
|
181 |
+
|
182 |
+
|
183 |
+
# Primal Slackness
|
184 |
+
print(" From Primal Slackness (if p*_i > 0, then primal constraint i is binding):")
|
185 |
+
for i in range(num_dual_vars):
|
186 |
+
if optimal_dual_p[i] > TOLERANCE:
|
187 |
+
print(f" p*{i+1} = {optimal_dual_p[i]:.4f} > 0 => Primal constraint {i+1} (standardized) is binding.")
|
188 |
+
binding_constraints_indices.append(i)
|
189 |
+
else:
|
190 |
+
print(f" p*{i+1} = {optimal_dual_p[i]:.4f} approx 0 => Primal constraint {i+1} (standardized) may be non-binding.")
|
191 |
+
|
192 |
+
# Construct system of equations for non-zero primal variables
|
193 |
+
# Equations come from binding primal constraints.
|
194 |
+
# Variables are x*_j where j is NOT in zero_primal_vars_indices.
|
195 |
+
|
196 |
+
active_primal_vars_indices = [j for j in range(num_primal_vars) if j not in zero_primal_vars_indices]
|
197 |
+
num_active_primal_vars = len(active_primal_vars_indices)
|
198 |
+
|
199 |
+
print(f"\n Identifying system for active primal variables ({[f'x{j+1}' for j in active_primal_vars_indices]}):")
|
200 |
+
|
201 |
+
if num_active_primal_vars == 0:
|
202 |
+
# All primal vars are zero
|
203 |
+
primal_x_star = np.zeros(num_primal_vars)
|
204 |
+
print(" All primal variables determined to be 0 by dual slackness.")
|
205 |
+
elif len(binding_constraints_indices) < num_active_primal_vars:
|
206 |
+
print(f" Warning: Number of binding constraints ({len(binding_constraints_indices)}) identified is less than the number of potentially non-zero primal variables ({num_active_primal_vars}).")
|
207 |
+
print(" Complementary slackness alone might not be sufficient, or there might be degeneracy/multiple solutions.")
|
208 |
+
print(" Attempting to solve using available binding constraints, but result might be unreliable.")
|
209 |
+
# Pad with zero rows if necessary, or indicate underspecified system. For now, proceed cautiously.
|
210 |
+
matrix_A_sys = primal_A_std[binding_constraints_indices][:, active_primal_vars_indices]
|
211 |
+
vector_b_sys = primal_b_std[binding_constraints_indices]
|
212 |
+
|
213 |
+
else:
|
214 |
+
# We have at least as many binding constraints as active variables.
|
215 |
+
# Select num_active_primal_vars binding constraints to form a square system (if possible).
|
216 |
+
# If more binding constraints exist, they should be consistent.
|
217 |
+
# We take the first 'num_active_primal_vars' binding constraints.
|
218 |
+
if len(binding_constraints_indices) > num_active_primal_vars:
|
219 |
+
print(f" More binding constraints ({len(binding_constraints_indices)}) than active variables ({num_active_primal_vars}). Using the first {num_active_primal_vars}.")
|
220 |
+
|
221 |
+
matrix_A_sys = primal_A_std[binding_constraints_indices[:num_active_primal_vars]][:, active_primal_vars_indices]
|
222 |
+
vector_b_sys = primal_b_std[binding_constraints_indices[:num_active_primal_vars]]
|
223 |
+
|
224 |
+
print(" System Ax = b to solve:")
|
225 |
+
for r in range(matrix_A_sys.shape[0]):
|
226 |
+
print(f" {' + '.join(f'{matrix_A_sys[r, c]:.2f}*x{active_primal_vars_indices[c]+1}' for c in range(num_active_primal_vars))} = {vector_b_sys[r]:.2f}")
|
227 |
+
|
228 |
+
|
229 |
+
# Solve the system if possible
|
230 |
+
if num_active_primal_vars > 0:
|
231 |
+
try:
|
232 |
+
# Use numpy.linalg.solve for square systems, lstsq for potentially non-square
|
233 |
+
if matrix_A_sys.shape[0] == matrix_A_sys.shape[1]:
|
234 |
+
solved_active_vars = np.linalg.solve(matrix_A_sys, vector_b_sys)
|
235 |
+
elif matrix_A_sys.shape[0] > matrix_A_sys.shape[1]: # Overdetermined
|
236 |
+
print(" System is overdetermined. Using least squares solution.")
|
237 |
+
solved_active_vars, residuals, rank, s = np.linalg.lstsq(matrix_A_sys, vector_b_sys, rcond=None)
|
238 |
+
# Check if residuals are close to zero for consistency
|
239 |
+
if residuals and np.sum(residuals**2) > TOLERANCE * len(vector_b_sys):
|
240 |
+
print(f" Warning: Least squares solution has significant residuals ({np.sqrt(np.sum(residuals**2)):.4f}), CS conditions might be inconsistent?")
|
241 |
+
else: # Underdetermined
|
242 |
+
# Cannot uniquely solve. This shouldn't happen if dual was optimal and non-degenerate.
|
243 |
+
# Could use lstsq which gives one possible solution (minimum norm).
|
244 |
+
print(" System is underdetermined. Using least squares (minimum norm) solution.")
|
245 |
+
solved_active_vars, residuals, rank, s = np.linalg.lstsq(matrix_A_sys, vector_b_sys, rcond=None)
|
246 |
+
|
247 |
+
|
248 |
+
# Assign solved values back to the full primal_x_star vector
|
249 |
+
primal_x_star = np.zeros(num_primal_vars)
|
250 |
+
for i, active_idx in enumerate(active_primal_vars_indices):
|
251 |
+
primal_x_star[active_idx] = solved_active_vars[i]
|
252 |
+
|
253 |
+
print("\n Solved values for active primal variables:")
|
254 |
+
for i, active_idx in enumerate(active_primal_vars_indices):
|
255 |
+
print(f" x{active_idx+1}* = {solved_active_vars[i]:.6f}")
|
256 |
+
|
257 |
+
except np.linalg.LinAlgError:
|
258 |
+
print(" Error: Could not solve the system of equations derived from binding constraints (matrix may be singular).")
|
259 |
+
# Attempt to use linprog on the original primal as a fallback/check
|
260 |
+
print(" Attempting to solve primal directly with linprog as a fallback...")
|
261 |
+
primal_fallback_status, _, primal_fallback_sol, _, primal_fallback_obj = solve_primal_directly(
|
262 |
+
original_obj_type, c, A, relations, b)
|
263 |
+
if primal_fallback_status == 0:
|
264 |
+
print(" Fallback solution found.")
|
265 |
+
return 0, "Solved primal using fallback direct method after CS failure", primal_fallback_sol, dual_solution_dict, primal_fallback_obj
|
266 |
+
else:
|
267 |
+
return 1, "Failed to solve system from CS, and fallback primal solve also failed.", None, dual_solution_dict, None
|
268 |
+
|
269 |
+
|
270 |
+
# Assemble final primal solution dictionary
|
271 |
+
primal_solution_dict = {f'x{j+1}': primal_x_star[j] for j in range(num_primal_vars)}
|
272 |
+
|
273 |
+
# --- Step 6: Verify Primal Feasibility and Objective Value ---
|
274 |
+
print("\n--- Step 6: Verify Primal Solution ---")
|
275 |
+
feasible = True
|
276 |
+
print(" Checking primal constraints:")
|
277 |
+
for i in range(num_primal_constraints):
|
278 |
+
lhs_val = primal_A[i] @ primal_x_star
|
279 |
+
rhs_val = primal_b[i]
|
280 |
+
rel = primal_relations[i]
|
281 |
+
constraint_met = False
|
282 |
+
if rel == '<=':
|
283 |
+
constraint_met = lhs_val <= rhs_val + TOLERANCE
|
284 |
+
elif rel == '>=':
|
285 |
+
constraint_met = lhs_val >= rhs_val - TOLERANCE
|
286 |
+
elif rel == '=':
|
287 |
+
constraint_met = abs(lhs_val - rhs_val) < TOLERANCE
|
288 |
+
|
289 |
+
status_str = "Satisfied" if constraint_met else "VIOLATED"
|
290 |
+
print(f" Constraint {i+1}: {lhs_val:.4f} {rel} {rhs_val:.4f} -> {status_str}")
|
291 |
+
if not constraint_met:
|
292 |
+
feasible = False
|
293 |
+
|
294 |
+
print(" Checking non-negativity (x >= 0):")
|
295 |
+
non_negative = np.all(primal_x_star >= -TOLERANCE)
|
296 |
+
print(f" All x_j >= 0: {non_negative}")
|
297 |
+
if not non_negative:
|
298 |
+
feasible = False
|
299 |
+
print(f"Violating variables: {[f'x{j+1}={primal_x_star[j]:.4f}' for j in range(len(primal_x_star)) if primal_x_star[j] < -TOLERANCE]}")
|
300 |
+
|
301 |
+
final_primal_obj = primal_c @ primal_x_star # Using original primal c
|
302 |
+
print(f"\n Calculated Primal Objective Value: {final_primal_obj:.6f}")
|
303 |
+
print(f" Expected Primal Objective Value (from dual): {expected_primal_obj:.6f}")
|
304 |
+
|
305 |
+
if abs(final_primal_obj - expected_primal_obj) > TOLERANCE * (1 + abs(expected_primal_obj)):
|
306 |
+
print(" Warning: Calculated primal objective value significantly differs from the dual objective value!")
|
307 |
+
feasible = False # Consider this a failure if strong duality doesn't hold
|
308 |
+
|
309 |
+
if feasible:
|
310 |
+
print("\n--- Primal Solution Found Successfully via Dual ---")
|
311 |
+
return 0, "Optimal solution found via dual.", primal_solution_dict, dual_solution_dict, final_primal_obj
|
312 |
+
else:
|
313 |
+
print("\n--- Failed to Find Feasible Primal Solution via Dual ---")
|
314 |
+
print(" The derived primal solution violates constraints or non-negativity, or strong duality failed.")
|
315 |
+
# You might want to return the possibly incorrect solution for inspection or None
|
316 |
+
return 1, "Derived primal solution is infeasible or inconsistent.", primal_solution_dict, dual_solution_dict, final_primal_obj
|
317 |
+
|
maths/university/operations_research/solve_primal_directly.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import numpy as np
|
3 |
+
from scipy.optimize import linprog # Using SciPy for robust LP solving
|
4 |
+
import warnings
|
5 |
+
|
6 |
+
def solve_primal_directly(objective_type, c, A, relations, b):
|
7 |
+
"""Helper to solve primal directly using linprog for comparison/fallback"""
|
8 |
+
primal_c = np.array(c, dtype=float)
|
9 |
+
primal_A = np.array(A, dtype=float)
|
10 |
+
primal_b = np.array(b, dtype=float)
|
11 |
+
num_vars = len(c)
|
12 |
+
|
13 |
+
A_ub, b_ub, A_eq, b_eq = [], [], [], []
|
14 |
+
|
15 |
+
sign_flip = 1.0
|
16 |
+
if objective_type.lower() == 'max':
|
17 |
+
primal_c = -primal_c
|
18 |
+
sign_flip = -1.0
|
19 |
+
|
20 |
+
for i in range(len(relations)):
|
21 |
+
if relations[i] == '<=':
|
22 |
+
A_ub.append(primal_A[i])
|
23 |
+
b_ub.append(primal_b[i])
|
24 |
+
elif relations[i] == '>=':
|
25 |
+
A_ub.append(-primal_A[i]) # Convert >= to <= for linprog A_ub
|
26 |
+
b_ub.append(-primal_b[i])
|
27 |
+
elif relations[i] == '=':
|
28 |
+
A_eq.append(primal_A[i])
|
29 |
+
b_eq.append(primal_b[i])
|
30 |
+
|
31 |
+
# Convert lists to arrays, handling empty cases
|
32 |
+
A_ub = np.array(A_ub) if A_ub else None
|
33 |
+
b_ub = np.array(b_ub) if b_ub else None
|
34 |
+
A_eq = np.array(A_eq) if A_eq else None
|
35 |
+
b_eq = np.array(b_eq) if b_eq else None
|
36 |
+
|
37 |
+
try:
|
38 |
+
result_primal = linprog(primal_c, A_ub=A_ub, b_ub=b_ub, A_eq=A_eq, b_eq=b_eq, bounds=[(0, None)] * num_vars, method='highs') # 'highs' is default and robust
|
39 |
+
|
40 |
+
if result_primal.success:
|
41 |
+
primal_sol_vals = result_primal.x
|
42 |
+
primal_obj_val = result_primal.fun * sign_flip # Adjust obj value back if Max
|
43 |
+
primal_sol_dict = {f'x{j+1}': primal_sol_vals[j] for j in range(num_vars)}
|
44 |
+
return 0, "Solved directly", primal_sol_dict, None, primal_obj_val
|
45 |
+
else:
|
46 |
+
return result_primal.status, f"Direct solve failed: {result_primal.message}", None, None, None
|
47 |
+
except Exception as e:
|
48 |
+
return -1, f"Direct solve exception: {e}", None, None, None
|
49 |
+
|
maths/university/tests/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
# This file makes 'tests' a Python package.
|
2 |
+
# It can be empty.
|
maths/university/tests/test_differential_equations.py
ADDED
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import unittest
|
2 |
+
import numpy as np
|
3 |
+
import math
|
4 |
+
from numpy.testing import assert_array_almost_equal
|
5 |
+
|
6 |
+
# Adjust import path as necessary
|
7 |
+
from maths.university.differential_equations import solve_first_order_ode, solve_second_order_ode
|
8 |
+
|
9 |
+
class TestDifferentialEquations(unittest.TestCase):
|
10 |
+
|
11 |
+
def test_solve_simple_first_order_ode_constant_rate(self):
|
12 |
+
"""Test dy/dt = k, e.g., dy/dt = 2 with y(0)=1. Solution y(t) = 2t + 1."""
|
13 |
+
k = 2.0
|
14 |
+
y_initial = 1.0
|
15 |
+
|
16 |
+
def ode_func(t, y):
|
17 |
+
# For a system, y is an array. Here, it's a single equation, so y[0]
|
18 |
+
# solve_ivp always passes y as at least a 1-element array.
|
19 |
+
return k
|
20 |
+
|
21 |
+
t_span = (0, 5)
|
22 |
+
t_eval_count = 50
|
23 |
+
y0 = [y_initial] # Initial condition must be a list or 1D array for solve_ivp
|
24 |
+
|
25 |
+
result = solve_first_order_ode(ode_func, t_span, y0, t_eval_count=t_eval_count)
|
26 |
+
|
27 |
+
self.assertTrue(result['success'], msg=f"Solver failed: {result['message']}")
|
28 |
+
self.assertIsInstance(result['t'], np.ndarray)
|
29 |
+
self.assertIsInstance(result['y'], np.ndarray)
|
30 |
+
|
31 |
+
self.assertEqual(result['t'].shape, (t_eval_count,))
|
32 |
+
# solve_ivp returns y with shape (n_equations, n_timepoints)
|
33 |
+
self.assertEqual(result['y'].shape, (1, t_eval_count))
|
34 |
+
|
35 |
+
# Check analytical solution: y(t) = k*t + y0
|
36 |
+
expected_y = k * result['t'] + y_initial
|
37 |
+
assert_array_almost_equal(result['y'][0], expected_y, decimal=5,
|
38 |
+
err_msg="Numerical solution does not match analytical solution for dy/dt=k.")
|
39 |
+
|
40 |
+
def test_solve_simple_first_order_ode_exponential_decay(self):
|
41 |
+
"""Test dy/dt = -k*y, e.g., dy/dt = -0.5*y with y(0)=10. Solution y(t) = 10*exp(-0.5t)."""
|
42 |
+
k = 0.5
|
43 |
+
y_initial = 10.0
|
44 |
+
|
45 |
+
def ode_func(t, y):
|
46 |
+
return -k * y[0] # y is [y_val]
|
47 |
+
|
48 |
+
t_span = (0, 3)
|
49 |
+
t_eval_count = 30
|
50 |
+
y0 = [y_initial]
|
51 |
+
|
52 |
+
result = solve_first_order_ode(ode_func, t_span, y0, t_eval_count=t_eval_count)
|
53 |
+
|
54 |
+
self.assertTrue(result['success'], msg=f"Solver failed for exponential decay: {result['message']}")
|
55 |
+
self.assertEqual(result['t'].shape, (t_eval_count,))
|
56 |
+
self.assertEqual(result['y'].shape, (1, t_eval_count))
|
57 |
+
|
58 |
+
# Check analytical solution: y(t) = y0 * exp(-k*t)
|
59 |
+
expected_y = y_initial * np.exp(-k * result['t'])
|
60 |
+
assert_array_almost_equal(result['y'][0], expected_y, decimal=5,
|
61 |
+
err_msg="Numerical solution does not match analytical for exponential decay.")
|
62 |
+
|
63 |
+
|
64 |
+
def test_solve_simple_second_order_ode_constant_acceleration(self):
|
65 |
+
"""Test d²y/dt² = a, e.g., d²y/dt² = 2 with y(0)=1, y'(0)=0.5.
|
66 |
+
Solution y'(t) = a*t + y'(0) => y'(t) = 2t + 0.5
|
67 |
+
Solution y(t) = 0.5*a*t² + y'(0)*t + y(0) => y(t) = t² + 0.5t + 1
|
68 |
+
"""
|
69 |
+
accel = 2.0
|
70 |
+
y_initial = 1.0
|
71 |
+
dy_dt_initial = 0.5
|
72 |
+
|
73 |
+
def ode_func_second_order(t, y_val, dy_dt_val):
|
74 |
+
return accel # d²y/dt² = constant
|
75 |
+
|
76 |
+
t_span = (0, 4)
|
77 |
+
t_eval_count = 40
|
78 |
+
|
79 |
+
result = solve_second_order_ode(
|
80 |
+
ode_func_second_order, t_span, y_initial, dy_dt_initial, t_eval_count=t_eval_count
|
81 |
+
)
|
82 |
+
|
83 |
+
self.assertTrue(result['success'], msg=f"Solver failed for 2nd order const accel: {result['message']}")
|
84 |
+
self.assertIsInstance(result['t'], np.ndarray)
|
85 |
+
self.assertIsInstance(result['y'], np.ndarray)
|
86 |
+
self.assertIsInstance(result['dy_dt'], np.ndarray)
|
87 |
+
|
88 |
+
self.assertEqual(result['t'].shape, (t_eval_count,))
|
89 |
+
self.assertEqual(result['y'].shape, (t_eval_count,)) # Output y is 1D array
|
90 |
+
self.assertEqual(result['dy_dt'].shape, (t_eval_count,)) # Output dy_dt is 1D array
|
91 |
+
|
92 |
+
# Check analytical solution for y(t)
|
93 |
+
expected_y = 0.5 * accel * result['t']**2 + dy_dt_initial * result['t'] + y_initial
|
94 |
+
assert_array_almost_equal(result['y'], expected_y, decimal=5,
|
95 |
+
err_msg="Numerical y(t) does not match analytical for const accel.")
|
96 |
+
|
97 |
+
# Check analytical solution for dy/dt(t)
|
98 |
+
expected_dy_dt = accel * result['t'] + dy_dt_initial
|
99 |
+
assert_array_almost_equal(result['dy_dt'], expected_dy_dt, decimal=5,
|
100 |
+
err_msg="Numerical dy/dt(t) does not match analytical for const accel.")
|
101 |
+
|
102 |
+
def test_solve_damped_oscillator_second_order(self):
|
103 |
+
"""Test d²y/dt² = -b*(dy/dt) - k*y (damped harmonic oscillator).
|
104 |
+
This is more complex and won't have a trivial analytical solution form for all parameters,
|
105 |
+
so we mainly check if it runs and produces output of the correct shape.
|
106 |
+
"""
|
107 |
+
damping_b = 0.5 # Damping coefficient
|
108 |
+
spring_k = 2.0 # Spring constant (normalized for m=1)
|
109 |
+
y_initial = 1.0
|
110 |
+
dy_dt_initial = 0.0
|
111 |
+
|
112 |
+
def damped_oscillator_func(t, y, dy_dt):
|
113 |
+
return -damping_b * dy_dt - spring_k * y
|
114 |
+
|
115 |
+
t_span = (0, 10)
|
116 |
+
t_eval_count = 100
|
117 |
+
|
118 |
+
result = solve_second_order_ode(
|
119 |
+
damped_oscillator_func, t_span, y_initial, dy_dt_initial, t_eval_count=t_eval_count
|
120 |
+
)
|
121 |
+
|
122 |
+
self.assertTrue(result['success'], msg=f"Solver failed for damped oscillator: {result['message']}")
|
123 |
+
self.assertEqual(result['t'].shape, (t_eval_count,))
|
124 |
+
self.assertEqual(result['y'].shape, (t_eval_count,))
|
125 |
+
self.assertEqual(result['dy_dt'].shape, (t_eval_count,))
|
126 |
+
# For a damped oscillator starting from rest at y=1, y should decrease initially (or oscillate around 0)
|
127 |
+
# This is just a sanity check, not a strict assertion of values.
|
128 |
+
if len(result['y']) > 1:
|
129 |
+
self.assertTrue(result['y'][0] > result['y'][-1] or result['y'][0] < result['y'][-1] or result['y'][0] != result['y'][1])
|
130 |
+
|
131 |
+
|
132 |
+
if __name__ == '__main__':
|
133 |
+
unittest.main(argv=['first-arg-is-ignored'], exit=False)
|
maths/university/tests/test_linear_algebra.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import unittest
|
2 |
+
import numpy as np
|
3 |
+
from numpy.testing import assert_array_almost_equal, assert_raises
|
4 |
+
|
5 |
+
# Adjust the import path based on your project structure if necessary
|
6 |
+
# Assuming 'maths' is a top-level directory in PYTHONPATH or current working directory
|
7 |
+
from maths.university.linear_algebra import (
|
8 |
+
matrix_add, matrix_subtract, matrix_multiply,
|
9 |
+
matrix_determinant, matrix_inverse,
|
10 |
+
vector_add, vector_subtract, vector_dot_product,
|
11 |
+
vector_cross_product, solve_linear_system
|
12 |
+
)
|
13 |
+
|
14 |
+
class TestLinearAlgebra(unittest.TestCase):
|
15 |
+
|
16 |
+
def test_matrix_add(self):
|
17 |
+
m1 = np.array([[1, 2], [3, 4]])
|
18 |
+
m2 = np.array([[5, 6], [7, 8]])
|
19 |
+
expected = np.array([[6, 8], [10, 12]])
|
20 |
+
assert_array_almost_equal(matrix_add(m1, m2), expected)
|
21 |
+
assert_array_almost_equal(matrix_add([[1,2],[3,4]], [[5,6],[7,8]]), expected)
|
22 |
+
|
23 |
+
def test_matrix_multiply(self):
|
24 |
+
m1 = np.array([[1, 2], [3, 4]])
|
25 |
+
m2 = np.array([[5, 6], [7, 8]]) # 2x2
|
26 |
+
expected = np.array([[19, 22], [43, 50]]) # (1*5+2*7), (1*6+2*8) etc.
|
27 |
+
assert_array_almost_equal(matrix_multiply(m1, m2), expected)
|
28 |
+
|
29 |
+
m3 = np.array([[1,2,3],[4,5,6]]) # 2x3
|
30 |
+
m4 = np.array([[7,8],[9,10],[11,12]]) # 3x2
|
31 |
+
expected2 = np.array([[58,64],[139,154]])
|
32 |
+
assert_array_almost_equal(matrix_multiply(m3,m4),expected2)
|
33 |
+
|
34 |
+
# Test incompatible shapes
|
35 |
+
with assert_raises(ValueError):
|
36 |
+
matrix_multiply(m1, m3) # 2x2 and 2x3 not compatible like this
|
37 |
+
|
38 |
+
def test_matrix_determinant(self):
|
39 |
+
m1 = np.array([[1, 2], [3, 4]]) # det = 1*4 - 2*3 = 4 - 6 = -2
|
40 |
+
self.assertAlmostEqual(matrix_determinant(m1), -2.0)
|
41 |
+
|
42 |
+
m2 = np.array([[3, 1, 0], [2, 0, 1], [0, 2, 4]])
|
43 |
+
# Det = 3(0-2) - 1(8-0) + 0 = -6 - 8 = -14
|
44 |
+
self.assertAlmostEqual(matrix_determinant(m2), -14.0)
|
45 |
+
|
46 |
+
m_singular = np.array([[1,2],[2,4]]) # det = 0
|
47 |
+
self.assertAlmostEqual(matrix_determinant(m_singular), 0.0)
|
48 |
+
|
49 |
+
def test_matrix_inverse(self):
|
50 |
+
m1 = np.array([[1, 2], [3, 7]]) # det = 7-6 = 1
|
51 |
+
# inv = 1/1 * [[7, -2], [-3, 1]]
|
52 |
+
expected_inv1 = np.array([[7, -2], [-3, 1]])
|
53 |
+
assert_array_almost_equal(matrix_inverse(m1), expected_inv1)
|
54 |
+
|
55 |
+
m_non_square = np.array([[1,2,3],[4,5,6]])
|
56 |
+
with assert_raises(ValueError): # Specific to our implementation if it checks before numpy
|
57 |
+
matrix_inverse(m_non_square)
|
58 |
+
|
59 |
+
m_singular = np.array([[1, 2], [2, 4]]) # Determinant is 0
|
60 |
+
with assert_raises(np.linalg.LinAlgError): # Numpy's error for singular matrix
|
61 |
+
matrix_inverse(m_singular)
|
62 |
+
|
63 |
+
def test_vector_dot_product(self):
|
64 |
+
v1 = np.array([1, 2, 3])
|
65 |
+
v2 = np.array([4, 5, 6])
|
66 |
+
# 1*4 + 2*5 + 3*6 = 4 + 10 + 18 = 32
|
67 |
+
self.assertAlmostEqual(vector_dot_product(v1, v2), 32.0)
|
68 |
+
self.assertAlmostEqual(vector_dot_product([1,0,-1], [1,1,1]), 0.0) # Orthogonal
|
69 |
+
|
70 |
+
def test_vector_cross_product(self):
|
71 |
+
v1 = np.array([1, 0, 0]) # i
|
72 |
+
v2 = np.array([0, 1, 0]) # j
|
73 |
+
expected_ij = np.array([0, 0, 1]) # k
|
74 |
+
assert_array_almost_equal(vector_cross_product(v1, v2), expected_ij)
|
75 |
+
|
76 |
+
v3 = np.array([1, 2, 3])
|
77 |
+
v4 = np.array([4, 5, 6])
|
78 |
+
# (2*6 - 3*5, 3*4 - 1*6, 1*5 - 2*4)
|
79 |
+
# (12 - 15, 12 - 6, 5 - 8)
|
80 |
+
# (-3, 6, -3)
|
81 |
+
expected_v3v4 = np.array([-3, 6, -3])
|
82 |
+
assert_array_almost_equal(vector_cross_product(v3, v4), expected_v3v4)
|
83 |
+
|
84 |
+
# Test non-3D vectors
|
85 |
+
v_2d_1 = [1,2]
|
86 |
+
v_2d_2 = [3,4]
|
87 |
+
with assert_raises(ValueError):
|
88 |
+
vector_cross_product(v_2d_1, v_2d_2)
|
89 |
+
with assert_raises(ValueError):
|
90 |
+
vector_cross_product(v1, v_2d_1) # One 3D, one 2D
|
91 |
+
|
92 |
+
def test_solve_linear_system(self):
|
93 |
+
# System 1:
|
94 |
+
# 2x + y = 5
|
95 |
+
# x - y = 1
|
96 |
+
# Solution: x=2, y=1
|
97 |
+
A1 = np.array([[2, 1], [1, -1]])
|
98 |
+
B1 = np.array([5, 1])
|
99 |
+
expected_X1 = np.array([2, 1])
|
100 |
+
assert_array_almost_equal(solve_linear_system(A1, B1), expected_X1, decimal=6)
|
101 |
+
|
102 |
+
# System 2:
|
103 |
+
# x + y + z = 6
|
104 |
+
# 2y + 5z = -4
|
105 |
+
# 2x + 5y - z = 27
|
106 |
+
# Solution: x=5, y=3, z=-2 (from online calculator)
|
107 |
+
A2 = np.array([[1, 1, 1], [0, 2, 5], [2, 5, -1]])
|
108 |
+
B2 = np.array([6, -4, 27])
|
109 |
+
expected_X2 = np.array([5, 3, -2])
|
110 |
+
assert_array_almost_equal(solve_linear_system(A2, B2), expected_X2, decimal=6)
|
111 |
+
|
112 |
+
# Singular system (no unique solution)
|
113 |
+
A_singular = np.array([[1, 1], [2, 2]])
|
114 |
+
B_singular = np.array([1, 3]) # Inconsistent
|
115 |
+
with assert_raises(np.linalg.LinAlgError):
|
116 |
+
solve_linear_system(A_singular, B_singular)
|
117 |
+
|
118 |
+
# Non-square coefficient matrix
|
119 |
+
A_non_square = np.array([[1,1,1],[0,2,5]])
|
120 |
+
B_non_square = np.array([6,-4])
|
121 |
+
with assert_raises(ValueError):
|
122 |
+
solve_linear_system(A_non_square, B_non_square)
|
123 |
+
|
124 |
+
# Incompatible dimensions B vector
|
125 |
+
A_valid = np.array([[1,1],[0,2]])
|
126 |
+
B_invalid_dim = np.array([1,2,3])
|
127 |
+
with assert_raises(ValueError): # Or LinAlgError depending on numpy's internal checks order
|
128 |
+
solve_linear_system(A_valid, B_invalid_dim)
|
129 |
+
|
130 |
+
|
131 |
+
if __name__ == '__main__':
|
132 |
+
unittest.main(argv=['first-arg-is-ignored'], exit=False)
|
requirements.txt
CHANGED
@@ -64,4 +64,17 @@ typing_extensions==4.14.0
|
|
64 |
tzdata==2025.2
|
65 |
urllib3==2.4.0
|
66 |
uvicorn==0.34.3
|
67 |
-
websockets==15.0.1
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
tzdata==2025.2
|
65 |
urllib3==2.4.0
|
66 |
uvicorn==0.34.3
|
67 |
+
websockets==15.0.1
|
68 |
+
scipy==1.15.3
|
69 |
+
cffi==1.17.1
|
70 |
+
clarabel==0.11.0
|
71 |
+
cvxpy==1.6.5
|
72 |
+
joblib==1.5.1
|
73 |
+
networkx==3.5
|
74 |
+
osqp==1.0.4
|
75 |
+
pycparser==2.22
|
76 |
+
scs==3.2.7.post2
|
77 |
+
setuptools==80.9.0
|
78 |
+
tabulate==0.9.0
|
79 |
+
mpmath==1.3.0
|
80 |
+
sympy==1.14.0
|