from omni.isaac.kit import SimulationApp

# Start simulation app
simulation_app = SimulationApp({"headless": False, "renderer": "RayTracedLighting"})

import omni.replicator.core as rep
from pxr import UsdLux, Gf, UsdGeom, Usd, Sdf
import omni.usd
import numpy as np
import cv2
from ultralytics import YOLO
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import torch
import time

# Check CUDA availability
def assign_cuda_device():
    if torch.cuda.is_available():
        device = torch.device("cuda:0")
        torch.cuda.set_device(device)
        print(f"Using CUDA device: {torch.cuda.current_device()} ({torch.cuda.get_device_name(0)})")
        return device
    else:
        print("CUDA is not available. Falling back to CPU.")
        return torch.device("cpu")

device = assign_cuda_device()

# Load stage
rep.new_layer()
stage_path = "D:/isaaclab/_isaac_sim/Humanoid_sim/scena1.usd"
omni.usd.get_context().open_stage(stage_path)

# Robot path
robot_path = "/World/g1"
camera_path = "/World/g1/pelvis/robot_eye_camera"  # If you placed it under head

# Wait for stage to load
stage = omni.usd.get_context().get_stage()
while stage is None:
    simulation_app.update()
    stage = omni.usd.get_context().get_stage()
    time.sleep(0.1)

# Verify robot and camera exist
robot_prim = stage.GetPrimAtPath(robot_path)
if not robot_prim.IsValid():
    raise RuntimeError(f"Robot not found at path: {robot_path}")
print(f"Found robot at: {robot_path}")

camera_prim = stage.GetPrimAtPath(camera_path)
if not camera_prim.IsValid():
    raise RuntimeError(f"Camera not found at path: {camera_path}")
print(f"Using new robot eye camera at: {camera_path}")

# Update simulation to ensure everything is loaded
for _ in range(20):
    simulation_app.update()

# Create render product
print("Setting up render product with camera in the stage...")
height, width = 1024, 1024
render_product = rep.create.render_product(camera_path, resolution=(width, height))
if render_product is None:
    raise RuntimeError("Failed to create render product")
print(f"Successfully created render product for camera at {camera_path}")

# Annotator setup
annotator = rep.AnnotatorRegistry.get_annotator("rgb")
annotator.attach([render_product])
depth_annotator = rep.AnnotatorRegistry.get_annotator("distance_to_camera")

depth_annotator.attach([render_product])

# Start Replicator
rep.orchestrator.run()

#Capture image
captured_image = None
captured_depth = None

for i in range(10):
    simulation_app.update()
    rep.orchestrator.step()
    time.sleep(0.05)

    rgb_data = annotator.get_data()
    depth_data = depth_annotator.get_data()

    if rgb_data is not None and depth_data is not None and rgb_data.size > 0 and depth_data.size > 0:
        rgb_image = np.frombuffer(rgb_data, dtype=np.uint8).reshape((height, width, 4))[:, :, :3]
        depth_image = np.frombuffer(depth_data, dtype=np.float32).reshape((height, width))

        mean_brightness = np.mean(rgb_image)
        if mean_brightness > 10.0:
            captured_image = rgb_image
            captured_depth = depth_image
            print(f"Captured image and depth at frame {i}")
            cv2.imwrite("capture_rgb.png", cv2.cvtColor(rgb_image, cv2.COLOR_RGB2BGR))
            np.save("capture_depth.npy", depth_image)
            break
    else:
        print(f"Frame {i}: No valid data received")

# ✅ Check this to avoid false RuntimeError
if captured_image is None or captured_depth is None:
    raise RuntimeError("Failed to capture a valid image after multiple attempts")

# Process captured image
image_rgb = captured_image.copy()
image_bgr = cv2.cvtColor(image_rgb, cv2.COLOR_RGB2BGR)
cv2.imwrite("humanoid_captured_image_fixed.png", image_bgr)
print("Saved captured image as 'humanoid_captured_image_fixed.png'")
#depth_path = "distance_to_camera_0000.npy"
depth_map = captured_depth

# YOLO inference
image_normalized = image_bgr / 255.0
image_tensor = torch.from_numpy(image_normalized).permute(2, 0, 1).float().to(device).unsqueeze(0)
print("Loading YOLO model...")
# model = YOLO('yolov8l-oiv7.pt').to(device)
model = YOLO('yolov8x.pt').to(device)
print("Running YOLO inference...")
results = model(image_tensor)
for box in results[0].boxes.data:
    x1, y1, x2, y2 = map(int, box[:4])
    cx, cy = (x1 + x2) // 2, (y1 + y2) // 2
    distance = depth_map[cy, cx]

    class_id = int(box[5]) if len(box) > 5 else -1
    class_name = model.names[class_id] if class_id in model.names else "Unknown"
    text = f"{class_name}: {distance:.2f}m"

    print(f"Object at ({cx}, {cy}) has estimated distance: {distance:.2f} meters")

    # Draw on image
    cv2.rectangle(image_bgr, (int(x1), int(y1)), (int(x2), int(y2)), (0, 255, 0), 2)
    cv2.putText(image_bgr, text, (int(x1), int(y1) - 10),
                cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 255), 2)

# Print detection results
print("\n=== YOLO Detection Results ===")
print("YOLO classes:", model.names)
if len(results[0].boxes) > 0:
    detected_classes = [results[0].names[int(cls)] for cls in results[0].boxes.cls]
    print("Detected classes:", detected_classes)
    print("Detected boxes:", results[0].boxes.data)
    print("Confidence scores:", results[0].boxes.conf)
else:
    print("No objects detected")

# Draw results
annotated_frame = results[0].plot()
annotated_rgb = cv2.cvtColor(annotated_frame, cv2.COLOR_BGR2RGB)
plt.figure(figsize=(15, 10))
plt.imshow(annotated_rgb)
plt.axis('off')
plt.title("YOLOv8 Detection from Stage Camera in Isaac Sim")
plt.savefig("humanoid_yolo_output_fixed.png", dpi=150, bbox_inches='tight')
plt.close()
print("Saved YOLO detection result as 'humanoid_yolo_output_fixed.png'")
cv2.imwrite("humanoid_yolo_depth_overlay.png", image_bgr)
print("Saved image with YOLO detections and depth overlay")

# Save detection summary
detection_summary = f"""
Detection Summary:
- Total detections: {len(results[0].boxes)}
- Classes found: {[results[0].names[int(cls)] for cls in results[0].boxes.cls] if len(results[0].boxes) > 0 else 'None'}
- Image resolution: {width}x{height}
- Camera path: {camera_path}
- Robot path: {robot_path}
- Camera is fixed: True
"""
with open("detection_summary_fixed.txt", "w") as f:
    f.write(detection_summary)
print(detection_summary)

# Cleanup
print("Cleaning up...")
annotator.detach()
omni.usd.get_context().close_stage()
simulation_app.close()
print("Simulation closed successfully")