View Javadoc
1   /*
2    * Licensed to the Hipparchus project under one or more
3    * contributor license agreements.  See the NOTICE file distributed with
4    * this work for additional information regarding copyright ownership.
5    * The ASF licenses this file to You under the Apache License, Version 2.0
6    * (the "License"); you may not use this file except in compliance with
7    * the License.  You may obtain a copy of the License at
8    *
9    *      https://www.apache.org/licenses/LICENSE-2.0
10   *
11   * Unless required by applicable law or agreed to in writing, software
12   * distributed under the License is distributed on an "AS IS" BASIS,
13   * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14   * See the License for the specific language governing permissions and
15   * limitations under the License.
16   */
17  package org.hipparchus.ode.events;
18  
19  import org.hipparchus.CalculusFieldElement;
20  import org.hipparchus.Field;
21  import org.hipparchus.analysis.solvers.BracketedRealFieldUnivariateSolver;
22  import org.hipparchus.analysis.solvers.FieldBracketingNthOrderBrentSolver;
23  import org.hipparchus.complex.Complex;
24  import org.hipparchus.complex.ComplexField;
25  import org.hipparchus.ode.FieldODEStateAndDerivative;
26  import org.hipparchus.ode.nonstiff.interpolators.ClassicalRungeKuttaFieldStateInterpolator;
27  import org.hipparchus.ode.sampling.FieldODEStateInterpolator;
28  import org.hipparchus.util.Binary64;
29  import org.hipparchus.util.Binary64Field;
30  import org.hipparchus.util.MathArrays;
31  import org.junit.jupiter.api.Assertions;
32  import org.junit.jupiter.api.Test;
33  import org.junit.jupiter.params.ParameterizedTest;
34  import org.junit.jupiter.params.provider.ValueSource;
35  import org.mockito.Mockito;
36  
37  class FieldDetectorBasedEventStateTest {
38  
39      // Unit test reproducing https://gitlab.orekit.org/orekit/orekit/-/issues/1808
40      @Test
41      void testEpochComparisonAtLeastSignificantBit() throws NoSuchFieldException, IllegalAccessException {
42          final Binary64Field field = Binary64Field.getInstance();
43          final Binary64 zero = field.getZero();
44          // Epoch of event
45          final Binary64 eventTime = zero.add(17016.237999999998);
46  
47          // Get the interpolated state at event time
48          // It will return globalCurrent at 17016.238 sec since the difference between current and previous state times is smaller than the least bit
49          final Binary64[] array = MathArrays.buildArray(field, 2);
50          final FieldODEStateAndDerivative<Binary64> globalCurrent = new FieldODEStateAndDerivative<>(zero.add(17016.238), array, array);
51          final FieldODEStateAndDerivative<Binary64> globalPrevious = new FieldODEStateAndDerivative<>(zero.add(17016.237999999998), array, array);
52          final ClassicalRungeKuttaFieldStateInterpolator<Binary64> interpolator = new ClassicalRungeKuttaFieldStateInterpolator<>(field, true, MathArrays.buildArray(field, 2, 2),
53                                                                                                                                   globalPrevious, globalCurrent,
54                                                                                                                                   globalPrevious, globalCurrent, null);
55          final FieldODEStateAndDerivative<Binary64> interpolatedState = interpolator.getInterpolatedState(eventTime);
56          Assertions.assertEquals(interpolatedState.getTime().getReal(), globalCurrent.getTime().getReal());
57          Assertions.assertNotEquals(interpolatedState.getTime().getReal(), globalPrevious.getTime().getReal());
58  
59          // Configure the event state (failing before the fix)
60          // Since detecting the event causing the numerical issue is tricky; we access the private field to simplify the workflow and directly set the necessary values causing the issue
61          final FieldDetectorBasedEventState<Binary64> es = new FieldDetectorBasedEventState<>(new TestFieldDetector<>(field, true));
62          final java.lang.reflect.Field pendingEvent = FieldDetectorBasedEventState.class.getDeclaredField("pendingEvent");
63          pendingEvent.setAccessible(true);
64          pendingEvent.set(es, true);
65          final java.lang.reflect.Field pendingEventTime = FieldDetectorBasedEventState.class.getDeclaredField("pendingEventTime");
66          pendingEventTime.setAccessible(true);
67          pendingEventTime.set(es, eventTime);
68          final java.lang.reflect.Field afterG = FieldDetectorBasedEventState.class.getDeclaredField("afterG");
69          afterG.setAccessible(true);
70          afterG.set(es, zero); // Dummy value (this value is not interesting in that case)
71  
72          // Action & verify
73          Assertions.assertNotNull(es.doEvent(interpolatedState));
74      }
75  
76      @Test
77      void testDoEventThrowsIfTimeMismatch() throws NoSuchFieldException, IllegalAccessException {
78          final Binary64Field field = Binary64Field.getInstance();
79          final Binary64 zero = field.getZero();
80          // Initialization
81          final FieldODEEventDetector<Binary64> detector = new DummyDetector<>(field);
82          final FieldDetectorBasedEventState<Binary64> eventState = new FieldDetectorBasedEventState<>(detector);
83          java.lang.reflect.Field pendingEvent = FieldDetectorBasedEventState.class.getDeclaredField("pendingEvent");
84          pendingEvent.setAccessible(true);
85          pendingEvent.set(eventState, true);
86          final java.lang.reflect.Field pendingEventTime = FieldDetectorBasedEventState.class.getDeclaredField("pendingEventTime");
87          pendingEventTime.setAccessible(true);
88          pendingEventTime.set(eventState, field.getOne());
89          final Binary64[] array = MathArrays.buildArray(field, 1);
90          final FieldODEStateAndDerivative<Binary64> state = new FieldODEStateAndDerivative<>(zero.add(1.0001), array, array);
91          // Action & verify
92          Assertions.assertThrows(org.hipparchus.exception.MathRuntimeException.class, () -> {
93              eventState.doEvent(state);
94          });
95      }
96  
97      @ParameterizedTest
98      @ValueSource(booleans = {true, false})
99      void testNextCheck(final boolean isForward) {
100         // GIVEN
101         final TestFieldDetector<Complex> detector = new TestFieldDetector<>(ComplexField.getInstance(), isForward);
102         final FieldDetectorBasedEventState<Complex> eventState = new FieldDetectorBasedEventState<>(detector);
103         final FieldODEStateInterpolator<Complex> mockedInterpolator = Mockito.mock(FieldODEStateInterpolator.class);
104         final FieldODEStateAndDerivative<Complex> stateAndDerivative1 = getStateAndDerivative(1);
105         final FieldODEStateAndDerivative<Complex> stateAndDerivative2 = getStateAndDerivative(-1);
106         if (isForward) {
107             Mockito.when(mockedInterpolator.getCurrentState()).thenReturn(stateAndDerivative1);
108             Mockito.when(mockedInterpolator.getPreviousState()).thenReturn(stateAndDerivative2);
109         } else {
110             Mockito.when(mockedInterpolator.getCurrentState()).thenReturn(stateAndDerivative2);
111             Mockito.when(mockedInterpolator.getPreviousState()).thenReturn(stateAndDerivative1);
112         }
113         Mockito.when(mockedInterpolator.isForward()).thenReturn(isForward);
114         Mockito.when(mockedInterpolator.getInterpolatedState(new Complex(0.))).thenReturn(getStateAndDerivative(0.));
115         eventState.init(mockedInterpolator.getPreviousState(), mockedInterpolator.getPreviousState().getTime());
116         eventState.reinitializeBegin(mockedInterpolator);
117         // WHEN & THEN
118         final AssertionError error = Assertions.assertThrows(AssertionError.class, () ->
119                 eventState.evaluateStep(mockedInterpolator));
120         Assertions.assertEquals(isForward ? "forward" : "backward", error.getMessage());
121     }
122 
123     private static FieldODEStateAndDerivative<Complex> getStateAndDerivative(final double time) {
124         final Complex[] state = MathArrays.buildArray(ComplexField.getInstance(), 1);
125         state[0] = new Complex(time);
126         final Complex[] derivative = MathArrays.buildArray(ComplexField.getInstance(), 1);
127         derivative[0] = Complex.ONE;
128         return new FieldODEStateAndDerivative<>(state[0], state, derivative);
129     }
130 
131     private static class TestFieldDetector<T extends CalculusFieldElement<T>> implements FieldODEEventDetector<T> {
132 
133         private final Field<T> field;
134         private final boolean failOnForward;
135 
136         TestFieldDetector(final Field<T> field, final boolean failOnForward) {
137             this.field = field;
138             this.failOnForward = failOnForward;
139         }
140 
141         @Override
142         public FieldAdaptableInterval<T> getMaxCheckInterval() {
143             return (state, isForward) -> {
144                 if (isForward && failOnForward) {
145                     throw new AssertionError("forward");
146                 } else if (!isForward && !failOnForward) {
147                     throw new AssertionError("backward");
148                 }
149                 return 1.;
150             };
151         }
152 
153         @Override
154         public int getMaxIterationCount() {
155             return 10;
156         }
157 
158         @Override
159         public BracketedRealFieldUnivariateSolver<T> getSolver() {
160             return new FieldBracketingNthOrderBrentSolver<>(field.getOne(), field.getOne(), field.getOne(), 2);
161         }
162 
163         @Override
164         public FieldODEEventHandler<T> getHandler() {
165             return (s, e, d) -> Action.CONTINUE;
166         }
167 
168         @Override
169         public T g(FieldODEStateAndDerivative<T> state) {
170             return state.getTime();
171         }
172     }
173 
174     private static class DummyDetector<T extends CalculusFieldElement<T>> implements FieldODEEventDetector<T> {
175 
176         private final Field<T> field;
177 
178         public DummyDetector(final Field<T> field) {
179             this.field = field;
180         }
181 
182         @Override
183         public FieldAdaptableInterval<T> getMaxCheckInterval() {
184             return (state, isForward) -> 1.0;
185         }
186 
187         @Override
188         public int getMaxIterationCount() {
189             return 10;
190         }
191 
192         @Override
193         public FieldBracketingNthOrderBrentSolver<T> getSolver() {
194             return new FieldBracketingNthOrderBrentSolver<>(field.getZero(), field.getZero(), field.getZero(), 2);
195         }
196 
197         @Override
198         public FieldODEEventHandler<T> getHandler() {
199             return (state, detector, increasing) -> Action.CONTINUE;
200         }
201 
202         @Override
203         public T g(FieldODEStateAndDerivative<T> state) {
204             return state.getTime().getField().getZero();
205         }
206     }
207 
208 }