Input: serio - use guard notation when acquiring mutexes and spinlocks

Using guard notation makes the code more compact and error handling
more robust by ensuring that locks are released in all code paths
when control leaves critical section.

Link: https://lore.kernel.org/r/20240905041732.2034348-20-dmitry.torokhov@gmail.com
Signed-off-by: Dmitry Torokhov <dmitry.torokhov@gmail.com>
This commit is contained in:
Dmitry Torokhov 2024-09-04 21:17:24 -07:00
parent f7d15dcc24
commit 924c5eeb17

View File

@ -38,33 +38,27 @@ static void serio_attach_driver(struct serio_driver *drv);
static int serio_connect_driver(struct serio *serio, struct serio_driver *drv)
{
int retval;
guard(mutex)(&serio->drv_mutex);
mutex_lock(&serio->drv_mutex);
retval = drv->connect(serio, drv);
mutex_unlock(&serio->drv_mutex);
return retval;
return drv->connect(serio, drv);
}
static int serio_reconnect_driver(struct serio *serio)
{
int retval = -1;
guard(mutex)(&serio->drv_mutex);
mutex_lock(&serio->drv_mutex);
if (serio->drv && serio->drv->reconnect)
retval = serio->drv->reconnect(serio);
mutex_unlock(&serio->drv_mutex);
return serio->drv->reconnect(serio);
return retval;
return -1;
}
static void serio_disconnect_driver(struct serio *serio)
{
mutex_lock(&serio->drv_mutex);
guard(mutex)(&serio->drv_mutex);
if (serio->drv)
serio->drv->disconnect(serio);
mutex_unlock(&serio->drv_mutex);
}
static int serio_match_port(const struct serio_device_id *ids, struct serio *serio)
@ -147,9 +141,8 @@ static LIST_HEAD(serio_event_list);
static struct serio_event *serio_get_event(void)
{
struct serio_event *event = NULL;
unsigned long flags;
spin_lock_irqsave(&serio_event_lock, flags);
guard(spinlock_irqsave)(&serio_event_lock);
if (!list_empty(&serio_event_list)) {
event = list_first_entry(&serio_event_list,
@ -157,7 +150,6 @@ static struct serio_event *serio_get_event(void)
list_del_init(&event->node);
}
spin_unlock_irqrestore(&serio_event_lock, flags);
return event;
}
@ -171,9 +163,8 @@ static void serio_remove_duplicate_events(void *object,
enum serio_event_type type)
{
struct serio_event *e, *next;
unsigned long flags;
spin_lock_irqsave(&serio_event_lock, flags);
guard(spinlock_irqsave)(&serio_event_lock);
list_for_each_entry_safe(e, next, &serio_event_list, node) {
if (object == e->object) {
@ -189,15 +180,13 @@ static void serio_remove_duplicate_events(void *object,
serio_free_event(e);
}
}
spin_unlock_irqrestore(&serio_event_lock, flags);
}
static void serio_handle_event(struct work_struct *work)
{
struct serio_event *event;
mutex_lock(&serio_mutex);
guard(mutex)(&serio_mutex);
while ((event = serio_get_event())) {
@ -228,8 +217,6 @@ static void serio_handle_event(struct work_struct *work)
serio_remove_duplicate_events(event->object, event->type);
serio_free_event(event);
}
mutex_unlock(&serio_mutex);
}
static DECLARE_WORK(serio_event_work, serio_handle_event);
@ -237,11 +224,9 @@ static DECLARE_WORK(serio_event_work, serio_handle_event);
static int serio_queue_event(void *object, struct module *owner,
enum serio_event_type event_type)
{
unsigned long flags;
struct serio_event *event;
int retval = 0;
spin_lock_irqsave(&serio_event_lock, flags);
guard(spinlock_irqsave)(&serio_event_lock);
/*
* Scan event list for the other events for the same serio port,
@ -253,7 +238,7 @@ static int serio_queue_event(void *object, struct module *owner,
list_for_each_entry_reverse(event, &serio_event_list, node) {
if (event->object == object) {
if (event->type == event_type)
goto out;
return 0;
break;
}
}
@ -261,16 +246,14 @@ static int serio_queue_event(void *object, struct module *owner,
event = kmalloc(sizeof(*event), GFP_ATOMIC);
if (!event) {
pr_err("Not enough memory to queue event %d\n", event_type);
retval = -ENOMEM;
goto out;
return -ENOMEM;
}
if (!try_module_get(owner)) {
pr_warn("Can't get module reference, dropping event %d\n",
event_type);
kfree(event);
retval = -EINVAL;
goto out;
return -EINVAL;
}
event->type = event_type;
@ -280,9 +263,7 @@ static int serio_queue_event(void *object, struct module *owner,
list_add_tail(&event->node, &serio_event_list);
queue_work(system_long_wq, &serio_event_work);
out:
spin_unlock_irqrestore(&serio_event_lock, flags);
return retval;
return 0;
}
/*
@ -292,9 +273,8 @@ static int serio_queue_event(void *object, struct module *owner,
static void serio_remove_pending_events(void *object)
{
struct serio_event *event, *next;
unsigned long flags;
spin_lock_irqsave(&serio_event_lock, flags);
guard(spinlock_irqsave)(&serio_event_lock);
list_for_each_entry_safe(event, next, &serio_event_list, node) {
if (event->object == object) {
@ -302,8 +282,6 @@ static void serio_remove_pending_events(void *object)
serio_free_event(event);
}
}
spin_unlock_irqrestore(&serio_event_lock, flags);
}
/*
@ -315,23 +293,19 @@ static void serio_remove_pending_events(void *object)
static struct serio *serio_get_pending_child(struct serio *parent)
{
struct serio_event *event;
struct serio *serio, *child = NULL;
unsigned long flags;
struct serio *serio;
spin_lock_irqsave(&serio_event_lock, flags);
guard(spinlock_irqsave)(&serio_event_lock);
list_for_each_entry(event, &serio_event_list, node) {
if (event->type == SERIO_REGISTER_PORT) {
serio = event->object;
if (serio->parent == parent) {
child = serio;
break;
}
if (serio->parent == parent)
return serio;
}
}
spin_unlock_irqrestore(&serio_event_lock, flags);
return child;
return NULL;
}
/*
@ -382,29 +356,27 @@ static ssize_t drvctl_store(struct device *dev, struct device_attribute *attr, c
struct device_driver *drv;
int error;
error = mutex_lock_interruptible(&serio_mutex);
if (error)
return error;
if (!strncmp(buf, "none", count)) {
serio_disconnect_port(serio);
} else if (!strncmp(buf, "reconnect", count)) {
serio_reconnect_subtree(serio);
} else if (!strncmp(buf, "rescan", count)) {
serio_disconnect_port(serio);
serio_find_driver(serio);
serio_remove_duplicate_events(serio, SERIO_RESCAN_PORT);
} else if ((drv = driver_find(buf, &serio_bus)) != NULL) {
serio_disconnect_port(serio);
error = serio_bind_driver(serio, to_serio_driver(drv));
serio_remove_duplicate_events(serio, SERIO_RESCAN_PORT);
} else {
error = -EINVAL;
scoped_cond_guard(mutex_intr, return -EINTR, &serio_mutex) {
if (!strncmp(buf, "none", count)) {
serio_disconnect_port(serio);
} else if (!strncmp(buf, "reconnect", count)) {
serio_reconnect_subtree(serio);
} else if (!strncmp(buf, "rescan", count)) {
serio_disconnect_port(serio);
serio_find_driver(serio);
serio_remove_duplicate_events(serio, SERIO_RESCAN_PORT);
} else if ((drv = driver_find(buf, &serio_bus)) != NULL) {
serio_disconnect_port(serio);
error = serio_bind_driver(serio, to_serio_driver(drv));
serio_remove_duplicate_events(serio, SERIO_RESCAN_PORT);
if (error)
return error;
} else {
return -EINVAL;
}
}
mutex_unlock(&serio_mutex);
return error ? error : count;
return count;
}
static ssize_t serio_show_bind_mode(struct device *dev, struct device_attribute *attr, char *buf)
@ -526,9 +498,9 @@ static void serio_add_port(struct serio *serio)
int error;
if (parent) {
serio_pause_rx(parent);
guard(serio_pause_rx)(parent);
list_add_tail(&serio->child_node, &parent->children);
serio_continue_rx(parent);
}
list_add_tail(&serio->node, &serio_list);
@ -560,9 +532,9 @@ static void serio_destroy_port(struct serio *serio)
serio->stop(serio);
if (serio->parent) {
serio_pause_rx(serio->parent);
guard(serio_pause_rx)(serio->parent);
list_del_init(&serio->child_node);
serio_continue_rx(serio->parent);
serio->parent = NULL;
}
@ -701,10 +673,10 @@ EXPORT_SYMBOL(__serio_register_port);
*/
void serio_unregister_port(struct serio *serio)
{
mutex_lock(&serio_mutex);
guard(mutex)(&serio_mutex);
serio_disconnect_port(serio);
serio_destroy_port(serio);
mutex_unlock(&serio_mutex);
}
EXPORT_SYMBOL(serio_unregister_port);
@ -715,12 +687,12 @@ void serio_unregister_child_port(struct serio *serio)
{
struct serio *s, *next;
mutex_lock(&serio_mutex);
guard(mutex)(&serio_mutex);
list_for_each_entry_safe(s, next, &serio->children, child_node) {
serio_disconnect_port(s);
serio_destroy_port(s);
}
mutex_unlock(&serio_mutex);
}
EXPORT_SYMBOL(serio_unregister_child_port);
@ -784,10 +756,10 @@ static void serio_driver_remove(struct device *dev)
static void serio_cleanup(struct serio *serio)
{
mutex_lock(&serio->drv_mutex);
guard(mutex)(&serio->drv_mutex);
if (serio->drv && serio->drv->cleanup)
serio->drv->cleanup(serio);
mutex_unlock(&serio->drv_mutex);
}
static void serio_shutdown(struct device *dev)
@ -850,7 +822,7 @@ void serio_unregister_driver(struct serio_driver *drv)
{
struct serio *serio;
mutex_lock(&serio_mutex);
guard(mutex)(&serio_mutex);
drv->manual_bind = true; /* so serio_find_driver ignores it */
serio_remove_pending_events(drv);
@ -866,15 +838,14 @@ void serio_unregister_driver(struct serio_driver *drv)
}
driver_unregister(&drv->driver);
mutex_unlock(&serio_mutex);
}
EXPORT_SYMBOL(serio_unregister_driver);
static void serio_set_drv(struct serio *serio, struct serio_driver *drv)
{
serio_pause_rx(serio);
guard(serio_pause_rx)(serio);
serio->drv = drv;
serio_continue_rx(serio);
}
static int serio_bus_match(struct device *dev, const struct device_driver *drv)
@ -935,14 +906,14 @@ static int serio_resume(struct device *dev)
struct serio *serio = to_serio_port(dev);
int error = -ENOENT;
mutex_lock(&serio->drv_mutex);
if (serio->drv && serio->drv->fast_reconnect) {
error = serio->drv->fast_reconnect(serio);
if (error && error != -ENOENT)
dev_warn(dev, "fast reconnect failed with error %d\n",
error);
scoped_guard(mutex, &serio->drv_mutex) {
if (serio->drv && serio->drv->fast_reconnect) {
error = serio->drv->fast_reconnect(serio);
if (error && error != -ENOENT)
dev_warn(dev, "fast reconnect failed with error %d\n",
error);
}
}
mutex_unlock(&serio->drv_mutex);
if (error) {
/*
@ -989,21 +960,17 @@ EXPORT_SYMBOL(serio_close);
irqreturn_t serio_interrupt(struct serio *serio,
unsigned char data, unsigned int dfl)
{
unsigned long flags;
irqreturn_t ret = IRQ_NONE;
guard(spinlock_irqsave)(&serio->lock);
spin_lock_irqsave(&serio->lock, flags);
if (likely(serio->drv))
return serio->drv->interrupt(serio, data, dfl);
if (likely(serio->drv)) {
ret = serio->drv->interrupt(serio, data, dfl);
} else if (!dfl && device_is_registered(&serio->dev)) {
if (!dfl && device_is_registered(&serio->dev)) {
serio_rescan(serio);
ret = IRQ_HANDLED;
return IRQ_HANDLED;
}
spin_unlock_irqrestore(&serio->lock, flags);
return ret;
return IRQ_NONE;
}
EXPORT_SYMBOL(serio_interrupt);