Spaces:
Sleeping
Sleeping
import gradio as gr | |
import json | |
from transformers import pipeline | |
# Load the pipeline | |
tourModel = pipeline(model="manhan/GPT-Tour") | |
def getTour(income,size,years,sex,edu,wrk): | |
# person = a dict with person-level and hh-level attributes: | |
person = {} | |
hh_income = int(income) | |
if hh_income<25000: # $25,000 | |
person['hh_inc'] = 'poor' | |
elif hh_income<50000: # $50,000 | |
person['hh_inc'] = 'low' | |
elif hh_income<75000: # $75,000 | |
person['hh_inc'] = 'medium' | |
elif hh_income<125000: # $125,000 | |
person['hh_inc'] = 'high' | |
else: # over | |
person['hh_inc'] = 'affluent' | |
hh_size = int(size) | |
if hh_size == 1: | |
person['hh_size'] = 'single' | |
elif hh_size == 2: | |
person['hh_size'] = 'couple' | |
elif hh_size <= 4: | |
person['hh_size'] = 'small' | |
else: # more than four people | |
person['hh_size'] = 'large' | |
age = int(years) | |
if age < 18: | |
person['age_grp'] = 'child' | |
elif age < 45: | |
person['age_grp'] = 'younger' | |
elif age < 65: | |
person['age_grp'] = 'older' | |
else: | |
person['age_grp'] = 'senior' | |
person['sex'] = sex | |
person['edu'] = edu | |
person['wrk'] = wrk | |
activity_list = [] | |
prompt = json.dumps(person)[:-1] + ", pattern: " | |
print(person) | |
while not activity_list: | |
generated = tourModel(prompt, return_full_text=False, max_length=250, temperature=0.9)[0]['generated_text'] | |
#print(f"{generated}") | |
start_pos = generated.find('[') | |
end_pos = generated.find(']')+1 | |
activity_list_str = generated[start_pos:end_pos] | |
print(f"Extracted: '{activity_list_str}'") | |
#if person['wrk']=='yes' and activity_list_str.find('Work')==-1: | |
# continue # try again | |
#if person['wrk']=='no' and activity_list_str.find('Work')>0: | |
# continue # try again | |
if activity_list_str: | |
try: | |
activity_list = json.loads(activity_list_str) | |
if activity_list[-1]!='Home': | |
activity_list=[] | |
continue | |
break | |
except Exception as e: | |
print("Error parsing activity list") | |
print(e) | |
else: | |
print("Nothing extracted!") | |
return activity_list | |
with gr.Interface(fn=getTour, inputs=[ | |
gr.Textbox(label="Annual Household Income (in dollars)"), | |
gr.Textbox(label="Household Size (number of people)"), | |
gr.Textbox(label="Traveler Age (years)"), | |
gr.Dropdown(["unknown", "male", "female"], label="Gender/sex"), | |
gr.Dropdown(["unknown", "grade school","highschool", "associates", "bachelors", "graduate"], label="Educational attainment level"), | |
gr.Dropdown(["unknown", "yes","no"], label="Worker status")], | |
outputs=["json"], title="GPT-Travel", description="Author: Colby Brown, Manhan ([email protected])", allow_flagging='never') as iface: | |
iface.launch() |