diff --git a/videocr/models.py b/videocr/models.py index 3ce8af6..3ad7afa 100644 --- a/videocr/models.py +++ b/videocr/models.py @@ -1,6 +1,7 @@ from __future__ import annotations from typing import List from dataclasses import dataclass +from fuzzywuzzy import fuzz CONF_THRESHOLD = 60 @@ -45,3 +46,33 @@ class PredictedFrame: self.confidence = sum(word.confidence for word in self.words) self.text = ''.join(word.text + ' ' for word in self.words).strip() + def is_similar_to(self, other: PredictedFrame, threshold=60) -> bool: + if len(self.text) == 0 or len(other.text) == 0: + return False + return fuzz.ratio(self.text, other.text) >= threshold + + +class PredictedSubtitle: + frames: List[PredictedFrame] + + def __init__(self, frames: List[PredictedFrame]): + self.frames = [f for f in frames if f.confidence > 0] + + @property + def text(self) -> str: + if self.frames: + conf_max = max(f.confidence for f in self.frames) + return next(f.text for f in self.frames if f.confidence == conf_max) + return '' + + @property + def index_start(self) -> int: + if self.frames: + return self.frames[0].index + return 0 + + @property + def index_end(self) -> int: + if self.frames: + return self.frames[-1].index + return 0