Instructions to use DMetaSoul/nl2sql-chinese-basic with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use DMetaSoul/nl2sql-chinese-basic with Transformers:
# Load model directly from transformers import AutoTokenizer, AutoModelForSeq2SeqLM tokenizer = AutoTokenizer.from_pretrained("DMetaSoul/nl2sql-chinese-basic") model = AutoModelForSeq2SeqLM.from_pretrained("DMetaSoul/nl2sql-chinese-basic") - Notebooks
- Google Colab
- Kaggle
| license: apache-2.0 | |
| ## 简介 | |
| 这是一款根据自然语言生成 SQL 的模型(NL2SQL/Text2SQL),是我们自研众多 NL2SQL 模型中最为基础的一版,其它高级版模型后续将陆续进行开源。 | |
| 该模型基于 BART 架构,我们将 NL2SQL 问题建模为类似机器翻译的 Seq2Seq 形式,该模型的优势特点:参数规模较小、但 SQL 生成准确性也较高。 | |
| ## 用法 | |
| NL2SQL 任务中输入参数含有用户查询文本+数据库表信息,目前按照以下格式拼接模型的输入文本: | |
| ``` | |
| Question: 名人堂一共有多少球员 <sep> Tables: hall_of_fame: player_id, yearid, votedby, ballots, needed, votes, inducted, category, needed_note ; player_award: player_id, award_id, year, league_id, tie, notes <sep> | |
| ``` | |
| 具体使用方法参考以下示例: | |
| ```python | |
| import torch | |
| from transformers import AutoModelForSeq2SeqLM, MBartForConditionalGeneration, AutoTokenizer | |
| device = 'cuda' | |
| model_path = 'DMetaSoul/nl2sql-chinese-basic' | |
| sampling = False | |
| tokenizer = AutoTokenizer.from_pretrained(model_path, src_lang='zh_CN') | |
| #model = MBartForConditionalGeneration.from_pretrained(model_path) | |
| model = AutoModelForSeq2SeqLM.from_pretrained(model_path) | |
| model = model.half() | |
| model.to(device) | |
| input_texts = [ | |
| "Question: 所有章节的名称和描述是什么? <sep> Tables: sections: section id , course id , section name , section description , other details <sep>", | |
| "Question: 名人堂一共有多少球员 <sep> Tables: hall_of_fame: player_id, yearid, votedby, ballots, needed, votes, inducted, category, needed_note ; player_award: player_id, award_id, year, league_id, tie, notes ; player_award_vote: award_id, year, league_id, player_id, points_won, points_max, votes_first ; salary: year, team_id, league_id, player_id, salary ; player: player_id, birth_year, birth_month, birth_day, birth_country, birth_state, birth_city, death_year, death_month, death_day, death_country, death_state, death_city, name_first, name_last, name_given, weight <sep>" | |
| ] | |
| inputs = tokenizer(input_texts, max_length=512, return_tensors="pt", | |
| padding=True, truncation=True) | |
| inputs = {k:v.to(device) for k,v in inputs.items() if k not in ["token_type_ids"]} | |
| with torch.no_grad(): | |
| if sampling: | |
| outputs = model.generate(**inputs, do_sample=True, top_k=50, top_p=0.95, | |
| temperature=1.0, num_return_sequences=1, | |
| max_length=512, return_dict_in_generate=True, output_scores=True) | |
| else: | |
| outputs = model.generate(**inputs, num_beams=4, num_return_sequences=1, | |
| max_length=512, return_dict_in_generate=True, output_scores=True) | |
| output_ids = outputs.sequences | |
| results = tokenizer.batch_decode(output_ids, skip_special_tokens=True, | |
| clean_up_tokenization_spaces=True) | |
| for question, sql in zip(input_texts, results): | |
| print(question) | |
| print('SQL: {}'.format(sql)) | |
| print() | |
| ``` | |
| 输入结果如下: | |
| ``` | |
| Question: 所有章节的名称和描述是什么? <sep> Tables: sections: section id , course id , section name , section description , other details <sep> | |
| SQL: SELECT section name, section description FROM sections | |
| Question: 名人堂一共有多少球员 <sep> Tables: hall_of_fame: player_id, yearid, votedby, ballots, needed, votes, inducted, category, needed_note ; player_award: player_id, award_id, year, league_id, tie, notes ; player_award_vote: award_id, year, league_id, player_id, points_won, points_max, votes_first ; salary: year, team_id, league_id, player_id, salary ; player: player_id, birth_year, birth_month, birth_day, birth_country, birth_state, birth_city, death_year, death_month, death_day, death_country, death_state, death_city, name_first, name_last, name_given, weight <sep> | |
| SQL: SELECT count(*) FROM hall_of_fame | |
| ``` | |