triggon ソースコード(日本語版)

自作アプリ/ツール紹介

投稿日:

自作Pythonライブラリ「triggon」の日本語版ソースコードです。(ver. 0.1.0)

triggon.py

from types import FrameType
from typing import Any

from ._exceptions import (
    SYMBOL,
    _ExitEarly,
    InvalidArgumentError, 
    _check_label_type,
    _compare_value_counts,
    _count_symbol,
    _handle_arg_types,
)
from . import _debug
from . import _var_analysis
from . import _var_update
from .trig_func import TrigFunc


class Triggon:
    debug: bool
    _debug_var: dict[str, tuple[int, str] | list[tuple[int, str]]]
    _trigger_flag: dict[str, bool]
    _new_value: dict[str, tuple[Any]]
    _org_value: dict[str, tuple[Any]]
    _var_list: dict[str, tuple[str]]
    _disable_label: dict[str, bool]
    _id_list: dict[str, int | tuple[int]]
    _return_value: tuple[bool, Any] | None = None
    _lineno: int = None
    _frame: FrameType = None

    def __init__(
        self, label: str | dict[str, Any], /, new: Any=None, 
        *, debug: bool=False,
    ) -> None:
      """
      ラベルとその値を登録します。

      値は配置位置に基づいてインデックスが割り振られます。
      配列を1つの値として設定したい場合は、さらに別の配列に入れてください。

      対応形式:
      - label, value
      - label, [value]
      - label, (value,)
      - {label: value}
      - {label: [value]}
      - {label: (value,)}
      """

      self.debug = debug
      self._trigger_flag = {}
      self._new_value = {}
      self._org_value = {}
      self._var_list = {}  
      self._disable_label = {}   
      self._id_list = {}

      change_list = _handle_arg_types(label, new)
      self._scan_dict(change_list)

    def _scan_dict(self, arg_dict: dict[str, Any]) -> None:      
      for key, value in arg_dict.items():          
          label = key.lstrip(SYMBOL)
          index = _count_symbol(key)

          if index != 0:
              raise InvalidArgumentError(
                  f" `{key}`の先頭にある '*' をすべて取り除いてください。 " 
                  "インデックスを指定するために, "
                  "リスト/タプル内に任意のインデックスの位置に値を配置してください。"
              )

          try:
            self._new_value[label]

            raise InvalidArgumentError(
              f"`{label}`はすでに登録済みです" 
              "この関数では重複したラベルは登録できません。"
            ) 
          except KeyError:
            self._add_new_label(label, value)

    def _add_new_label(self, label: str, value: Any, /) -> None:
      if isinstance(value, (list, tuple)):
        length = len(value)
      else:
        length = 1

      if isinstance(value, list) and length > 1: 
        self._new_value[label] = tuple(value)
      elif isinstance(value, tuple) and length > 1:
        self._new_value[label] = value
      elif isinstance(value, list) and length == 1:
        if isinstance(value[0], (list, tuple)):
          self._new_value[label] = value
        else:
          self._new_value[label] = (value[0],)
      elif isinstance(value, tuple) and length == 1:
        if isinstance(value[0], (list, tuple)):
          self._new_value[label] = value 
        else:
          self._new_value[label] = value
      else:
        # 配列ではない単一の値
        self._new_value[label] = (value,)
        
      self._trigger_flag[label] = False
      self._disable_label[label] = False

      # 送られた値のインデックスの数だけNoneを設定する
      self._org_value[label] = [None] * length
      self._var_list[label] = [None] * length
      self._id_list[label] = [None] * length

    def set_trigger(
        self, label: str | list[str] | tuple[str, ...], /,
    ) -> None:
      """
      引数のラベルのフラグをTrueに設定します。

      `alter_var()`によって変数が登録されてる場合, 
      この関数内で値が更新されます。
      """

      if (
        isinstance(label, (list, tuple)) 
        and all(isinstance(val, str) for val in label)
      ):
        for v in label:
          self._check_label_flag(v)
      elif isinstance(label, str):
        self._check_label_flag(label)       
      else:
        raise InvalidArgumentError(
          "'ラベルは単一の文字列、"
          "またはリスト/タプル内に文字列のみ入れてください。"
        )

    def alter_literal(
        self, label: str, /, org: Any, *, index: int=None,
    ) -> Any:
      """
      引数のラベルのフラグがTrueの場合、その値を変更します。

      変数以外の式やリテラルのみ対応しています。
      """

      _check_label_type(label)
      
      name = label.lstrip(SYMBOL)
      self._check_exist_label(name)

      if index is None:
        index = _count_symbol(label)
      _compare_value_counts(self._new_value[name], index)

      self._org_value[name][index] = org
      flag = self._trigger_flag[name]

      if self.debug:
        self._get_target_frame("alter_literal")
        self._print_val_debug(name, index, flag, org)

      if not flag:
        return self._org_value[name][index]

      return self._new_value[name][index]  

    def alter_var(
          self, label: str | dict[str, Any], var: Any=None, /, 
          *, index: int=None,
    ) -> None:
        """
        引数のラベルのフラグがTrueの場合、
        その変数の値を変更します。

        変数のみ対応しています。(式やリテラル以外)

        対応変数の種類:
        - グローバル変数
        - クラス変数
        """

        (change_list, arg_type) = _handle_arg_types(label, var, index, True)
        init_flag = False

        if len(change_list) == 1:
          # 単一のラベルの場合
          label = next(iter(change_list))
          _check_label_type(label)

          name = label.lstrip(SYMBOL)
          self._check_exist_label(name)

          if index is None:
            index = _count_symbol(label)
          _compare_value_counts(self._new_value[name], index)

          if (
            self._var_list[name][index] is None 
            or self._is_new_var(name, index, var)
          ):
            self._store_org_value(name, index, change_list[label])
 
            self._get_target_frame("alter_var")
            self._lineno = self._frame.f_lineno     

            # 変数保存の初回処理
            self._init_arg_list(change_list, arg_type, index)
            init_flag = True

          if not init_flag:
            return

          trig_flag = self._trigger_flag[name]
          vars = self._var_list[name][index]

          if not trig_flag:
            if self.debug:
              self._get_target_frame("alter_var")

              self._print_var_debug(
                vars, name, index, trig_flag, change_list[label],
              )   
            self._drop_debug_info()

            return
          
          self._update_var_value(vars, self._new_value[name][index])  

          if self.debug:
            self._print_var_debug(
              vars, name, index, trig_flag, change_list[label], 
              self._new_value[name][index], change=True,
            )
        else:
           # 複数のラベルの場合(辞書)
          if index is not None:
            raise InvalidArgumentError(
              "Cannot use the `index` keyword with a dictionary. " 
              "Use '*' in the label instead." 
            )
          
          for key, val in change_list.items():
            _check_label_type(key)
            
            name = key.lstrip(SYMBOL)
            index = _count_symbol(key)

            if self._org_value[name][index] is None:
              self._store_org_value(name, index, val)

            if (
              not init_flag
              and (self._var_list[name][index] is None 
              or self._is_new_var(name, index, val))
            ):    
              self._get_target_frame("alter_var")
              self._lineno = self._frame.f_lineno

               # 変数保存の初回処理
              self._init_arg_list(change_list, arg_type)
              init_flag = True
            
            if not init_flag:
              continue

            trig_flag = self._trigger_flag[name]
            vars = self._var_list[name][index]  

            if not trig_flag:
              if self.debug:
                self._get_target_frame("alter_var")
                self._print_var_debug(vars, name, index, trig_flag, val)

              continue          

            self._update_var_value(vars, self._new_value[name][index])

            if self.debug:
              self._get_target_frame("alter_var")
              self._print_var_debug(
                vars, name, index, trig_flag, val, 
                self._new_value[name][index], change=True,
              )
            
        self._drop_debug_info()

    def revert(self, label: str, /, *, disable: bool=False) -> None:
      """
      引数のラベルのフラグをFalseに設定します。
      'disable'がTrueに設定された場合、永続的にフラグをFalseにします。
      """
      
      name = label.lstrip(SYMBOL)
      self._check_exist_label(name)

      if not self._trigger_flag[name]:
        return

      if disable:
        state = "disable" # デバッグ用
        self._disable_label[name] = True
      else:
        state = "inactive" # デバッグ用
      self._trigger_flag[name] = False

      self._label_has_var(name, "revert", True)

      if self.debug:
        self._get_target_frame("revert")
        self._print_flag_debug(name, state)
    
    def exit_point(self, label: str, func: TrigFunc, /) -> None | Any:
      """
      引数と同じラベルの`trigger_return()`によって実行された早期リターンは、
      `func`に渡された関数のところまで処理が戻ります。 
    
      """

      name = label.lstrip(SYMBOL)
      self._check_exist_label(name)

      try:
          return func()
      except _ExitEarly:
          if not self._return_value[0]:
              return self._return_value[1]
          
          print(self._return_value[1])

    def trigger_return(
        self, label: str, /, *, index: int=None, do_print: bool=False,
    ) -> None | Any:
        """
        引数のラベルのフラグがTrueの場合、
        早期リターンを実行と共に設定された値を返します。

        `do_print`がTrueの場合、早期リターンで設定された値を出力します。
        値が文字列でない場合、`InvalidArgumentError`が発生します。
        """

        name = label.lstrip(SYMBOL)
        self._check_exist_label(name)

        if index is None:
           index = _count_symbol(label)

        if not self._trigger_flag[name]:
            return 
            
        if do_print:
            if not isinstance(self._new_value[name][index], str):
              raise InvalidArgumentError(
                 "値は文字列である必要がありますが、"
                 f"`{type(self._new_value[name][index]).__name__}`が渡されました。"
              )         
        self._return_value = (do_print, self._new_value[name][index])

        if self.debug:
          self._get_target_frame("trigger_return")
          self._print_trig_debug(name, "Return")

        raise _ExitEarly 
        
    def trigger_func(self, label: str, func: TrigFunc, /) -> None | Any:
        """
        引数のラベルのフラグがTrueの場合、`func`に渡された関数を実行します。
        """

        name = label.lstrip(SYMBOL)
        self._check_exist_label(name)

        if self._trigger_flag[name]:
            if self.debug:
              self._get_target_frame("trigger_func")
              self._print_trig_debug(name, "Trigger a function") 
                    
            return func()


for name, func in vars(_var_analysis).items():
    if callable(func):
        setattr(Triggon, name, func)


for name, func in vars(_debug).items():
    if callable(func):
        setattr(Triggon, name, func)


for name, func in vars(_var_update).items():
    if callable(func):
        setattr(Triggon, name, func)

tirg_func.py

import inspect
from typing import Any, Callable


class TrigFunc:
  """
  `trigger_return()` と `trigger_func()`を使う際に、
   引数に入れる関数の実行を遅延させます。

  対象関数を包んで引数に渡してください。

  必ずクラスインスタンス変数を作成してから使ってください。
  (例: F = TrigFunc()) 
  """

  _func: Callable | None

  def __init__(self, func: Callable=None) -> None:
    self._func = func

  def __call__(self, *args, **kwargs) -> "TrigFunc":
    if self._func is None:
      raise ValueError("`func` is None")
    
    return TrigFunc(self._func(*args, **kwargs))
    
  def __getattr__(self, name: str) -> Callable[[], Any]:    
    if self._func is not None:
      target = getattr(self._func, name)
    else:
      frame = inspect.currentframe().f_back
      target = frame.f_locals.get(name) or frame.f_globals.get(name)

    if target is None:
        raise AttributeError(f"'{name}' is not a callable function")
    elif not callable(target):
        return TrigFunc(target)
      
    def _wrapper(*args, **kwargs) -> Callable[[], Any]:
      return lambda: target(*args, **kwargs)
    return _wrapper

_var_update.py

import inspect
from typing import Any

from ._exceptions import SYMBOL


def _store_org_value(
    self, label: str, index: int, org_value: Any,
) -> None:
    if isinstance(org_value, (list, tuple)):
      self._org_value[label][index] = []  

      for v in org_value:
          self._org_value[label][index].append(v)
    else:
      self._org_value[label][index] = org_value

def _update_var_value(
    self, 
    var_ref: tuple[Any] | list[tuple[Any]], update_value: Any,
) -> None:
    # 変数情報:
    # - グローバル変数 -> (行番号, 変数名)
    # - クラス変数 -> (行番号, 属性名, クラスインスタンス)
    # またはリストにこれらが複数入っている
    if isinstance(var_ref, list):
        # 複数の値の場合は、リストに入れられている
        for value in var_ref:
            if len(value) == 3:
                setattr(value[2], value[1], update_value)
            else:
                self._frame.f_globals[value[1]] = update_value    
    else:     
        if len(var_ref) == 3:
            setattr(var_ref[2], var_ref[1], update_value)
        else:  
            self._frame.f_globals[var_ref[1]] = update_value

def _check_exist_label(self, label: str) -> None:
    try:
        self._new_value[label]
    except KeyError:
        raise KeyError(f"`{label}` has not been set.")

def _check_label_flag(self, label: str) -> None:
    name = label.lstrip(SYMBOL)
    self._check_exist_label(name)

    if self._disable_label[name] or self._trigger_flag[name]:
        return
    
    self._trigger_flag[name] = True
    self._label_has_var(name, "set_trigger", False)

    if self.debug:
        self._get_target_frame("set_trigger")
        self._print_flag_debug(name, "active", reset=False)   

def _label_has_var(
    self, label: str, called_func: str, to_org: bool,
) -> None:
    if self._var_list[label] is None:
        return
    
    if to_org:
        update_value = self._org_value[label]
    else:
        update_value = self._new_value[label]

    self._get_target_frame(called_func)

    # 特定ラベルの全ての登録されてる変数を、
    # 元の値または、設定された値に更新する
    for i in range(len(self._var_list[label])):
        arg = self._var_list[label][i]

        if arg is None:
            continue
        elif isinstance(arg, list):
            for i_2, v in enumerate(arg):
                if not isinstance(update_value[i], tuple):
                    self._update_var_value(v, update_value[i])
                    continue                              
                self._update_var_value(v, update_value[i][i_2])
        else:
            self._update_var_value(arg, update_value[i])     

def _is_new_var(self, label: str, index: int, value: Any) -> bool:
    if self._trigger_flag[label]:
        if self._new_value[label][index] == value:
            return False
    
        return True
    
    if self._org_value[label][index] == value:
        return False
    
    return True

def _get_target_frame(self, target_name: str) -> None:
   if self._frame is not None:
      return
   
   frame = inspect.currentframe()

   while frame:
      if frame.f_code.co_name == target_name:
         self._frame = frame.f_back
         return
      frame = frame.f_back

_var_analysis.py

import ast
import linecache
from itertools import count

from ._exceptions import (
    SYMBOL, 
    InvalidArgumentError, 
    _compare_value_counts,
    _count_symbol, 
)


LABEL_ERROR = "ラベルは文字列で渡してください。"
VALUE_TYPE_ERROR = "`value`には変数を入れてください。"
NEST_ERROR = "この関数では配列内でネストすることは出来ません。"
VAR_ERROR = "ローカル変数は対応していません。"


def _init_arg_list(
      self, change_list, arg_type: ast.AST, index: int=None,
) -> None:
    if index is None:
       has_index = False
    else:
       has_index = True

    # 変数を照合するためのIDを保存
    for key, val in change_list.items():   
        name = key.lstrip(SYMBOL)
        self._check_exist_label(name)

        if not has_index:
          index = _count_symbol(key)
        _compare_value_counts(self._new_value[name], index)

        if isinstance(val, (list, tuple)):
            self._id_list[name][index] = []

            for v in val:
               if isinstance(v, (list, tuple, dict)):
                  raise InvalidArgumentError(NEST_ERROR)                       
               self._id_list[name][index].append(id(v))     

            continue
        elif isinstance(val, dict):
           raise InvalidArgumentError(VALUE_TYPE_ERROR)

        self._id_list[name][index] = id(val)
 
    file_name = self._frame.f_code.co_filename 
    self._trace_func_call(file_name, arg_type)

def _trace_func_call(self, file_name: str, arg_type: ast.AST) -> None:
    lines = []
    
    for i in count(self._lineno):
      line = linecache.getline(file_name, i)
      if not line:
        break
      lines.append(line.lstrip())

      try:
        line_range = ast.parse("".join(lines))

        # 呼び出されたの関数を見つけるためにASTノードを巡回
        
        for node in ast.walk(line_range):
          if not isinstance(node, ast.Call):
            continue    

          if (
             not isinstance(node.func, ast.Attribute) 
             or node.func.attr != "alter_var"
            ): 
            continue

          first_arg = node.args[0]

          # 引数の型によって処理が分岐

          if arg_type == ast.Dict and isinstance(first_arg, ast.Dict):
            result = self._arg_is_dict(first_arg)
          else:
             second_arg = node.args[1]
             _identify_arg(second_arg)

             if isinstance(first_arg, ast.Name):
                try:
                  label = self._frame.f_locals[first_arg.id]
                except KeyError:
                   label = self._frame.f_globals[first_arg.id]
             elif isinstance(first_arg, ast.Attribute):
                instance = self._frame.f_locals[first_arg.value.id]
                field = instance.__dict__[first_arg.attr]

                label = field
             elif isinstance(first_arg, ast.Constant):
                label = first_arg.value
             else:
                linecache.clearcache()
                break

             name = label.lstrip(SYMBOL)       

             # `index`キーワードが設定されてるかの確認
             if 0 < len(node.keywords) <= 3:
                index = None

                for kw in node.keywords:
                   if kw.arg != "index":
                      continue
                   
                   index_node = kw.value

                   if isinstance(index_node, ast.Constant):
                      index = index_node.value
                   else:
                      # 現在は`index`にはリテラル値のみ使えますが、
                      # 将来的に変数にも対応するかもしれません。
                      raise InvalidArgumentError(
                         "`index` キーワードにはリテラル値を入れてください。"   
                      )                    

                if index is None:
                   index = _count_symbol(name)
             else:
                index = _count_symbol(name)
          
             if (
                arg_type == ast.List 
                and isinstance(second_arg, (ast.List, ast.Tuple))
             ):
                result = self._arg_is_seq(second_arg, name, index)
             elif arg_type == ast.Name and isinstance(second_arg, ast.Name):
                result = self._arg_is_name(second_arg.id, name, index)
             elif (
                arg_type == ast.Name 
                and isinstance(second_arg, ast.Attribute)
             ):
                result = self._arg_is_attr(second_arg, name, index)
             else:
                linecache.clearcache()
                break

          # resultが1の場合は目的の関数ではない
          if result == 1:
            linecache.clearcache()
            break     
          return

      # 関数が一行で終わっていない場合、このエラーは無視。
      # (その関数の終わりの行を探すため)
      except SyntaxError:
        continue

    raise RuntimeError(
       "ソースコード内の`alter_var' が見つかりませんでした。"
    )   
  
def _arg_is_dict(self, target: ast.Dict) -> int:
    arg_list = _deduplicate_labels(target)
  
    for key, val in arg_list.items():
      label = key.lstrip(SYMBOL)

      if self._id_list.get(label) is None:
        return 1
      index = _count_symbol(key)

      # 引数の型によって処理が分岐

      if isinstance(val, (ast.List, ast.Tuple)):
        result = self._arg_is_seq(val, label, index)
        if result == 1:
          return 1
      elif isinstance(val, ast.Dict):
         raise InvalidArgumentError(VALUE_TYPE_ERROR)
      else:
        _identify_arg(val)

        if isinstance(val, ast.Name):
          result = self._arg_is_name(val.id, label, index)
        else:
          result = self._arg_is_attr(val, label, index)

        if result == 1:
          return 1

    return 0

def _arg_is_seq(
      self, target: ast.List | ast.Tuple, label: str, index: int,
) -> tuple[int, int]:
    target_index = self._var_list[label][index] 

    if target_index is None:
      self._var_list[label][index] = []
    elif not isinstance(target_index, list):
       # 値を追加するためリストに変換
       self._var_list[label][index] = [target_index]

    for i, val in enumerate(target.elts):
      _identify_arg(val)

      if isinstance(val, ast.Name):
        _ = self._arg_is_name(val.id, label, index, i)
      elif isinstance(val, ast.Attribute):
        result = self._arg_is_attr(val, label, index, i)  
        if result == 1:
          return 1
      else:
        raise InvalidArgumentError(NEST_ERROR)

    return 0

def _arg_is_name(
      self, target: str, label: str, index: int, inner_index: int=None,
) -> int:
    self._check_exist_label(label)
    _compare_value_counts(self._new_value[label], index)

    target_id = self._get_list_id(label, index, inner_index)
    if target_id is None:
       return 1

    try:
       self._frame.f_globals[target]
    except KeyError:
       raise InvalidArgumentError(VAR_ERROR)

    target_index = self._var_list[label][index]
    # (行番号, 変数名)
    if isinstance(target_index, list):
      self._var_list[label][index].append((self._lineno, target))
    elif target_index is not None:
       self._var_list[label][index] = [target_index]
       self._var_list[label][index].append((self._lineno, target))
    else:
       self._var_list[label][index] = (self._lineno, target)

    return 0

def _arg_is_attr(
      self, target: ast.Attribute, label: str, 
      index: int, inner_index: int=None,
) -> int:
    self._check_exist_label(label)
    _compare_value_counts(self._new_value[label], index)

    target_id = self._get_list_id(label, index, inner_index)
    if target_id is None:
       return 1

    instance = self._frame.f_locals[target.value.id]
    field = instance.__dict__[target.attr]

    if id(field) != target_id:
      return 1
    
    target_index = self._var_list[label][index]

    # (行番号, 属性名, クラスインスタンス)
    if isinstance(target_index, list):
      self._var_list[label][index].append(
         (self._lineno, target.attr, instance)
      )
    elif target_index is not None:
       self._var_list[label][index] = [target_index]
       self._var_list[label][index].append(
          (self._lineno, target.attr, instance)
       )
    else:
       self._var_list[label][index] = (self._lineno, target.attr, instance)

    return 0
      
def _get_list_id(
      self, label: str, index: int, inner_index: int=None,
) -> int | None:
    target_id = self._id_list[label][index]

    if isinstance(target_id, list):
       if inner_index is not None and len(target_id) > inner_index:
         return target_id[inner_index]
       return None
    elif inner_index is not None:
      return None
    else:
      return target_id
         
   
def _deduplicate_labels(target: ast.Dict) -> dict[str, ast.AST]:
    # 重複したラベルを排除する
    sorted_list = {}

    for key, val in zip(target.keys, target.values):
      sorted_list[key.value] = val

    return sorted_list


def _identify_arg(target: ast.AST) -> None:
    if isinstance(target, ast.Constant):
      raise InvalidArgumentError(VALUE_TYPE_ERROR)

現在ベータ版のため、コードが変更される可能性があります!


オリジナルソースコード(GitHub) → tsuruko12/triggon: Automatically switches values at labeled trigger points, supporting multi-value switching, an early return, and a function call.

コメント

タイトルとURLをコピーしました