SivaResearch commited on
Commit
5808915
·
verified ·
1 Parent(s): fa7bdde

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -0
app.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModelForCausalLM,
3
+ import gradio as gr
4
+
5
+
6
+ model = "ai4bharat/Airavata"
7
+
8
+ tokenizer = AutoTokenizer.from_pretrained(model, padding_side="left")
9
+ # tokenizer.pad_token = tokenizer.eos_token
10
+ # model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16).to(device)
11
+
12
+ llama_pipeline = pipeline(
13
+ "text-generation",
14
+ model=model,
15
+ torch_dtype=torch.float16,
16
+ device_map="auto",
17
+ )
18
+ SYSTEM_PROMPT = """<s>[INST] <<SYS>>
19
+ You are a helpful bot. Your answers are clear and concise.
20
+ <</SYS>>
21
+
22
+ """
23
+
24
+ # Formatting function for message and history
25
+ def format_message(message: str, history: list, memory_limit: int = 3) -> str:
26
+ """
27
+ Formats the message and history for the Llama model.
28
+
29
+ Parameters:
30
+ message (str): Current message to send.
31
+ history (list): Past conversation history.
32
+ memory_limit (int): Limit on how many past interactions to consider.
33
+
34
+ Returns:
35
+ str: Formatted message string
36
+ """
37
+ # always keep len(history) <= memory_limit
38
+ if len(history) > memory_limit:
39
+ history = history[-memory_limit:]
40
+
41
+ if len(history) == 0:
42
+ return SYSTEM_PROMPT + f"{message} [/INST]"
43
+
44
+ formatted_message = SYSTEM_PROMPT + f"{history[0][0]} [/INST] {history[0][1]} </s>"
45
+
46
+ # Handle conversation history
47
+ for user_msg, model_answer in history[1:]:
48
+ formatted_message += f"<s>[INST] {user_msg} [/INST] {model_answer} </s>"
49
+
50
+ # Handle the current message
51
+ formatted_message += f"<s>[INST] {message} [/INST]"
52
+
53
+ return formatted_message
54
+
55
+ # Generate a response from the Llama model
56
+ def get_llama_response(message: str, history: list) -> str:
57
+ """
58
+ Generates a conversational response from the Llama model.
59
+
60
+ Parameters:
61
+ message (str): User's input message.
62
+ history (list): Past conversation history.
63
+
64
+ Returns:
65
+ str: Generated response from the Llama model.
66
+ """
67
+ query = format_message(message, history)
68
+ response = ""
69
+
70
+ sequences = llama_pipeline(
71
+ query,
72
+ do_sample=True,
73
+ top_k=10,
74
+ num_return_sequences=1,
75
+ eos_token_id=tokenizer.eos_token_id,
76
+ max_length=1024,
77
+ )
78
+
79
+ generated_text = sequences[0]['generated_text']
80
+ response = generated_text[len(query):] # Remove the prompt from the output
81
+
82
+ print("Chatbot:", response.strip())
83
+ return response.strip()
84
+
85
+
86
+
87
+ gr.ChatInterface(get_llama_response).launch()