add Video class

This commit is contained in:
Yi Ge 2019-04-24 21:18:31 +02:00
parent 57d1dc7b9b
commit 63873af476
2 changed files with 62 additions and 18 deletions

View File

@ -3,8 +3,8 @@ from typing import List
from dataclasses import dataclass from dataclasses import dataclass
CONFIDENCE_THRESHOLD = 60 CONF_THRESHOLD = 60
# predictions with lower confidence will be filtered out # word predictions with lower confidence will be filtered out
@dataclass @dataclass
@ -15,33 +15,33 @@ class PredictedWord:
class PredictedFrame: class PredictedFrame:
index: int # 0-based index of the frame
words: List[PredictedWord] words: List[PredictedWord]
confidence: int # total confidence of all words
text: str
def __init__(self, pred_data: str): def __init__(self, index, pred_data: str):
self.index = index
self.words = [] self.words = []
block_current = 1 block = 0 # keep track of line breaks
for line in pred_data.split('\n')[1:]:
tmp = line.split() for l in pred_data.splitlines()[1:]:
if len(tmp) < 12: word_data = l.split()
if len(word_data) < 12:
# no word is predicted # no word is predicted
continue continue
_, _, block_num, *_, conf, text = tmp _, _, block_num, *_, conf, text = word_data
block_num, conf = int(block_num), int(conf) block_num, conf = int(block_num), int(conf)
# handle line breaks # handle line breaks
if block_current < block_num: if block < block_num:
block_current = block_num block = block_num
self.words.append(PredictedWord(0, '\n')) self.words.append(PredictedWord(0, '\n'))
if conf >= CONFIDENCE_THRESHOLD: if conf >= CONF_THRESHOLD:
self.words.append(PredictedWord(conf, text)) self.words.append(PredictedWord(conf, text))
@property self.confidence = sum(word.confidence for word in self.words)
def confidence(self) -> int: self.text = ''.join(word.text + ' ' for word in self.words).strip()
return sum(word.confidence for word in self.words)
@property
def text(self) -> str:
return ''.join(word.text + ' ' for word in self.words)

44
videocr/video.py Normal file
View File

@ -0,0 +1,44 @@
from __future__ import annotations
from concurrent import futures
import pytesseract
import cv2
import timeit
from .models import PredictedFrame
class Video:
path: str
lang: str
num_frames: int
def __init__(self, path, lang):
self.path = path
self.lang = lang
v = cv2.VideoCapture(path)
self.num_frames = int(v.get(cv2.CAP_PROP_FRAME_COUNT))
v.release()
def _frame_ocr(self, img):
data = pytesseract.image_to_data(img, lang=self.lang)
return data
def run_ocr(self):
v = cv2.VideoCapture(self.path)
print(self.num_frames)
frames = (v.read()[1] for _ in range(40))
with futures.ProcessPoolExecutor() as pool:
frames_ocr = pool.map(self._frame_ocr, frames, chunksize=1)
for i, data in enumerate(frames_ocr):
pred = PredictedFrame(i, data)
print(pred.text)
v.release()
time_start = timeit.default_timer()
v = Video('1.mp4', 'HanS')
v.run_ocr()
time_stop = timeit.default_timer()
print(time_stop - time_start)