From f88e26b452ed8799d30c87c36133902ac447d5aa Mon Sep 17 00:00:00 2001 From: louistiti Date: Sat, 16 Nov 2024 21:13:37 +0800 Subject: [PATCH] feat(python tcp server): improve city entity accuracy --- tcp_server/src/lib/nlp.py | 47 ++++++++++++++++++++------------------- 1 file changed, 24 insertions(+), 23 deletions(-) diff --git a/tcp_server/src/lib/nlp.py b/tcp_server/src/lib/nlp.py index 66bc5bee..4db61688 100644 --- a/tcp_server/src/lib/nlp.py +++ b/tcp_server/src/lib/nlp.py @@ -27,7 +27,7 @@ spacy_model_mapping = { } } -geonamescache = GeonamesCache() +geonamescache = GeonamesCache(min_city_population=5000) countries = geonamescache.get_countries() cities = geonamescache.get_cities() @@ -85,30 +85,31 @@ def extract_spacy_entities(utterance: str) -> list[dict]: delete_unneeded_country_data(resolution['data']) break - city_population = 0 - for city in cities: - alternatenames = [name.casefold() for name in cities[city]['alternatenames']] - if cities[city]['name'].casefold() == ent.text.casefold() or ent.text.casefold() in alternatenames: - if city_population == 0: - entity += ':city' + if ':country' not in entity: + city_population = 0 + for city in cities: + alternatenames = [name.casefold() for name in cities[city]['alternatenames']] + if cities[city]['name'].casefold() == ent.text.casefold() or ent.text.casefold() in alternatenames: + if city_population == 0: + entity += ':city' - if cities[city]['population'] > city_population: - resolution['data'] = copy.deepcopy(cities[city]) - city_population = cities[city]['population'] + if cities[city]['population'] > city_population: + resolution['data'] = copy.deepcopy(cities[city]) + city_population = cities[city]['population'] - for country in countries: - if countries[country]['iso'] == cities[city]['countrycode']: - resolution['data']['country'] = copy.deepcopy(countries[country]) - break - try: - del resolution['data']['geonameid'] - del resolution['data']['alternatenames'] - del resolution['data']['admin1code'] - delete_unneeded_country_data(resolution['data']['country']) - except BaseException: - pass - else: - continue + for country in countries: + if countries[country]['iso'] == cities[city]['countrycode']: + resolution['data']['country'] = copy.deepcopy(countries[country]) + break + try: + del resolution['data']['geonameid'] + del resolution['data']['alternatenames'] + del resolution['data']['admin1code'] + delete_unneeded_country_data(resolution['data']['country']) + except BaseException: + pass + else: + continue entities.append({ 'start': ent.start_char,