Spaces:
Running
Running
import gradio as gr | |
import json | |
from transformers import pipeline | |
# Load the pipeline | |
tourModel = pipeline(model="manhan/GPT-Tour") | |
def getTour(income,size,years,sex,edu,wrk): | |
# person = a dict with person-level and hh-level attributes: | |
person = {} | |
hh_income = int(income) | |
if hh_income<0: | |
_skip = 1 # nvm, bad data | |
elif hh_income<4: # $25,000 | |
person['hh_inc'] = 'poor' | |
elif hh_income<6: # $50,000 | |
person['hh_inc'] = 'low' | |
elif hh_income<7: # $75,000 | |
person['hh_inc'] = 'medium' | |
elif hh_income<9: # $125,000 | |
person['hh_inc'] = 'high' | |
else: # over | |
person['hh_inc'] = 'affluent' | |
hh_size = int(size) | |
if hh_size == 1: | |
person['hh_size'] = 'single' | |
elif hh_size == 2: | |
person['hh_size'] = 'couple' | |
elif hh_size <= 4: | |
person['hh_size'] = 'small' | |
else: # more than four people | |
person['hh_size'] = 'large' | |
age = int(years) | |
if age < 18: | |
person['age_grp'] = 'child' | |
elif age < 45: | |
person['age_grp'] = 'younger' | |
elif age < 65: | |
person['age_grp'] = 'older' | |
else: | |
person['age_grp'] = 'senior' | |
person['sex'] = sex | |
person['edu'] = edu | |
person['wrk'] = wrk | |
activity_list = [] | |
prompt = json.dumps(person)[:-1] + ", pattern: " | |
print(person) | |
while not activity_list: | |
generated = tourModel(prompt, return_full_text=False, max_length=250, temperature=0.9)[0]['generated_text'] | |
#print(f"{generated}") | |
start_pos = generated.find('[') | |
end_pos = generated.find(']')+1 | |
activity_list_str = generated[start_pos:end_pos] | |
print(f"Extracted: '{activity_list_str}'") | |
# this check doesn't appear to work anyways | |
if person['wrk']=='yes' and activity_list_str.find('Work')==-1: | |
continue # try again | |
if person['wrk']=='no' and activity_list_str.find('Work')>0: | |
continue # try again | |
if activity_list_str: | |
try: | |
activity_list = json.loads(activity_list_str) | |
if activity_list[-1]!='Home': | |
activity_list=[] | |
continue | |
break | |
except Exception as e: | |
print("Error parsing activity list") | |
print(e) | |
else: | |
print("Nothing extracted!") | |
return activity_list | |
with gr.Interface(fn=getTour, inputs=[ | |
gr.inputs.Textbox(label="Annual Household Income (in dollars)"), | |
gr.inputs.Textbox(label="Household Size (number of people)"), | |
gr.inputs.Textbox(label="Traveler Age (years)"), | |
gr.inputs.Dropdown(["unknown", "male", "female"], label="Gender/sex"), | |
gr.inputs.Dropdown(["unknown", "grade school","highschool", "associates", "bachelors", "graduate"], label="Educational attainment level"), | |
gr.inputs.Dropdown(["unknown", "yes","no"], label="Worker status") | |
], outputs="json", title="GPT-Tour", description="Generate a tour for a person", allow_flagging=False, allow_screenshot=False, allow_embedding=False) as iface: | |
iface.launch() | |
gr.Interface.load("models/manhan/GPT-Tour").launch() |