rashedalhuniti commited on
Commit
9e51e1d
·
verified ·
1 Parent(s): 3e8a64e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -13
app.py CHANGED
@@ -2,20 +2,12 @@ import gradio as gr
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
 
5
- # Load model and tokenizer
6
- model_name = "inceptionai/jais-13b"
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
  model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto")
9
 
10
- # Define chatbot function
11
- def chat_with_jais(prompt):
12
  inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
13
- outputs = model.generate(**inputs, max_length=512)
14
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
15
-
16
- # Gradio Interface
17
- interface = gr.Interface(fn=chat_with_jais, inputs="text", outputs="text", title="JAIS-13B Chatbot")
18
-
19
- # Launch the app
20
- if __name__ == "__main__":
21
- interface.launch()
 
2
  from transformers import AutoModelForCausalLM, AutoTokenizer
3
  import torch
4
 
5
+ # Load JAIS-13B model and tokenizer
6
+ model_name = "InceptionAI/jais-13b"
7
  tokenizer = AutoTokenizer.from_pretrained(model_name)
8
  model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto")
9
 
10
+ def generate_response(prompt, max_length=512, temperature=0.7):
 
11
  inputs = tokenizer(prompt, return_tensors="pt").to("cuda")
12
+ with torch.no_grad():
13
+ output = model