diff --git a/drivers/bus/mhi/host/boot.c b/drivers/bus/mhi/host/boot.c index 1c69feee1703..d2a19b07ccb8 100644 --- a/drivers/bus/mhi/host/boot.c +++ b/drivers/bus/mhi/host/boot.c @@ -391,6 +391,7 @@ void mhi_fw_load_handler(struct mhi_controller *mhi_cntrl) { const struct firmware *firmware = NULL; struct device *dev = &mhi_cntrl->mhi_dev->dev; + enum mhi_pm_state new_state; const char *fw_name; void *buf; dma_addr_t dma_addr; @@ -508,14 +509,18 @@ void mhi_fw_load_handler(struct mhi_controller *mhi_cntrl) } error_fw_load: - mhi_cntrl->pm_state = MHI_PM_FW_DL_ERR; - wake_up_all(&mhi_cntrl->state_event); + write_lock_irq(&mhi_cntrl->pm_lock); + new_state = mhi_tryset_pm_state(mhi_cntrl, MHI_PM_FW_DL_ERR); + write_unlock_irq(&mhi_cntrl->pm_lock); + if (new_state == MHI_PM_FW_DL_ERR) + wake_up_all(&mhi_cntrl->state_event); } int mhi_download_amss_image(struct mhi_controller *mhi_cntrl) { struct image_info *image_info = mhi_cntrl->fbc_image; struct device *dev = &mhi_cntrl->mhi_dev->dev; + enum mhi_pm_state new_state; int ret; if (!image_info) @@ -526,8 +531,11 @@ int mhi_download_amss_image(struct mhi_controller *mhi_cntrl) &image_info->mhi_buf[image_info->entries - 1]); if (ret) { dev_err(dev, "MHI did not load AMSS, ret:%d\n", ret); - mhi_cntrl->pm_state = MHI_PM_FW_DL_ERR; - wake_up_all(&mhi_cntrl->state_event); + write_lock_irq(&mhi_cntrl->pm_lock); + new_state = mhi_tryset_pm_state(mhi_cntrl, MHI_PM_FW_DL_ERR); + write_unlock_irq(&mhi_cntrl->pm_lock); + if (new_state == MHI_PM_FW_DL_ERR) + wake_up_all(&mhi_cntrl->state_event); } return ret;