鱼C论坛

 找回密码
 立即注册
查看: 2083|回复: 1

[技术交流] 《用Python动手学习强化学习》【MDP(马尔可夫决策过程)】

[复制链接]
发表于 2021-9-19 22:35:54 | 显示全部楼层 |阅读模式

马上注册,结交更多好友,享用更多功能^_^

您需要 登录 才可以下载或查看,没有账号?立即注册

x
本帖最后由 糖逗 于 2021-9-19 22:37 编辑

代码出处:《用Python动手学习强化学习》第一章:强化学习的问题设定:马尔可夫决策过程
import random
from enum import Enum
import numpy as np


class State():
    def __init__(self, row = -1, column = -1):
        self.row = row
        self.column = column
        
    def __repr__(self):
        return "<State:[{},{}]>".format(self.row, self.column)
    
    def clone(self):
        return State(self.row, self.column)

    def __hash__(self):
        return hash((self.row, self.column))
    
    def __eq__(self, other):
        return self.row == other.row and self.column == other.column
    
    
class Action(Enum):
    UP = 1
    DOWN = -1
    LEFT = 2
    RIGHT = -2
    


class Environment():#迁移函数和奖励函数
    def __init__(self, grid, move_prob = 0.8):
        '''
        0:普通格子
        -1:有危险的格子(游戏结束)
        1:有奖励的格子(游戏结束)
        9:被屏蔽的格子(无法放置智能体)
        '''
        self.grid = grid
        '''
        默认的奖励是负数,就像施加了初始位置
        '''
        self.default_reward = -0.04
        '''
        智能体能够以move_prob的概率向所选方向移动
        '''
        self.move_prob = move_prob
        self.reset()
        
    @property
    def row_length(self):
        return len(self.grid)
    
    @property
    def column_length(self):
        return len(self.grid[0])
    
    @property
    def actions(self):
        return [Action.UP, Action.DOWN, Action.LEFT, Action.RIGHT]
    
    @property
    def states(self):
        states = []
        for row in range(self.row_length):
            for column in range(self.column_length):
                if self.grid[row][column] != 9:
                    states.append(State(row, column))
        return states

    def reset(self):
        #初始位置在左下角
        self.agent_state = State(self.row_length - 1, 0)
        return self.agent_state
    
    def can_action_at(self, state):
        if self.grid[state.row][state.column] == 0:
            return True
        else:
            return False
        
    def _move(self, state, action):
        if not self.can_action_at(state):
            raise Exception("Can't move from here!")
        
        next_state = state.clone()
        
        if action == Action.UP:
            next_state.row -= 1
        elif action == Action.DOWN:
            next_state.row += 1
        elif action == Action.LEFT:
            next_state.column -= 1
        elif action == Action.RIGHT:
            next_state.column += 1
            
        #检查状态是否在grid外
        if not (0 <= next_state.row < self.row_length):
            next_state = state
        if not (0 <= next_state.column < self.column_length):
            next_state = state
        
        #检查智能体是否到达了被屏蔽的格子
        if self.grid[next_state.row][next_state.column] == 9:
            next_state = state
            
        return next_state
        
    def transit_func(self, state, action):
        transition_probs = {}
        if not self.can_action_at(state):
            #游戏结束
            return transition_probs
        
        opposite_direction = Action(action.value * -1)
        
        for a in self.actions:
            prob = 0
            if a == action:
                prob = self.move_prob
            elif a != opposite_direction:
                prob = (1 - self.move_prob) / 2
            
            next_state = self._move(state, a)
            if next_state not in transition_probs:
                #求期望
                transition_probs[next_state] = prob
            else:
                transition_probs[next_state] += prob
        return transition_probs
    
    def reward_func(self, state):
        reward = self.default_reward
        done = False
        #检查下一种状态的属性
        attribute = self.grid[state.row][state.column]
        if attribute == 1:
            #获得奖励,游戏结束
            reward = 1
            done = True
        elif attribute == -1:
            #遇到危险,游戏结束
            reward = -1
            done = True
        return reward, done
    
    def transit(self, state, action):
        transition_probs = self.transit_func(state, action)
        if len(transition_probs) == 0:
            return None, None, True
        
        next_states = []
        probs = []
        for s in transition_probs:
            next_states.append(s)
            probs.append(transition_probs[s])
        
        next_state = np.random.choice(next_states, p = probs)
        reward, done = self.reward_func(next_state)
        return next_state, reward, done
    
    def step(self, action):
        next_state, reward, done = self.transit(self.agent_state, action)
        if next_state is not None:
            self.agent_state = next_state
        return next_state, reward, done
    
    
class Agent():
    def __init__(self, env):
        self.actions = env.actions
    
    def policy(self, state):
        return random.choice(self.actions)
    

def main():
    grid = [[0, 0, 0, 1], [0, 9, 0, -1], [0, 0, 0, 0]]
    env = Environment(grid)#环境包括迁移函数和奖励函数
    agent = Agent(env)
    
    for i in range(10):
        state = env.reset()
        total_reward = 0
        done = False
        
        while not done:
            action = agent.policy(state)
            next_state, reward, done = env.step(action)
            total_reward += reward
            state = next_state
            
        print("Episode{}:Agent gets {} reward.".format(i, total_reward))

if __name__ == "__main__":
    main()
    

本帖被以下淘专辑推荐:

想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复

使用道具 举报

 楼主| 发表于 2021-9-19 22:36:31 | 显示全部楼层
想知道小甲鱼最近在做啥?请访问 -> ilovefishc.com
回复

使用道具 举报

您需要登录后才可以回帖 登录 | 立即注册

本版积分规则

小黑屋|手机版|Archiver|鱼C工作室 ( 粤ICP备18085999号-1 | 粤公网安备 44051102000585号)

GMT+8, 2025-1-13 07:49

Powered by Discuz! X3.4

© 2001-2023 Discuz! Team.

快速回复 返回顶部 返回列表