Source code for corr_solver.gmi_oracle
# -*- coding: utf-8 -*-
from typing import Optional, Tuple
import numpy as np
from ellalgo.oracles.ldlt_mgr import LDLTMgr
Cut = Tuple[np.ndarray, float]
# The `GMIOracle` class is an oracle for a General Matrix Inequality constraint, which evaluates the
# function and its negative gradient.
[docs]
class GMIOracle:
"""Oracle for General Matrix Inequality constraint
H(x) >= 0
H.eval(row, col, x): function evalution at (row, col)-element
H.neggrad[k](rng, x): negative gradient in range rng, the k-term
"""
def __init__(self, H, m):
"""
The function initializes an object with attributes H, m, and Q.
:param H: The parameter `H` is a variable that represents a matrix. It is not clear what the matrix
represents or how it is used in the code
:param m: The parameter `m` represents the dimension of the matrix. It is an integer value
"""
self.H = H
self.m = m
self.ldlt_mgr = LDLTMgr(m)
# def update(self, t):
# """
# The function "update" updates the value of "self.H" with the value of "t".
#
# :param t: The parameter "t" in the "update" method is a variable that represents the time or the
# value that needs to be updated
# """
# self.H.update(t)
[docs]
def assess_feas(self, x: np.ndarray) -> Optional[Cut]:
"""
The `assess_feas` function assesses the feasibility of a given input `x` and returns a cut if it is
infeasible, otherwise it returns `None`.
:param x: An input array of type `np.ndarray`
:type x: np.ndarray
:return: The function `assess_feas` returns an optional `Cut` object.
"""
def get_elem(row, col):
"""
The function `get_elem` returns the evaluation of the function `H` at the given indices `row` and `col`,
with the input `x`.
:param row: The parameter "row" represents the row index of the element in the matrix
:param col: The parameter "col" represents the column index of the element in the matrix
:return: The function `get_elem` is returning the result of calling the `eval` method on the `H`
object with the arguments `row`, `col`, and `x`.
"""
return self.H.eval(row, col, x)
if self.ldlt_mgr.factor(get_elem):
return None
ep = self.ldlt_mgr.witness()
g = self.H.neg_grad_sym_quad(self.ldlt_mgr, x)
return g, ep