diff --git a/src/tagstudio/core/library/alchemy/visitors.py b/src/tagstudio/core/library/alchemy/visitors.py index 31a10d76..b3d173e3 100644 --- a/src/tagstudio/core/library/alchemy/visitors.py +++ b/src/tagstudio/core/library/alchemy/visitors.py @@ -147,7 +147,7 @@ class SQLBoolExpressionBuilder(BaseVisitor[ColumnElement[bool]]): # raise exception if Constraint stays unhandled raise NotImplementedError("This type of constraint is not implemented yet") - def visit_property(self, node: Property) -> None: + def visit_property(self, node: Property) -> ColumnElement[bool]: raise NotImplementedError("This should never be reached!") def visit_not(self, node: Not) -> ColumnElement[bool]: diff --git a/src/tagstudio/core/query_lang/ast.py b/src/tagstudio/core/query_lang/ast.py index 9ebab448..102203ed 100644 --- a/src/tagstudio/core/query_lang/ast.py +++ b/src/tagstudio/core/query_lang/ast.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod from enum import Enum -from typing import Generic, TypeVar +from typing import Generic, TypeVar, Union class ConstraintType(Enum): @@ -12,7 +12,7 @@ class ConstraintType(Enum): Special = 5 @staticmethod - def from_string(text: str) -> "ConstraintType": + def from_string(text: str) -> Union["ConstraintType", None]: return { "tag": ConstraintType.Tag, "tag_id": ConstraintType.TagID, @@ -24,7 +24,7 @@ class ConstraintType(Enum): class AST: - parent: "AST" = None + parent: Union["AST", None] = None def __str__(self): class_name = self.__class__.__name__ diff --git a/src/tagstudio/core/query_lang/tokenizer.py b/src/tagstudio/core/query_lang/tokenizer.py index 07c40e7d..4970a5fe 100644 --- a/src/tagstudio/core/query_lang/tokenizer.py +++ b/src/tagstudio/core/query_lang/tokenizer.py @@ -26,19 +26,19 @@ class Token: start: int end: int - def __init__(self, type: TokenType, value: Any, start: int = None, end: int = None) -> None: + def __init__(self, type: TokenType, value: Any, start: int, end: int) -> None: self.type = type self.value = value self.start = start self.end = end @staticmethod - def from_type(type: TokenType, pos: int = None) -> "Token": + def from_type(type: TokenType, pos: int) -> "Token": return Token(type, None, pos, pos) @staticmethod - def EOF() -> "Token": # noqa: N802 - return Token.from_type(TokenType.EOF) + def EOF(pos: int) -> "Token": # noqa: N802 + return Token.from_type(TokenType.EOF, pos) def __str__(self) -> str: return f"Token({self.type}, {self.value}, {self.start}, {self.end})" # pragma: nocover @@ -50,7 +50,7 @@ class Token: class Tokenizer: text: str pos: int - current_char: str + current_char: str | None ESCAPABLE_CHARS = ["\\", '"', '"'] NOT_IN_ULITERAL = [":", " ", "[", "]", "(", ")", "=", ","] @@ -63,7 +63,7 @@ class Tokenizer: def get_next_token(self) -> Token: self.__skip_whitespace() if self.current_char is None: - return Token.EOF() + return Token.EOF(self.pos) if self.current_char in ("'", '"'): return self.__quoted_string() @@ -119,6 +119,8 @@ class Tokenizer: out = "" while escape or self.current_char != quote: + if self.current_char is None: + raise ParsingError(start, self.pos, "Unterminated quoted string") if escape: escape = False if self.current_char not in Tokenizer.ESCAPABLE_CHARS: