File size: 2,165 Bytes
3a89850
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import gradio as gr
import time
from config import model_repo_id, src_lang, tgt_lang
from indictrans2 import initialize_model_and_tokenizer, batch_translate
from examples import example_sentences


def load_models():
    model_dict = {}

    print("\tLoading model: %s" % model_repo_id)

    # build model and tokenizer
    en_indic_tokenizer, en_indic_model, en_indic_lora_model = (
        initialize_model_and_tokenizer()
    )

    model_dict["_tokenizer"] = en_indic_tokenizer
    model_dict["_model"] = en_indic_model
    model_dict["_lora_model"] = en_indic_lora_model

    return model_dict


def translation(text):

    start_time = time.time()

    tokenizer = model_dict["_tokenizer"]
    model = model_dict["_model"]
    lora_model = model_dict["_lora_model"]

    # org translation
    org_translation = batch_translate(
        [text],
        model=model,
        tokenizer=tokenizer,
    )
    org_output = org_translation[0]
    end_time = time.time()

    # lora translation
    lora_translation = batch_translate(
        [text],
        model=lora_model,
        tokenizer=tokenizer,
    )
    lora_output = lora_translation[0]
    end_time2 = time.time()

    result = {
        "source": src_lang,
        "target": tgt_lang,
        "input": text,
        "it2_result": org_output,
        "it2_conv_result": lora_output,
        "it2_inference_time": end_time - start_time,
        "it2_conv_inference_time": end_time2 - end_time,
    }

    return result


print("\tinit models")

global model_dict

model_dict = load_models()

inputs = gr.Textbox(lines=5, label="Input text")
outputs = gr.JSON(container=True)
submit_btn = gr.Button("Translate", variant="primary")

title = "IndicTrans2 fine-tuned on conversation"
description = f"Note: LoRA is trained only on En-Hi pair.\nDetails: https://github.com/AI4Bharat/IndicTrans2.\nLoRA Model: https://huggingface.co/sam749/IndicTrans2-Conv"

gr.Interface(
    fn=translation,
    inputs=inputs,
    outputs=outputs,
    title=title,
    description=description,
    submit_btn=submit_btn,
    examples=example_sentences,
    examples_per_page=10,
    cache_examples=False,
).launch(share=True)