fix(USD): preserve full XCAF hierarchy with local transforms

Rewrite _traverse_xcaf → _author_xcaf_to_usd that recursively authors
USD prims mirroring the XCAF assembly tree:
- Assembly nodes become UsdGeom.Xform prims with local transforms from
  each component label's Location (not composed with parents)
- Leaf shapes get definition-space vertices (face_loc only, no instance
  placement) — the USD scene graph composes transforms hierarchically
- Coordinate swap (X,-Z,Y) now authored once as a root Xform on
  /Root/Assembly instead of per-vertex transformation
- Sharp/seam edges extracted per-part from definition shape (not global)

This fixes misplaced geometry for sub-assembly parts (e.g. KOMP-EIN
roller cages with -45° Z rotation) that were previously lost by the
flat traversal.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
2026-03-12 23:45:02 +01:00
parent 078420c5f1
commit 3dcfa7c0bd
+259 -200
View File
@@ -1,12 +1,13 @@
"""STEP → USD exporter for Schaeffler Automat.
Reads a STEP file via OCP/XCAF (preserving part names + embedded colors),
tessellates with BRepMesh, builds a USD stage with one UsdGeomMesh per leaf
part, and writes a .usd file.
tessellates with BRepMesh, builds a USD stage mirroring the full XCAF
assembly hierarchy (intermediate Xform prims with local transforms, leaf
Mesh prims with definition-space geometry), and writes a .usd file.
Coordinate system: OCC is mm Z-up. USD stage is authored in mm Y-up
(matching glTF / Blender convention). metersPerUnit=0.001 is set so Blender
handles the mm→m conversion on import — no explicit scaling applied here.
Coordinate system: OCC is mm Z-up. USD stage is Z-up with a coordinate
swap Xform on /Root/Assembly: (X_occ, Y_occ, Z_occ) → (X, -Z, Y).
metersPerUnit=0.001 is set so Blender handles mm→m on import.
Usage:
python3 export_step_to_usd.py \\
@@ -320,63 +321,253 @@ def _extract_seam_edge_pairs(shape) -> list:
return seam_pairs
# ── XCAF traversal ────────────────────────────────────────────────────────────
# ── XCAF traversal + hierarchical USD authoring ──────────────────────────────
def _traverse_xcaf(shape_tool, color_tool, label, path_prefix, existing_keys, depth=0):
"""Yield one dict per leaf shape in the XCAF hierarchy.
def _get_label_name(label) -> str:
"""Extract the source name string from a TDF_Label."""
from OCP.TDataStd import TDataStd_Name
name_attr = TDataStd_Name()
if label.FindAttribute(TDataStd_Name.GetID_s(), name_attr):
return name_attr.Get().ToExtString()
return ""
Transform composition: `GetShape_s(reference_label)` returns the shape with
the reference's own location already composed in. For standard Schaeffler flat
assemblies (12 levels deep) this is correct. Deeply nested sub-assembly
transforms (3+ levels) accumulate naturally because each recursive call
receives a component label from the *referred* definition, so each level's
location is composed by the next GetShape_s call.
def _occ_trsf_to_usd_matrix(trsf):
"""Convert an OCC gp_Trsf (column-vector) to a USD Gf.Matrix4d (row-vector).
OCC uses p' = R·p + t (column-vector convention).
USD uses p' = p·M (row-vector convention).
So M = [R^T | 0; t^T | 1].
"""
from pxr import Gf
return Gf.Matrix4d(
trsf.Value(1, 1), trsf.Value(2, 1), trsf.Value(3, 1), 0,
trsf.Value(1, 2), trsf.Value(2, 2), trsf.Value(3, 2), 0,
trsf.Value(1, 3), trsf.Value(2, 3), trsf.Value(3, 3), 0,
trsf.Value(1, 4), trsf.Value(2, 4), trsf.Value(3, 4), 1,
)
def _author_xcaf_to_usd(
stage, shape_tool, color_tool, label,
usd_parent_path, xcaf_path_prefix,
existing_keys, mat_map_lower, color_map, args,
manifest_parts, counters, prim_names_at_level,
depth=0,
):
"""Recursively author USD prims mirroring the XCAF hierarchy.
Assembly nodes → UsdGeom.Xform with local transform from component Location.
Leaf shapes → Xform + Mesh with definition-space geometry.
Sharp/seam edges are extracted per-part from the definition shape.
The local transform for each node comes from GetShape_s(label).Location(),
which returns ONLY this label's own placement (not composed with parents).
USD scene graph composition handles the full parent-to-leaf chain.
"""
from OCP.TDF import TDF_LabelSequence, TDF_Label
from OCP.TDataStd import TDataStd_Name
from OCP.XCAFDoc import XCAFDoc_ShapeTool
from OCP.TopLoc import TopLoc_Location
from pxr import UsdGeom, UsdShade, Sdf, Vt, Gf
name_attr = TDataStd_Name()
source_name = ""
if label.FindAttribute(TDataStd_Name.GetID_s(), name_attr):
source_name = name_attr.Get().ToExtString()
source_name = _get_label_name(label)
xcaf_path = (f"{xcaf_path_prefix}/{source_name}" if source_name
else f"{xcaf_path_prefix}/unnamed_{depth}")
xcaf_path = (f"{path_prefix}/{source_name}" if source_name
else f"{path_prefix}/unnamed_{depth}")
# Get local transform from this label's shape Location.
# GetShape_s(label) returns the shape with ONLY this label's own Location
# (not composed with parent locations).
label_shape = shape_tool.GetShape_s(label)
if label_shape.IsNull():
return
local_loc = label_shape.Location()
has_local_trsf = not local_loc.IsIdentity()
# Follow references to get the definition label (for sub-assembly detection)
# Follow reference to definition label
actual_label = label
if XCAFDoc_ShapeTool.IsReference_s(label):
ref_label = TDF_Label()
if XCAFDoc_ShapeTool.GetReferredShape_s(label, ref_label):
actual_label = ref_label
# Check for sub-components on the definition
components = TDF_LabelSequence()
XCAFDoc_ShapeTool.GetComponents_s(actual_label, components)
if components.Length() == 0:
shape = shape_tool.GetShape_s(label)
if shape.IsNull():
shape = shape_tool.GetShape_s(actual_label)
if shape.IsNull():
if components.Length() > 0:
# ── ASSEMBLY NODE ──────────────────────────────────────────────
raw_name = _prim_name(source_name or f"asm_{depth}")
unique_name = raw_name
n = 2
while unique_name in prim_names_at_level:
unique_name = f"{raw_name}_{n}"
n += 1
prim_names_at_level.add(unique_name)
xform_path = f"{usd_parent_path}/{unique_name}"
xform = UsdGeom.Xform.Define(stage, xform_path)
if has_local_trsf:
xform.AddTransformOp().Set(
_occ_trsf_to_usd_matrix(local_loc.Transformation()))
prim = xform.GetPrim()
prim.SetCustomDataByKey("schaeffler:sourceName", source_name)
prim.SetCustomDataByKey("schaeffler:sourceAssemblyPath", xcaf_path)
print(f" {' ' * depth}[asm] {source_name}{xform_path}"
f"{' (transform)' if has_local_trsf else ''}")
child_names: set = set()
for i in range(1, components.Length() + 1):
_author_xcaf_to_usd(
stage, shape_tool, color_tool, components.Value(i),
xform_path, xcaf_path,
existing_keys, mat_map_lower, color_map, args,
manifest_parts, counters, child_names, depth + 1,
)
else:
# ── LEAF SHAPE ─────────────────────────────────────────────────
# Get definition shape without instance location
def_shape = shape_tool.GetShape_s(actual_label)
if def_shape.IsNull():
return
# Strip any residual location so _extract_mesh yields definition-space
# vertices (face_loc only, no instance placement).
bare_def = def_shape.Located(TopLoc_Location())
part_key = _generate_part_key(xcaf_path, source_name, existing_keys)
color = _get_shape_color(color_tool, shape)
hex_color = _get_shape_color(color_tool, label_shape)
if not hex_color:
hex_color = _get_shape_color(color_tool, def_shape)
yield {
'shape': shape,
'source_name': source_name,
'xcaf_path': xcaf_path,
'part_key': part_key,
'color': color,
}
else:
for i in range(1, components.Length() + 1):
yield from _traverse_xcaf(
shape_tool, color_tool, components.Value(i),
xcaf_path, existing_keys, depth + 1,
# color_map override (substring match)
if source_name:
for map_name, map_hex in color_map.items():
if (map_name.lower() in source_name.lower()
or source_name.lower() in map_name.lower()):
hex_color = map_hex
break
if not hex_color:
hex_color = PALETTE_HEX[counters['n_parts'] % len(PALETTE_HEX)]
# Extract mesh from definition shape (face_loc only, no instance placement)
vertices, triangles = _extract_mesh(bare_def)
if not vertices or not triangles:
counters['n_empty'] += 1
return
# Ensure unique prim name at this level
raw_name = _prim_name(part_key)
unique_name = raw_name
n = 2
while unique_name in prim_names_at_level:
unique_name = f"{raw_name}_{n}"
n += 1
prim_names_at_level.add(unique_name)
part_path = f"{usd_parent_path}/{unique_name}"
# Name the Mesh prim after part_key so Blender imports it with the
# part name directly (Blender collapses single-child Xform+Mesh).
mesh_path = f"{part_path}/{part_key}"
# ── Xform prim with local transform ────────────────────────
xform = UsdGeom.Xform.Define(stage, part_path)
if has_local_trsf:
xform.AddTransformOp().Set(
_occ_trsf_to_usd_matrix(local_loc.Transformation()))
prim = xform.GetPrim()
prim.SetCustomDataByKey("schaeffler:partKey", part_key)
prim.SetCustomDataByKey("schaeffler:sourceName", source_name)
prim.SetCustomDataByKey("schaeffler:sourceAssemblyPath", xcaf_path)
prim.SetCustomDataByKey("schaeffler:sourceColor", hex_color)
prim.SetCustomDataByKey("schaeffler:tessellation:linearDeflectionMm",
args.linear_deflection)
prim.SetCustomDataByKey("schaeffler:tessellation:angularDeflectionRad",
args.angular_deflection)
if args.cad_file_id:
prim.SetCustomDataByKey("schaeffler:cadFileId", args.cad_file_id)
# ── UsdGeomMesh ────────────────────────────────────────────
mesh = UsdGeom.Mesh.Define(stage, mesh_path)
mesh.CreateSubdivisionSchemeAttr(UsdGeom.Tokens.none)
# Vertices in OCC definition space (mm, Z-up).
# The /Root/Assembly Xform carries the OCC→USD coordinate swap
# so no per-vertex (X,-Z,Y) transformation is needed here.
mesh.CreatePointsAttr(Vt.Vec3fArray([
Gf.Vec3f(x, y, z) for (x, y, z) in vertices
]))
mesh.CreateFaceVertexCountsAttr(Vt.IntArray([3] * len(triangles)))
mesh.CreateFaceVertexIndicesAttr(
Vt.IntArray([idx for tri in triangles for idx in tri])
)
r, g, b = _hex_to_rgb01(hex_color)
mesh.CreateDisplayColorAttr(Vt.Vec3fArray([Gf.Vec3f(r, g, b)]))
# ── Material metadata on mesh prim (customData) ───────────
mesh_prim = mesh.GetPrim()
mesh_prim.SetCustomDataByKey("schaeffler:partKey", part_key)
mesh_prim.SetCustomDataByKey("schaeffler:sourceName", source_name)
canonical_mat = _lookup_material(source_name, part_key, mat_map_lower)
if canonical_mat:
mesh_prim.SetCustomDataByKey(
"schaeffler:canonicalMaterialName", canonical_mat)
primvars_api = UsdGeom.PrimvarsAPI(mesh)
# ── Per-part sharp + seam edge primvars (definition space) ─
try:
sharp_pairs = _extract_sharp_edge_pairs(bare_def, args.sharp_threshold)
if sharp_pairs:
idx_pairs = _world_to_index_pairs(vertices, sharp_pairs)
if idx_pairs:
pv = primvars_api.CreatePrimvar(
"schaeffler:sharpEdgeVertexPairs",
Sdf.ValueTypeNames.Int2Array,
UsdGeom.Tokens.constant,
)
pv.Set(Vt.Vec2iArray([Gf.Vec2i(a, b) for a, b in idx_pairs]))
except Exception as exc:
print(f"WARNING: sharp edge extraction for {part_key}: {exc}")
try:
seam_pairs = _extract_seam_edge_pairs(bare_def)
if seam_pairs:
seam_idx_pairs = _world_to_index_pairs(vertices, seam_pairs)
if seam_idx_pairs:
pv_seam = primvars_api.CreatePrimvar(
"schaeffler:seamEdgeVertexPairs",
Sdf.ValueTypeNames.Int2Array,
UsdGeom.Tokens.constant,
)
pv_seam.Set(Vt.Vec2iArray([
Gf.Vec2i(a, b) for a, b in seam_idx_pairs]))
except Exception as exc:
print(f"WARNING: seam edge extraction for {part_key}: {exc}")
# ── Material binding ──────────────────────────────────────
if canonical_mat:
mat_prim_name = _prim_name(canonical_mat)
else:
mat_prim_name = (_prim_name(source_name) if source_name
else f"mat_{part_key}")
mat_usd_path = f"/Root/Looks/{mat_prim_name}"
if not stage.GetPrimAtPath(mat_usd_path):
UsdShade.Material.Define(stage, mat_usd_path)
UsdShade.MaterialBindingAPI(mesh.GetPrim()).Bind(
UsdShade.Material(stage.GetPrimAtPath(mat_usd_path))
)
manifest_parts.append({
"part_key": part_key,
"source_name": source_name,
"prim_path": part_path,
"canonical_material": canonical_mat,
})
counters['n_parts'] += 1
# ── Mesh geometry extraction ──────────────────────────────────────────────────
@@ -594,31 +785,7 @@ def main() -> None:
)
print("Tessellation complete.")
# ── Sharp edge pairs (world-space mm, Z-up) ───────────────────────────────
sharp_pairs_mm: list = []
try:
for i in range(1, free_labels.Length() + 1):
root_shape = shape_tool.GetShape_s(free_labels.Value(i))
if not root_shape.IsNull():
sharp_pairs_mm.extend(
_extract_sharp_edge_pairs(root_shape, args.sharp_threshold)
)
print(f"Total sharp segment pairs: {len(sharp_pairs_mm)}")
except Exception as exc:
print(f"WARNING: sharp edge extraction failed (non-fatal): {exc}", file=sys.stderr)
# ── Seam edge pairs (world-space mm, Z-up) ────────────────────────────────
seam_pairs_mm: list = []
try:
for i in range(1, free_labels.Length() + 1):
root_shape = shape_tool.GetShape_s(free_labels.Value(i))
if not root_shape.IsNull():
seam_pairs_mm.extend(_extract_seam_edge_pairs(root_shape))
print(f"Total seam segment pairs: {len(seam_pairs_mm)}")
except Exception as exc:
print(f"WARNING: seam edge extraction failed (non-fatal): {exc}", file=sys.stderr)
# ── Apply colors ──────────────────────────────────────────────────────────
# ── Apply colors ──────────────────────────────────────────────────────
if color_map:
try:
_apply_color_map(shape_tool, color_tool, free_labels, color_map)
@@ -632,156 +799,48 @@ def main() -> None:
except Exception as exc:
print(f"WARNING: palette colors failed (non-fatal): {exc}", file=sys.stderr)
# ── Create USD stage ──────────────────────────────────────────────────────
# ── Create USD stage ──────────────────────────────────────────────────
stage = Usd.Stage.CreateNew(str(output_path))
UsdGeom.SetStageUpAxis(stage, UsdGeom.Tokens.z)
UsdGeom.SetStageMetersPerUnit(stage, 0.001) # mm; Blender handles m conversion on import
root_prim = UsdGeom.Xform.Define(stage, "/Root")
stage.SetDefaultPrim(root_prim.GetPrim())
UsdGeom.Xform.Define(stage, "/Root/Assembly")
# /Root/Assembly carries the OCC→Blender coordinate swap.
# OCC is mm Z-up Y-forward; Blender/GLB convention is Z-up Y-backward.
# Transform: (X_occ, Y_occ, Z_occ) → (X, -Z, Y) (= Rx(-90°)).
# Authored as a USD row-vector matrix on the Assembly Xform so that all
# child XCAF transforms (authored in OCC space) are correctly composed.
assembly_xform = UsdGeom.Xform.Define(stage, "/Root/Assembly")
assembly_xform.AddTransformOp().Set(Gf.Matrix4d(
1, 0, 0, 0,
0, 0, 1, 0,
0, -1, 0, 0,
0, 0, 0, 1,
))
stage.DefinePrim("/Root/Looks", "Scope")
# ── Walk XCAF tree → author USD prims ─────────────────────────────────────
# ── Walk XCAF tree → author USD prims (hierarchical) ──────────────────
# Sharp/seam edges are extracted per-part inside _author_xcaf_to_usd
# (in definition space, matching definition-space mesh vertices).
existing_keys: set = set()
manifest_parts: list = []
n_parts = 0
n_empty = 0
counters = {"n_parts": 0, "n_empty": 0}
for root_idx in range(1, free_labels.Length() + 1):
root_label = free_labels.Value(root_idx)
from OCP.TDataStd import TDataStd_Name as _Name
_na = _Name()
root_src = ""
if root_label.FindAttribute(_Name.GetID_s(), _na):
root_src = _na.Get().ToExtString()
node_name = _prim_name(root_src or f"Root{root_idx}")
node_path = f"/Root/Assembly/{node_name}"
UsdGeom.Xform.Define(stage, node_path)
for part in _traverse_xcaf(shape_tool, color_tool, root_label, "", existing_keys):
source_name = part['source_name']
part_key = part['part_key']
hex_color = part['color']
shape = part['shape']
xcaf_path = part['xcaf_path']
# color_map override (substring match)
for map_name, map_hex in color_map.items():
if (map_name.lower() in source_name.lower()
or source_name.lower() in map_name.lower()):
hex_color = map_hex
break
if not hex_color:
hex_color = PALETTE_HEX[n_parts % len(PALETTE_HEX)]
vertices, triangles = _extract_mesh(shape)
if not vertices or not triangles:
n_empty += 1
continue
part_path = f"{node_path}/{part_key}"
# Name the Mesh prim after part_key so Blender imports it with the
# part name directly (Blender collapses single-child Xform+Mesh into
# just the Mesh object, using the mesh prim's leaf name as object name).
mesh_path = f"{part_path}/{part_key}"
# ── Xform prim ────────────────────────────────────────────────
xform = UsdGeom.Xform.Define(stage, part_path)
prim = xform.GetPrim()
prim.SetCustomDataByKey("schaeffler:partKey", part_key)
prim.SetCustomDataByKey("schaeffler:sourceName", source_name)
prim.SetCustomDataByKey("schaeffler:sourceAssemblyPath", xcaf_path)
prim.SetCustomDataByKey("schaeffler:sourceColor", hex_color)
prim.SetCustomDataByKey("schaeffler:tessellation:linearDeflectionMm",
args.linear_deflection)
prim.SetCustomDataByKey("schaeffler:tessellation:angularDeflectionRad",
args.angular_deflection)
if args.cad_file_id:
prim.SetCustomDataByKey("schaeffler:cadFileId", args.cad_file_id)
# ── UsdGeomMesh ───────────────────────────────────────────────
mesh = UsdGeom.Mesh.Define(stage, mesh_path)
mesh.CreateSubdivisionSchemeAttr(UsdGeom.Tokens.none)
# OCC is Z-up (mm) but Y-forward. Blender is Z-up, Y-backward.
# GLB export uses: Blender(X, -Z_occ, Y_occ) × 0.001
# USD stage is Z-up with metersPerUnit=0.001, so Blender applies
# only the scale. Write (X, -Z, Y) to match GLB orientation.
mesh.CreatePointsAttr(Vt.Vec3fArray([
Gf.Vec3f(x, -z, y) for (x, y, z) in vertices
]))
mesh.CreateFaceVertexCountsAttr(Vt.IntArray([3] * len(triangles)))
mesh.CreateFaceVertexIndicesAttr(
Vt.IntArray([idx for tri in triangles for idx in tri])
)
r, g, b = _hex_to_rgb01(hex_color)
mesh.CreateDisplayColorAttr(Vt.Vec3fArray([Gf.Vec3f(r, g, b)]))
# ── Material metadata on mesh prim (customData) ─────────────
# Blender's USD importer does NOT expose STRING primvars or
# customData as Python properties — but pxr can read customData
# directly from the USD file after Blender import. This is 100%
# reliable and avoids Blender importer limitations.
mesh_prim = mesh.GetPrim()
mesh_prim.SetCustomDataByKey("schaeffler:partKey", part_key)
mesh_prim.SetCustomDataByKey("schaeffler:sourceName", source_name)
canonical_mat = _lookup_material(source_name, part_key, mat_map_lower)
if canonical_mat:
mesh_prim.SetCustomDataByKey(
"schaeffler:canonicalMaterialName", canonical_mat)
primvars_api = UsdGeom.PrimvarsAPI(mesh)
# ── Index-space sharp + seam edge primvars ───────────────────
# Lookup is in OCC Z-up space; pairs are also Z-up — no swap needed.
# Both `vertices` and `*_pairs_mm` are in OCC Z-up mm space with the
# full per-shape location already applied — same coordinate frame required
# by _world_to_index_pairs for the nearest-vertex lookup (tol=0.5 mm).
if sharp_pairs_mm:
idx_pairs = _world_to_index_pairs(vertices, sharp_pairs_mm)
if idx_pairs:
pv = primvars_api.CreatePrimvar(
"schaeffler:sharpEdgeVertexPairs",
Sdf.ValueTypeNames.Int2Array,
UsdGeom.Tokens.constant,
)
pv.Set(Vt.Vec2iArray([Gf.Vec2i(a, b) for a, b in idx_pairs]))
if seam_pairs_mm:
seam_idx_pairs = _world_to_index_pairs(vertices, seam_pairs_mm)
if seam_idx_pairs:
pv_seam = primvars_api.CreatePrimvar(
"schaeffler:seamEdgeVertexPairs",
Sdf.ValueTypeNames.Int2Array,
UsdGeom.Tokens.constant,
)
pv_seam.Set(Vt.Vec2iArray([Gf.Vec2i(a, b) for a, b in seam_idx_pairs]))
# ── Material binding ──────────────────────────────────────────
# Use canonical SCHAEFFLER material name when resolved; fall back
# to source_name so Blender imports show meaningful material names
# even without the library .blend appended.
if canonical_mat:
mat_prim_name = _prim_name(canonical_mat)
else:
mat_prim_name = _prim_name(source_name) if source_name else f"mat_{part_key}"
mat_usd_path = f"/Root/Looks/{mat_prim_name}"
if not stage.GetPrimAtPath(mat_usd_path):
UsdShade.Material.Define(stage, mat_usd_path)
UsdShade.MaterialBindingAPI(mesh.GetPrim()).Bind(
UsdShade.Material(stage.GetPrimAtPath(mat_usd_path))
root_names: set = set()
_author_xcaf_to_usd(
stage, shape_tool, color_tool, root_label,
"/Root/Assembly", "",
existing_keys, mat_map_lower, color_map, args,
manifest_parts, counters, root_names,
)
manifest_parts.append({
"part_key": part_key,
"source_name": source_name,
"prim_path": part_path,
"canonical_material": canonical_mat,
})
n_parts += 1
n_parts = counters["n_parts"]
n_empty = counters["n_empty"]
stage.Save()
sz = output_path.stat().st_size // 1024 if output_path.exists() else 0