add type hints

This commit is contained in:
bentrevett 2021-07-13 22:47:54 +01:00
parent 0ba2261d29
commit 47bb3895bc

View File

@ -3,15 +3,16 @@ import re
from SetSimilaritySearch import all_pairs
import numpy as np
import tqdm
from typing import List, Set, Tuple
class DocumentID:
def __init__(self, index, repo_name, file_name):
def __init__(self, index: int, repo_name: str, file_name: str):
self.index = index
self.repo_name = repo_name
self.file_name = file_name
def __eq__(self, other):
def __eq__(self, other) -> bool:
return (
self.index == other.index
and self.repo_name == other.repo_name
@ -46,7 +47,7 @@ class DuplicateDetector:
self.vocabulary[token] = token_id
return token_id
def add_file(self, document_id: DocumentID, code: str, language=None) -> bool:
def add_file(self, document_id: DocumentID, code: str) -> bool:
"""Add a file to the documents to be viewed by the duplicate detector"""
# split on non-alphanumeric characters
tokens = re.split(r"\W+", code)
@ -61,7 +62,7 @@ class DuplicateDetector:
self.document_elements.append(id_counter)
return True
def get_duplicate_pairs(self):
def get_duplicate_pairs(self) -> Tuple[int, int]:
"""Find the pairs of documents that are duplicates."""
# get all similar pairs of documents
similar_pairs = all_pairs(
@ -88,7 +89,7 @@ class DuplicateDetector:
)
return float(intersection_size) / union_size
def get_duplicate_clusters(self):
def get_duplicate_clusters(self) -> List[Set[DocumentID]]:
"""Compute the duplicates in the indexed documents."""
# stores duplicate clusters, list of set of DocumentID
@ -142,7 +143,7 @@ class DuplicateDetector:
return duplicate_clusters
def print_duplicate_clusters_stats(self, duplicate_clusters):
def print_duplicate_clusters_stats(self, duplicate_clusters: List[Set[DocumentID]]):
total_num_files = len(self.document_ids)
num_cloned_files = sum(len(c) for c in duplicate_clusters)
print(f"duplicated files: {num_cloned_files / total_num_files * 100}%")
@ -154,7 +155,9 @@ class DuplicateDetector:
f"duplication ratio: {((num_cloned_files - len(duplicate_clusters)) / total_num_files * 100)}"
)
def get_documents_to_exclude(self, duplicate_clusters):
def get_documents_to_exclude(
self, duplicate_clusters: List[Set[DocumentID]]
) -> Set[DocumentID]:
"""A set of DocumentIDs to exclude because they are duplicates."""
# remove one document from each duplicate cluster to keep