SivaResearch commited on
Commit
0d58b13
·
verified ·
1 Parent(s): c4807ae

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +91 -58
app.py CHANGED
@@ -1,65 +1,98 @@
1
  import gradio as gr
2
- import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
 
5
- device = "cuda" if torch.cuda.is_available() else "cpu"
6
-
7
- def create_prompt_with_chat_format(messages, bos="<s>", eos="</s>", add_bos=True):
8
- formatted_text = ""
9
- for message in messages:
10
- if message["role"] == "system":
11
- formatted_text += "\n" + message["content"] + "\n"
12
- elif message["role"] == "user":
13
- formatted_text += "\n" + message["content"] + "\n"
14
- elif message["role"] == "assistant":
15
- formatted_text += "\n" + message["content"].strip() + eos + "\n"
16
- else:
17
- raise ValueError(
18
- "Tulu chat template only supports 'system', 'user', and 'assistant' roles. Invalid role: {}.".format(
19
- message["role"]
20
- )
21
- )
22
- formatted_text += "\n"
23
- formatted_text = bos + formatted_text if add_bos else formatted_text
24
- return formatted_text
25
-
26
- def inference(input_prompts, model, tokenizer):
27
- input_prompts = [
28
- create_prompt_with_chat_format([{"role": "user", "content": input_prompt}], add_bos=False)
29
- for input_prompt in input_prompts
30
- ]
31
-
32
- encodings = tokenizer(input_prompts, padding=True, return_tensors="pt")
33
- encodings = encodings.to(device)
34
-
35
- with torch.no_grad():
36
- outputs = model.generate(encodings.input_ids, do_sample=False, max_length=250)
37
-
38
- output_texts = tokenizer.batch_decode(outputs.detach(), skip_special_tokens=True)
39
-
40
- input_prompts = [
41
- tokenizer.decode(tokenizer.encode(input_prompt), skip_special_tokens=True) for input_prompt in input_prompts
42
- ]
43
- output_texts = [output_text[len(input_prompt) :] for input_prompt, output_text in zip(input_prompts, output_texts)]
44
- return output_texts
45
-
46
- model_name = "ai4bharat/Airavata"
47
- tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
48
- tokenizer.pad_token = tokenizer.eos_token
49
- model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16).to(device)
50
- examples = [
51
- ["मुझे अपने करियर के बारे में सुझाव दो", "मैं कैसे अध्ययन कर सकता हूँ?"],
52
- ["कृपया मुझे एक कहानी सुनाएं", "ताजमहल के बारे में कुछ बताएं"],
53
- ["मेरा नाम क्या है?", "आपका पसंदीदा फिल्म कौन सी है?"],
54
- ]
55
-
56
- iface = gr.Chat(
57
- model_fn=lambda input_prompts: inference(input_prompts, model, tokenizer),
58
- inputs=["text"],
59
  outputs="text",
60
- examples=examples,
61
- title="Airavata Chatbot",
62
- theme="light", # Optional: Set a light theme
 
63
  )
64
 
65
  iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
 
2
  from transformers import AutoTokenizer, AutoModelForCausalLM
3
 
4
+ tokenizer = AutoTokenizer.from_pretrained("ai4bharat/Airavata")
5
+ model = AutoModelForCausalLM.from_pretrained("ai4bharat/Airavata")
6
+
7
+ def generate_response(prompt):
8
+ input_ids = tokenizer.encode(prompt, return_tensors="pt", max_length=50)
9
+ output_ids = model.generate(input_ids, max_length=100, num_beams=5, no_repeat_ngram_size=2)
10
+ response = tokenizer.decode(output_ids[0], skip_special_tokens=True)
11
+ return response
12
+
13
+ iface = gr.Interface(
14
+ fn=generate_response,
15
+ inputs="text",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  outputs="text",
17
+ live=True,
18
+ title="Airavata LLMs Chatbot",
19
+ description="Ask me anything, and I'll generate a response!",
20
+ theme="light",
21
  )
22
 
23
  iface.launch()
24
+
25
+
26
+
27
+
28
+
29
+
30
+
31
+
32
+
33
+
34
+ # import gradio as gr
35
+ # import torch
36
+ # from transformers import AutoTokenizer, AutoModelForCausalLM
37
+
38
+ # device = "cuda" if torch.cuda.is_available() else "cpu"
39
+
40
+ # def create_prompt_with_chat_format(messages, bos="<s>", eos="</s>", add_bos=True):
41
+ # formatted_text = ""
42
+ # for message in messages:
43
+ # if message["role"] == "system":
44
+ # formatted_text += "\n" + message["content"] + "\n"
45
+ # elif message["role"] == "user":
46
+ # formatted_text += "\n" + message["content"] + "\n"
47
+ # elif message["role"] == "assistant":
48
+ # formatted_text += "\n" + message["content"].strip() + eos + "\n"
49
+ # else:
50
+ # raise ValueError(
51
+ # "Tulu chat template only supports 'system', 'user', and 'assistant' roles. Invalid role: {}.".format(
52
+ # message["role"]
53
+ # )
54
+ # )
55
+ # formatted_text += "\n"
56
+ # formatted_text = bos + formatted_text if add_bos else formatted_text
57
+ # return formatted_text
58
+
59
+ # def inference(input_prompts, model, tokenizer):
60
+ # input_prompts = [
61
+ # create_prompt_with_chat_format([{"role": "user", "content": input_prompt}], add_bos=False)
62
+ # for input_prompt in input_prompts
63
+ # ]
64
+
65
+ # encodings = tokenizer(input_prompts, padding=True, return_tensors="pt")
66
+ # encodings = encodings.to(device)
67
+
68
+ # with torch.no_grad():
69
+ # outputs = model.generate(encodings.input_ids, do_sample=False, max_length=250)
70
+
71
+ # output_texts = tokenizer.batch_decode(outputs.detach(), skip_special_tokens=True)
72
+
73
+ # input_prompts = [
74
+ # tokenizer.decode(tokenizer.encode(input_prompt), skip_special_tokens=True) for input_prompt in input_prompts
75
+ # ]
76
+ # output_texts = [output_text[len(input_prompt) :] for input_prompt, output_text in zip(input_prompts, output_texts)]
77
+ # return output_texts
78
+
79
+ # model_name = "ai4bharat/Airavata"
80
+ # tokenizer = AutoTokenizer.from_pretrained(model_name, padding_side="left")
81
+ # tokenizer.pad_token = tokenizer.eos_token
82
+ # model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16).to(device)
83
+ # examples = [
84
+ # ["मुझे अपने करियर के बारे में सुझाव दो", "मैं कैसे अध्ययन कर सकता हूँ?"],
85
+ # ["कृपया मुझे एक कहानी सुनाएं", "ताजमहल के बारे में कुछ बताएं"],
86
+ # ["मेरा नाम क्या है?", "आपका पसंदीदा फिल्म कौन सी है?"],
87
+ # ]
88
+
89
+ # iface = gr.Chat(
90
+ # model_fn=lambda input_prompts: inference(input_prompts, model, tokenizer),
91
+ # inputs=["text"],
92
+ # outputs="text",
93
+ # examples=examples,
94
+ # title="Airavata Chatbot",
95
+ # theme="light", # Optional: Set a light theme
96
+ # )
97
+
98
+ # iface.launch()