Dddixyy commited on
Commit
e7a45f0
·
verified ·
1 Parent(s): 3fa5eb5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -15
app.py CHANGED
@@ -1,29 +1,18 @@
1
  import gradio as gr
2
  import torch
3
  from transformers import MarianMTModel, MarianTokenizer
4
- from optimum.intel import IncQuantizer
5
 
6
- # Load and optimize the model (quantization)
7
  model_name = "Dddixyy/latin-italian-translator"
8
-
9
- # Load the quantized model if available or use a regular model (quantization shown as an example)
10
- try:
11
- # Attempt to load a quantized version if it's available
12
- quantizer = IncQuantizer.from_pretrained(model_name)
13
- model = quantizer.quantize()
14
- print("Quantized model loaded.")
15
- except Exception as e:
16
- print(f"Error loading quantized model: {e}")
17
- model = MarianMTModel.from_pretrained(model_name)
18
-
19
- # Load tokenizer
20
  tokenizer = MarianTokenizer.from_pretrained(model_name)
 
21
 
22
  # Translation function
23
  def translate_latin_to_italian(latin_text):
24
- # Truncate input to 512 tokens to avoid overload (adjust as necessary)
25
  inputs = tokenizer(latin_text, return_tensors="pt", padding=True, truncation=True, max_length=512)
26
 
 
27
  with torch.no_grad():
28
  generated_ids = model.generate(inputs["input_ids"])
29
 
 
1
  import gradio as gr
2
  import torch
3
  from transformers import MarianMTModel, MarianTokenizer
 
4
 
5
+ # Load the MarianMT model and tokenizer
6
  model_name = "Dddixyy/latin-italian-translator"
 
 
 
 
 
 
 
 
 
 
 
 
7
  tokenizer = MarianTokenizer.from_pretrained(model_name)
8
+ model = MarianMTModel.from_pretrained(model_name)
9
 
10
  # Translation function
11
  def translate_latin_to_italian(latin_text):
12
+ # Truncate input to a maximum length of 512 tokens to avoid overload
13
  inputs = tokenizer(latin_text, return_tensors="pt", padding=True, truncation=True, max_length=512)
14
 
15
+ # Use torch.no_grad() to speed up inference by not calculating gradients
16
  with torch.no_grad():
17
  generated_ids = model.generate(inputs["input_ids"])
18