鱼C论坛

 找回密码
 立即注册
查看: 1890|回复: 1

[技术交流] 《用Python动手学习强化学习》【MDP(马尔可夫决策过程)】

[复制链接]
发表于 2021-9-19 22:35:54 | 显示全部楼层 |阅读模式

马上注册,结交更多好友,享用更多功能^_^

您需要 登录 才可以下载或查看,没有账号?立即注册

x
本帖最后由 糖逗 于 2021-9-19 22:37 编辑

代码出处:《用Python动手学习强化学习》第一章:强化学习的问题设定:马尔可夫决策过程

  1. import random
  2. from enum import Enum
  3. import numpy as np


  4. class State():
  5.     def __init__(self, row = -1, column = -1):
  6.         self.row = row
  7.         self.column = column
  8.         
  9.     def __repr__(self):
  10.         return "<State:[{},{}]>".format(self.row, self.column)
  11.    
  12.     def clone(self):
  13.         return State(self.row, self.column)

  14.     def __hash__(self):
  15.         return hash((self.row, self.column))
  16.    
  17.     def __eq__(self, other):
  18.         return self.row == other.row and self.column == other.column
  19.    
  20.    
  21. class Action(Enum):
  22.     UP = 1
  23.     DOWN = -1
  24.     LEFT = 2
  25.     RIGHT = -2
  26.    


  27. class Environment():#迁移函数和奖励函数
  28.     def __init__(self, grid, move_prob = 0.8):
  29.         '''
  30.         0:普通格子
  31.         -1:有危险的格子(游戏结束)
  32.         1:有奖励的格子(游戏结束)
  33.         9:被屏蔽的格子(无法放置智能体)
  34.         '''
  35.         self.grid = grid
  36.         '''
  37.         默认的奖励是负数,就像施加了初始位置
  38.         '''
  39.         self.default_reward = -0.04
  40.         '''
  41.         智能体能够以move_prob的概率向所选方向移动
  42.         '''
  43.         self.move_prob = move_prob
  44.         self.reset()
  45.         
  46.     @property
  47.     def row_length(self):
  48.         return len(self.grid)
  49.    
  50.     @property
  51.     def column_length(self):
  52.         return len(self.grid[0])
  53.    
  54.     @property
  55.     def actions(self):
  56.         return [Action.UP, Action.DOWN, Action.LEFT, Action.RIGHT]
  57.    
  58.     @property
  59.     def states(self):
  60.         states = []
  61.         for row in range(self.row_length):
  62.             for column in range(self.column_length):
  63.                 if self.grid[row][column] != 9:
  64.                     states.append(State(row, column))
  65.         return states

  66.     def reset(self):
  67.         #初始位置在左下角
  68.         self.agent_state = State(self.row_length - 1, 0)
  69.         return self.agent_state
  70.    
  71.     def can_action_at(self, state):
  72.         if self.grid[state.row][state.column] == 0:
  73.             return True
  74.         else:
  75.             return False
  76.         
  77.     def _move(self, state, action):
  78.         if not self.can_action_at(state):
  79.             raise Exception("Can't move from here!")
  80.         
  81.         next_state = state.clone()
  82.         
  83.         if action == Action.UP:
  84.             next_state.row -= 1
  85.         elif action == Action.DOWN:
  86.             next_state.row += 1
  87.         elif action == Action.LEFT:
  88.             next_state.column -= 1
  89.         elif action == Action.RIGHT:
  90.             next_state.column += 1
  91.             
  92.         #检查状态是否在grid外
  93.         if not (0 <= next_state.row < self.row_length):
  94.             next_state = state
  95.         if not (0 <= next_state.column < self.column_length):
  96.             next_state = state
  97.         
  98.         #检查智能体是否到达了被屏蔽的格子
  99.         if self.grid[next_state.row][next_state.column] == 9:
  100.             next_state = state
  101.             
  102.         return next_state
  103.         
  104.     def transit_func(self, state, action):
  105.         transition_probs = {}
  106.         if not self.can_action_at(state):
  107.             #游戏结束
  108.             return transition_probs
  109.         
  110.         opposite_direction = Action(action.value * -1)
  111.         
  112.         for a in self.actions:
  113.             prob = 0
  114.             if a == action:
  115.                 prob = self.move_prob
  116.             elif a != opposite_direction:
  117.                 prob = (1 - self.move_prob) / 2
  118.             
  119.             next_state = self._move(state, a)
  120.             if next_state not in transition_probs:
  121.                 #求期望
  122.                 transition_probs[next_state] = prob
  123.             else:
  124.                 transition_probs[next_state] += prob
  125.         return transition_probs
  126.    
  127.     def reward_func(self, state):
  128.         reward = self.default_reward
  129.         done = False
  130.         #检查下一种状态的属性
  131.         attribute = self.grid[state.row][state.column]
  132.         if attribute == 1:
  133.             #获得奖励,游戏结束
  134.             reward = 1
  135.             done = True
  136.         elif attribute == -1:
  137.             #遇到危险,游戏结束
  138.             reward = -1
  139.             done = True
  140.         return reward, done
  141.    
  142.     def transit(self, state, action):
  143.         transition_probs = self.transit_func(state, action)
  144.         if len(transition_probs) == 0:
  145.             return None, None, True
  146.         
  147.         next_states = []
  148.         probs = []
  149.         for s in transition_probs:
  150.             next_states.append(s)
  151.             probs.append(transition_probs[s])
  152.         
  153.         next_state = np.random.choice(next_states, p = probs)
  154.         reward, done = self.reward_func(next_state)
  155.         return next_state, reward, done
  156.    
  157.     def step(self, action):
  158.         next_state, reward, done = self.transit(self.agent_state, action)
  159.         if next_state is not None:
  160.             self.agent_state = next_state
  161.         return next_state, reward, done
  162.    
  163.    
  164. class Agent():
  165.     def __init__(self, env):
  166.         self.actions = env.actions
  167.    
  168.     def policy(self, state):
  169.         return random.choice(self.actions)
  170.    

  171. def main():
  172.     grid = [[0, 0, 0, 1], [0, 9, 0, -1], [0, 0, 0, 0]]
  173.     env = Environment(grid)#环境包括迁移函数和奖励函数
  174.     agent = Agent(env)
  175.    
  176.     for i in range(10):
  177.         state = env.reset()
  178.         total_reward = 0
  179.         done = False
  180.         
  181.         while not done:
  182.             action = agent.policy(state)
  183.             next_state, reward, done = env.step(action)
  184.             total_reward += reward
  185.             state = next_state
  186.             
  187.         print("Episode{}:Agent gets {} reward.".format(i, total_reward))

  188. if __name__ == "__main__":
  189.     main()
  190.    
复制代码

本帖被以下淘专辑推荐:

想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复

使用道具 举报

 楼主| 发表于 2021-9-19 22:36:31 | 显示全部楼层
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

小黑屋|手机版|Archiver|鱼C工作室 ( 粤ICP备18085999号-1 | 粤公网安备 44051102000585号)

GMT+8, 2024-5-15 01:00

Powered by Discuz! X3.4

© 2001-2023 Discuz! Team.

快速回复 返回顶部 返回列表