query_analysis / app.py
chicham
Add a print for entity extraction (#3)
442640c unverified
raw
history blame
9.16 kB
"""Demo gradio app for some text/query augmentation."""
from __future__ import annotations
import functools
from collections import defaultdict
from itertools import chain
from typing import Any
from typing import Callable
from typing import Mapping
from typing import Sequence
import attr
import environ
import fasttext # not working with python3.9
import gradio as gr
from tokenizers.pre_tokenizers import Whitespace
from transformers.pipelines import pipeline
from transformers.pipelines.base import Pipeline
from transformers.pipelines.token_classification import AggregationStrategy
def compose(*functions) -> Callable:
"""
Compose functions.
Args:
functions: functions to compose.
Returns:
Composed functions.
"""
def apply(f, g):
return lambda x: f(g(x))
return functools.reduce(apply, functions[::-1], lambda x: x)
def mapped(fn) -> Callable:
"""
Decorator to apply map/filter to a function
"""
def inner(func):
partial_fn = functools.partial(fn, func)
@functools.wraps(func)
def wrapper(*args, **kwargs):
return partial_fn(*args, **kwargs)
return wrapper
return inner
@attr.frozen
class Prediction:
"""Dataclass to store prediction results."""
label: str
score: float
@attr.frozen
class Models:
identification: Predictor
translation: Predictor
classification: Predictor
ner: Predictor
recipe: Predictor
@attr.frozen
class Predictor:
load_fn: Callable
predict_fn: Callable = attr.field(default=lambda model, query: model(query))
model: Any = attr.field(init=False)
def __attrs_post_init__(self):
object.__setattr__(self, "model", self.load_fn())
def __call__(self, *args: Any, **kwds: Any) -> Any:
return self.predict_fn(self.model, *args, **kwds)
@environ.config(prefix="QUERY_INTERPRETATION")
class AppConfig:
@environ.config
class Identification:
"""Identification model configuration."""
model = environ.var(default="./models/lid.176.ftz")
max_results = environ.var(default=3, converter=int)
@environ.config
class Translation:
"""Translation models configuration."""
model = environ.var(default="t5-small")
sources = environ.var(default="de,fr")
target = environ.var(default="en")
@environ.config
class Classification:
"""Classification model configuration."""
model = environ.var(default="typeform/distilbert-base-uncased-mnli")
max_results = environ.var(default=5, converter=int)
@environ.config
class NER:
general = environ.var(
default="asahi417/tner-xlm-roberta-large-uncased-wnut2017",
)
recipe = environ.var(default="adamlin/recipe-tag-model")
identification: Identification = environ.group(Identification)
translation: Translation = environ.group(Translation)
classification: Classification = environ.group(Classification)
ner: NER = environ.group(NER)
def predict(
models: Models,
query: str,
categories: Sequence[str],
supported_languages: tuple[str, ...] = ("fr", "de"),
) -> tuple[
Mapping[str, float],
str,
Mapping[str, float],
Sequence[tuple[str, str | None]],
Sequence[tuple[str, str | None]],
]:
"""Predict from a textual query:
- the language
- classify as a recipe or not
- extract the recipe
"""
def predict_lang(query) -> Mapping[str, float]:
def predict_fn(query) -> Sequence[Prediction]:
return tuple(
Prediction(label=label, score=score)
for label, score in zip(*models.identification(query, k=176))
)
@mapped(map)
def format_label(prediction: Prediction) -> Prediction:
return attr.evolve(
prediction,
label=prediction.label.replace("__label__", ""),
)
def filter_labels(prediction: Prediction) -> bool:
return prediction.label in supported_languages + ("en",)
def format_output(predictions: Sequence[Prediction]) -> dict:
return {pred.label: pred.score for pred in predictions}
apply_fn = compose(
predict_fn,
format_label,
functools.partial(filter, filter_labels),
format_output,
)
return apply_fn(query)
def translate_query(query: str, languages: Mapping[str, float]) -> str:
def predicted_language() -> str:
return max(languages.items(), key=lambda lang: lang[1])[0]
def translate(query):
lang = predicted_language()
if lang in supported_languages:
output = models.translation(query, lang)[0]["translation_text"]
else:
output = query
return output
return translate(query)
def classify_query(query, categories) -> Mapping[str, float]:
predictions = models.classification(query, categories)
return dict(zip(predictions["labels"], predictions["scores"]))
def extract_entities(
predict_fn: Callable,
query: str,
) -> Sequence[tuple[str, str | None]]:
def get_entity(pred: Mapping[str, str]):
return pred.get("entity", pred.get("entity_group", None))
mapping = defaultdict(lambda: None)
mapping.update(**{pred["word"]: get_entity(pred) for pred in predict_fn(query)})
query_processed = Whitespace().pre_tokenize_str(query)
res = tuple(
chain.from_iterable(
((word, mapping[word]), (" ", None)) for word, _ in query_processed
),
)
print(res)
return res
languages = predict_lang(query)
translation = translate_query(query, languages)
classifications = classify_query(translation, categories)
general_entities = extract_entities(models.ner, query)
recipe_entities = extract_entities(models.recipe, translation)
return languages, translation, classifications, general_entities, recipe_entities
def main():
cfg: AppConfig = AppConfig.from_environ()
def load_translation_models(
sources: Sequence[str],
target: str,
models: Sequence[str],
) -> Pipeline:
result = {
src: pipeline(f"translation_{src}_to_{target}", models)
for src, models in zip(sources, models)
}
return result
def extract_commas_separated_values(value: str) -> Sequence[str]:
return tuple(filter(None, value.split(",")))
models = Models(
identification=Predictor(
load_fn=lambda: fasttext.load_model(cfg.identification.model),
predict_fn=lambda model, query, k: model.predict(query, k=k),
),
translation=Predictor(
load_fn=functools.partial(
load_translation_models,
sources=extract_commas_separated_values(cfg.translation.sources),
target=cfg.translation.target,
models=["Helsinki-NLP/opus-mt-de-en", "Helsinki-NLP/opus-mt-fr-en"],
),
predict_fn=lambda models, query, src: models[src](query),
),
classification=Predictor(
load_fn=lambda: pipeline(
"zero-shot-classification",
model=cfg.classification.model,
),
predict_fn=lambda model, query, categories: model(query, categories),
),
ner=Predictor(
load_fn=lambda: pipeline(
"ner",
model=cfg.ner.general,
aggregation_strategy=AggregationStrategy.MAX,
),
),
recipe=Predictor(
load_fn=lambda: pipeline("ner", model=cfg.ner.recipe),
),
)
iface = gr.Interface(
fn=lambda query, categories: predict(
models,
query.strip(),
extract_commas_separated_values(categories),
),
examples=[["gateau au chocolat paris"], ["Newyork LA flight"]],
inputs=[
gr.inputs.Textbox(label="Query"),
gr.inputs.Textbox(
label="categories (commas separated and in english)",
default="cooking and recipe,traveling,location,information,buy or sell",
),
],
outputs=[
gr.outputs.Label(
num_top_classes=cfg.identification.max_results,
type="auto",
label="Language identification",
),
gr.outputs.Textbox(
label="English query",
type="auto",
),
gr.outputs.Label(
num_top_classes=cfg.classification.max_results,
type="auto",
label="Predicted categories",
),
gr.outputs.HighlightedText(label="NER generic"),
gr.outputs.HighlightedText(label="NER Recipes"),
],
interpretation="default",
)
iface.launch(debug=True)
if __name__ == "__main__":
main()