diff --git a/ohmpi/hardware_system.py b/ohmpi/hardware_system.py
index f3fe7ba648a64a416e331b2c4185b4811de9c4ca..7a65b0642606fe3ac632c1f83b8b87aee32a1abe 100644
--- a/ohmpi/hardware_system.py
+++ b/ohmpi/hardware_system.py
@@ -139,12 +139,25 @@ class OhmPiHardware:
         self._start_time = None
         self._pulse = 0
 
-    def _gain_auto(self):  # TODO: improve _gain_auto
+    def _gain_auto(self, polarities=[1, -1]):  # TODO: improve _gain_auto
         self.exec_logger.event(f'OhmPiHardware\ttx_rx_gain_auto\tbegin\t{datetime.datetime.utcnow()}')
         self.tx_sync.wait()
-        self.tx.adc_gain_auto()
-        self.rx.adc_gain_auto()
-        self.rx.voltage_gain_auto()
+        current, voltage = 0, 0
+        assert isinstance(polarities, )
+        for pol in polarities:
+            self.tx.polarity = pol
+            # set gains automatically
+            injection = Thread(target=self._inject, kwargs={'injection_duration': 0.2, 'polarity': polarity})
+            readings = Thread(target=self._read_values)
+            readings.start()
+            injection.start()
+            readings.join()
+            injection.join()
+            current = max(current,np.mean(self.readings[v, 3]))
+            voltage = max(voltage,np.abs(np.mean(self.readings[v, 2] * (self.readings[v, 4]))))
+
+        self.tx.gain_auto(current)
+        self.rx.gain_auto(voltage)
         self.exec_logger.event(f'OhmPiHardware\ttx_rx_gain_auto\tend\t{datetime.datetime.utcnow()}')
 
     def _inject(self, polarity=1, injection_duration=None):  # TODO: deal with voltage or current pulse
@@ -344,15 +357,7 @@ class OhmPiHardware:
     def vab_square_wave(self, vab, cycle_duration, sampling_rate=None, cycles=3, polarity=1, duty_cycle=1.,
                         append=False):
         self.exec_logger.event(f'OhmPiHardware\tvab_square_wave\tbegin\t{datetime.datetime.utcnow()}')
-        self.tx.polarity = polarity  ### TODO: inject on both polarities for gain auto?
-        durations = [cycle_duration/2]*2*cycles
-        # set gains automatically
-        gain_auto = Thread(target=self._gain_auto)
-        injection = Thread(target=self._inject, kwargs={'injection_duration': 0.2, 'polarity': polarity})
-        gain_auto.start()
-        injection.start()
-        gain_auto.join()
-        injection.join()
+        self._gain_auto()
         assert 0. <= duty_cycle <= 1.
         if duty_cycle < 1.:
             durations = [cycle_duration/2 * duty_cycle, cycle_duration/2*(1.-duty_cycle)] * 2 * cycles