Skip to content

Commit e619b6f

Browse files
feat: Implement isCloseTo function (#1012)
1 parent c2c2eb3 commit e619b6f

File tree

5 files changed

+188
-13
lines changed

5 files changed

+188
-13
lines changed

packages/typegpu/src/data/vectorOps.ts

+21
Original file line numberDiff line numberDiff line change
@@ -903,4 +903,25 @@ export const VectorOps = {
903903
vec4i: unary4i((value) => value - Math.floor(value)),
904904
vec4u: unary4u((value) => value - Math.floor(value)),
905905
} as Record<VecKind, <T extends vBase>(v: T) => T>,
906+
907+
isCloseToZero: {
908+
vec2f: (v: wgsl.v2f, n: number) => Math.abs(v.x) <= n && Math.abs(v.y) <= n,
909+
vec2h: (v: wgsl.v2h, n: number) => Math.abs(v.x) <= n && Math.abs(v.y) <= n,
910+
911+
vec3f: (v: wgsl.v3f, n: number) =>
912+
Math.abs(v.x) <= n && Math.abs(v.y) <= n && Math.abs(v.z) <= n,
913+
vec3h: (v: wgsl.v3h, n: number) =>
914+
Math.abs(v.x) <= n && Math.abs(v.y) <= n && Math.abs(v.z) <= n,
915+
916+
vec4f: (v: wgsl.v4f, n: number) =>
917+
Math.abs(v.x) <= n &&
918+
Math.abs(v.y) <= n &&
919+
Math.abs(v.z) <= n &&
920+
Math.abs(v.w) <= n,
921+
vec4h: (v: wgsl.v4h, n: number) =>
922+
Math.abs(v.x) <= n &&
923+
Math.abs(v.y) <= n &&
924+
Math.abs(v.z) <= n &&
925+
Math.abs(v.w) <= n,
926+
} as Record<VecKind, <T extends vBase>(v: T, n: number) => boolean>,
906927
};

packages/typegpu/src/std/index.ts

+1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ export {
2323
mix,
2424
pow,
2525
reflect,
26+
isCloseTo,
2627
} from './numeric.js';
2728

2829
export {

packages/typegpu/src/std/numeric.ts

+73-13
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
import { f32 } from '../data/numeric';
1+
import { bool, f32 } from '../data/numeric';
22
import { VectorOps } from '../data/vectorOps';
33
import type {
44
AnyMatInstance,
@@ -15,6 +15,19 @@ import type {
1515
vBaseForMat,
1616
} from '../data/wgslTypes';
1717
import { createDualImpl } from '../shared/generators';
18+
import type { Resource } from '../types';
19+
20+
function isNumeric(element: Resource) {
21+
const type = element.dataType.type;
22+
return (
23+
type === 'abstractInt' ||
24+
type === 'abstractFloat' ||
25+
type === 'f32' ||
26+
type === 'f16' ||
27+
type === 'i32' ||
28+
type === 'u32'
29+
);
30+
}
1831

1932
type vBase = { kind: VecKind };
2033

@@ -82,18 +95,17 @@ export const mul: MulOverload = createDualImpl(
8295
},
8396
// GPU implementation
8497
(s, v) => {
85-
const returnType =
86-
typeof s === 'number'
87-
? // Scalar * Vector/Matrix
88-
(v.dataType as AnyWgslData)
89-
: !s.dataType.type.startsWith('mat')
90-
? // Vector * Matrix
91-
(s.dataType as AnyWgslData)
92-
: !v.dataType.type.startsWith('mat')
93-
? // Matrix * Vector
94-
(v.dataType as AnyWgslData)
95-
: // Vector * Vector or Matrix * Matrix
96-
(s.dataType as AnyWgslData);
98+
const returnType = isNumeric(s)
99+
? // Scalar * Vector/Matrix
100+
(v.dataType as AnyWgslData)
101+
: !s.dataType.type.startsWith('mat')
102+
? // Vector * Matrix
103+
(s.dataType as AnyWgslData)
104+
: !v.dataType.type.startsWith('mat')
105+
? // Matrix * Vector
106+
(v.dataType as AnyWgslData)
107+
: // Vector * Vector or Matrix * Matrix
108+
(s.dataType as AnyWgslData);
97109
return { value: `(${s.value} * ${v.value})`, dataType: returnType };
98110
},
99111
);
@@ -394,3 +406,51 @@ export const reflect = createDualImpl(
394406
};
395407
},
396408
);
409+
410+
/**
411+
* Checks whether the given elements differ by at most 0.01.
412+
* Component-wise if arguments are vectors.
413+
* @example
414+
* isCloseTo(0, 0.1) // returns false
415+
* isCloseTo(vec3f(0, 0, 0), vec3f(0.002, -0.009, 0)) // returns true
416+
*
417+
* @param {number} precision argument that specifies the maximum allowed difference, 0.01 by default.
418+
*/
419+
420+
export const isCloseTo = createDualImpl(
421+
// CPU implementation
422+
<T extends v2f | v3f | v4f | v2h | v3h | v4h | number>(
423+
e1: T,
424+
e2: T,
425+
precision = 0.01,
426+
) => {
427+
if (typeof e1 === 'number' && typeof e2 === 'number') {
428+
return Math.abs(e1 - e2) < precision;
429+
}
430+
if (typeof e1 !== 'number' && typeof e2 !== 'number') {
431+
return VectorOps.isCloseToZero[e1.kind](sub(e1, e2), precision);
432+
}
433+
return false;
434+
},
435+
// GPU implementation
436+
(e1, e2, precision = { value: 0.01, dataType: f32 }) => {
437+
if (isNumeric(e1) && isNumeric(e2)) {
438+
return {
439+
value: `abs(f32(${e1.value})-f32(${e2.value})) <= ${precision.value}`,
440+
dataType: bool,
441+
};
442+
}
443+
if (!isNumeric(e1) && !isNumeric(e2)) {
444+
return {
445+
// https://www.w3.org/TR/WGSL/#vector-multi-component:~:text=Binary%20arithmetic%20expressions%20with%20mixed%20scalar%20and%20vector%20operands
446+
// (a-a)+prec creates a vector of a.length elements, all equal to prec
447+
value: `all(abs(${e1.value}-${e2.value}) <= (${e1.value} - ${e1.value})+${precision.value})`,
448+
dataType: bool,
449+
};
450+
}
451+
return {
452+
value: 'false',
453+
dataType: bool,
454+
};
455+
},
456+
);
+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import { describe, expect, it } from 'vitest';
2+
import { vec4f } from '../../src/data';
3+
import { atan2, isCloseTo } from '../../src/std';
4+
5+
describe('atan2', () => {
6+
it('computes atan2 of two values', () => {
7+
expect(atan2(0, 1)).toBeCloseTo(0);
8+
expect(atan2(1, 0)).toBeCloseTo(Math.PI / 2);
9+
expect(atan2(0, -1)).toBeCloseTo(Math.PI);
10+
expect(atan2(-1, 0)).toBeCloseTo(-Math.PI / 2);
11+
});
12+
13+
it('computes atan2 for two vectors', () => {
14+
expect(
15+
isCloseTo(
16+
atan2(vec4f(0, 1, 0, -1), vec4f(1, 0, -1, 0)),
17+
vec4f(0, Math.PI / 2, Math.PI, -Math.PI / 2),
18+
),
19+
).toBeTruthy();
20+
});
21+
});
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
import { describe, expect, it } from 'vitest';
2+
import { vec2f, vec2h, vec3f, vec3h, vec4f, vec4h } from '../../src/data';
3+
import { isCloseTo } from '../../src/std/numeric';
4+
5+
describe('isCloseTo', () => {
6+
it('returns true for close f32 containers', () => {
7+
expect(isCloseTo(vec2f(0, 0), vec2f(0.0012, -0.009))).toBeTruthy();
8+
expect(
9+
isCloseTo(vec3f(1.05, -2.18, 1.22), vec3f(1.05, -2.17777, 1.229)),
10+
).toBeTruthy();
11+
expect(
12+
isCloseTo(
13+
vec4f(3.8, -4.87, -2.42, -1.97),
14+
vec4f(3.794, -4.861, -2.412, -1.971),
15+
),
16+
).toBeTruthy();
17+
});
18+
19+
it('returns true for close f16 containers', () => {
20+
expect(isCloseTo(vec2h(0, 0), vec2h(0.0012, -0.009))).toBeTruthy();
21+
expect(
22+
isCloseTo(vec3h(1.05, -2.18, 1.22), vec3h(1.05, -2.17777, 1.229)),
23+
).toBeTruthy();
24+
expect(
25+
isCloseTo(
26+
vec4h(3.8, -4.87, -2.42, -1.97),
27+
vec4h(3.794, -4.861, -2.412, -1.971),
28+
),
29+
).toBeTruthy();
30+
});
31+
32+
it('returns true for close numbers', () => {
33+
expect(isCloseTo(0, 0.009)).toBeTruthy();
34+
expect(isCloseTo(0, 0.0009)).toBeTruthy();
35+
});
36+
37+
it('returns false for distant f32 containers', () => {
38+
expect(isCloseTo(vec2f(0, 0), vec2f(0, 1))).toBeFalsy();
39+
expect(isCloseTo(vec3f(100, 100, 100), vec3f(101, 100, 100))).toBeFalsy();
40+
expect(
41+
isCloseTo(vec4f(1, 2, 3, 4), vec4f(1.02, 2.02, 3.02, 4.02)),
42+
).toBeFalsy();
43+
});
44+
45+
it('returns false for distant f16 containers', () => {
46+
expect(isCloseTo(vec2h(0, 0), vec2h(0, 1))).toBeFalsy();
47+
expect(isCloseTo(vec3h(100, 100, 100), vec3h(101, 100, 100))).toBeFalsy();
48+
expect(
49+
isCloseTo(vec4h(1, 2, 3, 4), vec4h(1.02, 2.02, 3.02, 4.02)),
50+
).toBeFalsy();
51+
});
52+
53+
it('returns false for distant numbers', () => {
54+
expect(isCloseTo(0, 0.9)).toBeFalsy();
55+
expect(isCloseTo(0, 0.09)).toBeFalsy();
56+
});
57+
58+
it('applies precision correctly', () => {
59+
// default precision of 0.01
60+
expect(isCloseTo(vec2h(0, 0), vec2h(0, 0.009))).toBeTruthy();
61+
expect(isCloseTo(vec2h(0, 0), vec2h(0, 0.011))).toBeFalsy();
62+
63+
expect(isCloseTo(vec2h(0, 0), vec2h(0, 0.09), 0.1)).toBeTruthy();
64+
expect(isCloseTo(vec2h(0, 0), vec2h(0, 0.11), 0.1)).toBeFalsy();
65+
66+
expect(isCloseTo(vec2h(0, 0), vec2h(0, 0.0009), 0.001)).toBeTruthy();
67+
expect(isCloseTo(vec2h(0, 0), vec2h(0, 0.0011), 0.001)).toBeFalsy();
68+
69+
expect(isCloseTo(vec2h(0, 0), vec2h(0, 9), 10)).toBeTruthy();
70+
expect(isCloseTo(vec2h(0, 0), vec2h(0, 11), 10)).toBeFalsy();
71+
});
72+
});

0 commit comments

Comments
 (0)