mojad121 commited on
Commit
5a16de6
·
verified ·
1 Parent(s): 24fd080

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -161
app.py CHANGED
@@ -1,175 +1,62 @@
 
1
  import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline, Trainer, TrainingArguments
3
- from datasets import load_dataset, Dataset
4
- import torch
5
- import pandas as pd
6
- from transformers import DataCollatorForLanguageModeling
7
- from sklearn.model_selection import train_test_split
8
 
9
- MODEL_NAME = "microsoft/DialoGPT-medium"
10
  DATASET_NAME = "embedding-data/Amazon-QA"
11
- FINETUNED_MODEL_NAME = "MujtabaShopifyChatbot"
12
 
13
- chatbot_pipe = None
14
-
15
- def show_dataset_head(dataset, num_rows=5):
16
- print("Displaying dataset preview ", dataset)
17
- if isinstance(dataset, dict):
18
- for split in dataset.keys():
19
- print("Current split ", split)
20
- df = pd.DataFrame(dataset[split][:num_rows])
21
- cols = [col for col in ['query', 'pos', 'question', 'answer'] if col in df.columns]
22
- if cols:
23
- print("Dataset columns ", cols)
24
-
25
- def load_and_preprocess_data():
26
- print("Loading dataset from ", DATASET_NAME)
27
- try:
28
- dataset = load_dataset(DATASET_NAME)
29
- show_dataset_head(dataset)
30
-
31
- df = pd.DataFrame(dataset['train'])
32
-
33
- if 'query' in df.columns and 'pos' in df.columns:
34
- df = df.rename(columns={'query': 'question', 'pos': 'answer'})
35
- elif 'question' not in df.columns or 'answer' not in df.columns:
36
- df = df.rename(columns={df.columns[0]: 'question', df.columns[1]: 'answer'})
37
-
38
- df = df[['question', 'answer']].dropna()
39
- df = df[:5000]
40
-
41
- df['answer'] = df['answer'].astype(str).str.replace(r'\[\^|\].*', '', regex=True)
42
-
43
- processed_dataset = Dataset.from_pandas(df)
44
- show_dataset_head(processed_dataset)
45
- return processed_dataset.train_test_split(test_size=0.1)
46
- except Exception as e:
47
- print("Error loading dataset ", e)
48
- raise
49
-
50
- def tokenize_data(dataset):
51
- print("Tokenizing data with model ", MODEL_NAME)
52
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
53
- tokenizer.pad_token = tokenizer.eos_token
54
-
55
- def preprocess_function(examples):
56
- inputs = [f"question: {q} answer: {a}" for q, a in zip(examples["question"], examples["answer"])]
57
-
58
- model_inputs = tokenizer(
59
- inputs,
60
- max_length=128,
61
- truncation=True,
62
- padding='max_length'
63
- )
64
-
65
- model_inputs["labels"] = model_inputs["input_ids"].copy()
66
- return model_inputs
67
-
68
- return dataset.map(preprocess_function, batched=True)
69
-
70
- def fine_tune_model(tokenized_dataset):
71
- print("Starting fine-tuning process")
72
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
73
- model = AutoModelForCausalLM.from_pretrained(MODEL_NAME)
74
- tokenizer.pad_token = tokenizer.eos_token
75
-
76
- data_collator = DataCollatorForLanguageModeling(
77
- tokenizer=tokenizer,
78
- mlm=False
79
- )
80
-
81
- training_args = TrainingArguments(
82
- output_dir="./results",
83
- evaluation_strategy="epoch",
84
- learning_rate=5e-5,
85
- per_device_train_batch_size=4,
86
- per_device_eval_batch_size=4,
87
- num_train_epochs=3,
88
- weight_decay=0.01,
89
- save_total_limit=3,
90
- fp16=torch.cuda.is_available(),
91
- push_to_hub=False,
92
- report_to="none",
93
- logging_steps=100,
94
- save_steps=500
95
- )
96
-
97
- trainer = Trainer(
98
- model=model,
99
- args=training_args,
100
- train_dataset=tokenized_dataset["train"],
101
- eval_dataset=tokenized_dataset["test"],
102
- data_collator=data_collator
103
- )
104
-
105
- trainer.train()
106
- print("Training completed, saving model")
107
- model.save_pretrained(FINETUNED_MODEL_NAME)
108
- tokenizer.save_pretrained(FINETUNED_MODEL_NAME)
109
- return model
110
-
111
- def initialize_chatbot():
112
- global chatbot_pipe
113
- print("Initializing chatbot with model ", FINETUNED_MODEL_NAME)
114
- try:
115
- model = AutoModelForCausalLM.from_pretrained(FINETUNED_MODEL_NAME)
116
- tokenizer = AutoTokenizer.from_pretrained(FINETUNED_MODEL_NAME)
117
- tokenizer.pad_token = tokenizer.eos_token
118
-
119
- chatbot_pipe = pipeline(
120
- "text-generation",
121
- model=model,
122
- tokenizer=tokenizer,
123
- device=0 if torch.cuda.is_available() else -1
124
- )
125
- print("Chatbot initialized successfully")
126
- except Exception as e:
127
- print("Error initializing chatbot ", e)
128
- return None
129
- return chatbot_pipe
130
 
131
  def generate_response(message, history):
132
- if chatbot_pipe is None:
133
- print("Chatbot pipeline not initialized")
134
- return "System error: Chatbot not ready"
135
-
 
 
 
 
 
 
 
 
136
  try:
137
- print("Generating response for query ", message)
138
- response = chatbot_pipe(
139
- f"question: {message} answer:",
140
- max_length=128,
141
- do_sample=True,
142
  temperature=0.7,
143
- top_p=0.9
144
- )[0]['generated_text']
145
- final_response = response.split("answer:")[-1].strip()
146
- print("Generated response ", final_response)
147
- return final_response
148
  except Exception as e:
149
- print("Error generating response ", e)
150
- return "Sorry, I encountered an error processing your request"
151
 
152
- def deploy_chatbot():
153
- print("Launching chatbot interface")
154
- demo = gr.ChatInterface(
155
  fn=generate_response,
156
- title="Mujtaba's Shopify Assistant",
157
- description="Ask about products, shipping, or store policies",
158
  examples=[
159
- "Will this work with iPhone 15?",
160
- "What's the return window?",
161
- "Do you ship to Lahore?"
162
- ],
163
- theme="soft",
164
- cache_examples=False
165
  )
166
- return demo
167
-
168
- if __name__ == "__main__":
169
- dataset = load_and_preprocess_data()
170
- tokenized_data = tokenize_data(dataset)
171
-
172
- model = fine_tune_model(tokenized_data)
173
 
174
- initialize_chatbot()
175
- deploy_chatbot().launch()
 
1
+ import os
2
  import gradio as gr
3
+ from groq import Groq
4
+ from datasets import load_dataset
 
 
 
 
5
 
6
+ GROQ_MODEL = "llama3-70b-8192"
7
  DATASET_NAME = "embedding-data/Amazon-QA"
 
8
 
9
+ def load_shopify_context():
10
+ dataset = load_dataset(DATASET_NAME)
11
+ samples = dataset['train'].select(range(3))
12
+ examples = []
13
+ for sample in samples:
14
+ question = sample['query']
15
+ if isinstance(question, list):
16
+ question = question[0] if len(question) > 0 else "No question"
17
+ question = str(question).replace('\\', '/')
18
+ answer = sample.get('pos', sample.get('answer', ["No answer"]))
19
+ if isinstance(answer, list):
20
+ answer = answer[0] if len(answer) > 0 else "No answer"
21
+ answer = str(answer).replace('\\', '/')
22
+ examples.append(f"Q: {question}\nA: {answer}")
23
+ return '\n'.join(examples)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
  def generate_response(message, history):
26
+ api_key = os.getenv("GROQ_API_KEY")
27
+ if not api_key:
28
+ return "Error: GROQ_API_KEY not set. Please add it as a secret in your Space."
29
+ client = Groq(api_key=api_key)
30
+ context = load_shopify_context()
31
+ conversation = []
32
+ for user_msg, bot_msg in history:
33
+ safe_user = str(user_msg).replace('\\', '/')
34
+ safe_bot = str(bot_msg).replace('\\', '/')
35
+ conversation.extend([f"User: {safe_user}", f"Assistant: {safe_bot}"])
36
+ safe_message = str(message).replace('\\', '/')
37
+ prompt = f"You are an expert Shopify support agent. Context examples:\n{context}\n{chr(10).join(conversation)}\nUser: {safe_message}\nAssistant:"
38
  try:
39
+ response = client.chat.completions.create(
40
+ messages=[{"role": "user", "content": prompt}],
41
+ model=GROQ_MODEL,
 
 
42
  temperature=0.7,
43
+ max_tokens=256,
44
+ top_p=0.9,
45
+ stop=["<|endoftext|>"]
46
+ )
47
+ return response.choices[0].message.content
48
  except Exception as e:
49
+ return f"Error: {str(e)}"
 
50
 
51
+ with gr.Blocks() as app:
52
+ gr.Markdown("## Shopify Q&A Assistant (Groq-powered)")
53
+ gr.ChatInterface(
54
  fn=generate_response,
 
 
55
  examples=[
56
+ "What's your return policy?",
57
+ "Do you ship internationally?",
58
+ "Is this compatible with iPhone 15?"
59
+ ]
 
 
60
  )
 
 
 
 
 
 
 
61
 
62
+ app.launch()