Python
Numpy - np.where()
IT_달토끼
2023. 3. 30. 19:45
코드를 보는데, 처음 보는 구문이 나와서 정리해보았다.
np.where(조건, x, y): 조건문을 만족할 때 x값을 반환하고, 아니면 y값을 반환한다.
다음은 사용 예제이다.
import numpy as np
a = np.arange(10)
b = np.where(a < 5, a, 10*a)
'''
a 출력 시 => array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
b 출력 시 => array([ 0, 1, 2, 3, 4, 50, 60, 70, 80, 90])
'''
아래처럼 np.where() 메서드를 사용해서 ReLU 함수를 구현할 수 있다.
import numpy as np
import matplotlib.pyplot as plt
x = np.linspace(-50, 100, 10)
y = np.where(x > 0, x, 0)
plt.plot(x, y)
plt.title('ReLU')
plt.show()
아래는 평소에 내가 사용하던 방법이다.
import numpy as np
import matplotlib.pyplot as plt
x = [x for x in range(-50, 100, 10)]
y = []
for _ in range(-50, 100, 10):
y.append(max(_, 0))
plt.plot(x, y)
plt.title('ReLU')
plt.show()
np.where을 사용하는 게 더 깔끔해 보인다.
결과는 두 구문 모두 아래와 같다.
np.where은 사용방법이 간단해서 자주 사용하게 될 것 같다.
파이썬을 배운 지 꽤 됐는데 이렇게 모르는 구문을 마주칠 때마다 아직 많이 부족함을 느낀다ㅠ