{ "cells": [ { "cell_type": "markdown", "id": "2832faf1-1bd3-4a95-8b0d-b3289e74d4d0", "metadata": {}, "source": [ "# Demonstration on MultigoalIntersection\n", "\n", "[![Click and Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/metadriverse/metadrive/blob/main/documentation/source/multigoal_intersection.ipynb)\n", "\n", "\n", "In this notebook, we demonstrate how to setup a multigoal intersection environment where you can access relevant stats (e.g. route completion, reward, success rate) for all four possible goals (right turn, left turn, move forward, U turn) simultaneously.\n", "\n", "We demonstrate how to build the environment, in which we have successfully trained a SAC expert that achieves 99% success rate, and how to access those stats in the info dict returned each step.\n", "\n", "*Note: We pretrain the SAC expert with `use_multigoal_intersection=False` and then finetune it with `use_multigoal_intersection=True`.*" ] }, { "cell_type": "code", "execution_count": 1, "id": "b9733eac-9d07-47cf-bda7-4dbb8d5f2412", "metadata": {}, "outputs": [], "source": [ "import numpy as np\n", "from metadrive.envs.gym_wrapper import create_gym_wrapper\n", "from metadrive.envs.multigoal_intersection import MultiGoalIntersectionEnv\n", "import mediapy as media\n", "\n", "render = False\n", "num_scenarios = 1000\n", "start_seed = 100" ] }, { "cell_type": "code", "execution_count": 2, "id": "f5b6f059-52f8-46ee-bcfe-dee6f4d2e2e6", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\u001b[38;20m[INFO] Environment: MultiGoalIntersectionEnv\u001b[0m\n", "\u001b[38;20m[INFO] MetaDrive version: 0.4.3\u001b[0m\n", "\u001b[38;20m[INFO] Sensors: [lidar: Lidar(), side_detector: SideDetector(), lane_line_detector: LaneLineDetector()]\u001b[0m\n", "\u001b[38;20m[INFO] Render Mode: none\u001b[0m\n", "\u001b[38;20m[INFO] Horizon (Max steps per agent): 500\u001b[0m\n" ] } ], "source": [ "env_config = dict(\n", " use_render=render,\n", " manual_control=False,\n", " horizon=500, # to speed up training\n", "\n", " traffic_density=0.06,\n", " \n", " use_multigoal_intersection=True, # Set to False if want to use the same observation but with original PG scenarios.\n", " out_of_route_done=False,\n", "\n", " num_scenarios=num_scenarios,\n", " start_seed=start_seed,\n", " accident_prob=0.8,\n", " crash_vehicle_done=False,\n", " crash_object_done=False,\n", ")\n", "\n", "wrapped = create_gym_wrapper(MultiGoalIntersectionEnv)\n", "\n", "env = wrapped(env_config)\n" ] }, { "cell_type": "code", "execution_count": 3, "id": "ae2abe78-f3e3-40b9-88dd-a958fc932363", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "\u001b[38;20m[INFO] Assets version: 0.4.3\u001b[0m\n", "\u001b[38;20m[INFO] Known Pipes: glxGraphicsPipe\u001b[0m\n", "\u001b[38;20m[INFO] Start Scenario Index: 100, Num Scenarios : 1000\u001b[0m\n", "\u001b[33;20m[WARNING] env.vehicle will be deprecated soon. Use env.agent instead (base_env.py:737)\u001b[0m\n", "\u001b[38;20m[INFO] Episode ended! Scenario Index: 542 Reason: arrive_dest.\u001b[0m\n" ] } ], "source": [ "frames = []\n", "\n", "try:\n", " env.reset()\n", " while True:\n", " action = [0, 1]\n", " o, r, d, i = env.step(action)\n", " frame = env.render(mode=\"topdown\")\n", " frames.append(frame)\n", " if d:\n", " break\n", "finally:\n", " env.close()" ] }, { "cell_type": "code", "execution_count": 4, "id": "40ac0392-67e3-4d2d-a9bd-2065831e43ca", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Output at final step:\n", "\tacceleration: 1.000\n", "\tarrive_dest: 1.000\n", "\tarrive_dest/goals/default: 1.000\n", "\tarrive_dest/goals/go_straight: 1.000\n", "\tarrive_dest/goals/left_turn: 0.000\n", "\tarrive_dest/goals/right_turn: 0.000\n", "\tarrive_dest/goals/u_turn: 0.000\n", "\tcost: 0.000\n", "\tcrash: 0.000\n", "\tcrash_building: 0.000\n", "\tcrash_human: 0.000\n", "\tcrash_object: 0.000\n", "\tcrash_sidewalk: 0.000\n", "\tcrash_vehicle: 0.000\n", "\tcurrent_goal: go_straight\n", "\tenv_seed: 542.000\n", "\tepisode_energy: 6.565\n", "\tepisode_length: 85.000\n", "\tepisode_reward: 122.793\n", "\tmax_step: 0.000\n", "\tnavigation_command: forward\n", "\tnavigation_forward: 1.000\n", "\tnavigation_left: 0.000\n", "\tnavigation_right: 0.000\n", "\tout_of_road: 0.000\n", "\tovertake_vehicle_num: 0.000\n", "\tpolicy: EnvInputPolicy\n", "\treward/default_reward: 12.332\n", "\treward/goals/default: 12.332\n", "\treward/goals/go_straight: 12.332\n", "\treward/goals/left_turn: -10.000\n", "\treward/goals/right_turn: -10.000\n", "\treward/goals/u_turn: -10.000\n", "\troute_completion: 0.969\n", "\troute_completion/goals/default: 0.969\n", "\troute_completion/goals/go_straight: 0.969\n", "\troute_completion/goals/left_turn: 0.621\n", "\troute_completion/goals/right_turn: 0.644\n", "\troute_completion/goals/u_turn: 0.557\n", "\tsteering: 0.000\n", "\tstep_energy: 0.162\n", "\tvelocity: 22.291\n" ] } ], "source": [ "print(\"Output at final step:\")\n", "\n", "i = {k: i[k] for k in sorted(i.keys())}\n", "for k, v in i.items():\n", " if isinstance(v, str):\n", " s = v\n", " elif np.iterable(v):\n", " continue\n", " else:\n", " s = \"{:.3f}\".format(v)\n", " print(\"\\t{}: {}\".format(k, s))" ] }, { "cell_type": "code", "execution_count": 5, "id": "dc986e4e-f81c-4882-88b2-9eb306552fb3", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
" ], "text/plain": [ "" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "media.show_video(frames)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.15" } }, "nbformat": 4, "nbformat_minor": 5 }