cody82 commited on
Commit
f49a41a
·
verified ·
1 Parent(s): 12dd231

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -4
app.py CHANGED
@@ -2,7 +2,7 @@ import torch
2
  import gradio as gr
3
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
 
5
- model_id = "google/flan-t5-base"
6
  tokenizer = AutoTokenizer.from_pretrained(model_id)
7
  model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -12,19 +12,19 @@ def respond(message, history=None):
12
  if history is None:
13
  history = []
14
 
15
- prompt = f"Question: {message} Answer:"
16
  inputs = tokenizer(prompt, return_tensors="pt").to(device)
17
 
18
  with torch.no_grad():
19
  outputs = model.generate(
20
  **inputs,
21
- max_new_tokens=50,
22
  do_sample=False,
23
  eos_token_id=tokenizer.eos_token_id
24
  )
25
  answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
26
 
27
- # Убираем из ответа префикс prompt, если он остался
28
  if answer.lower().startswith(prompt.lower()):
29
  answer = answer[len(prompt):].strip()
30
 
 
2
  import gradio as gr
3
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
4
 
5
+ model_id = "google/flan-t5-base" # можно попробовать flan-t5-large
6
  tokenizer = AutoTokenizer.from_pretrained(model_id)
7
  model = AutoModelForSeq2SeqLM.from_pretrained(model_id)
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
12
  if history is None:
13
  history = []
14
 
15
+ prompt = f"Answer the following question about Innopolis University clearly and concisely.\nQuestion: {message}\nAnswer:"
16
  inputs = tokenizer(prompt, return_tensors="pt").to(device)
17
 
18
  with torch.no_grad():
19
  outputs = model.generate(
20
  **inputs,
21
+ max_new_tokens=100,
22
  do_sample=False,
23
  eos_token_id=tokenizer.eos_token_id
24
  )
25
  answer = tokenizer.decode(outputs[0], skip_special_tokens=True)
26
 
27
+ # Отрезаем если модель включила prompt в ответ
28
  if answer.lower().startswith(prompt.lower()):
29
  answer = answer[len(prompt):].strip()
30