【语音识别逻辑修改】

This commit is contained in:
weizhihong 2023-04-10 17:13:16 +08:00
parent 53fd77af5c
commit 621cd8befa
3 changed files with 48 additions and 40 deletions

View File

@ -89,10 +89,10 @@ public class VoiceParseServiceImpl implements VoiceParseService {
} }
// 匹配指令 // 匹配指令
for (VoiceDiscriminateRule rule : ruleList) { for (VoiceDiscriminateRule rule : ruleList) {
List<String> matchGroupList = keyWordsMatch(rule.getKeyWordRules(), result.getMatchOriginContent()); Matcher matcher = keyWordsMatch(rule.getKeyWordRules(), result.getMatchOriginContent());
if (!CollectionUtils.isEmpty(matchGroupList)) { if (matcher != null) {
result.setRule(rule); result.setRule(rule);
result.setMatchGroupList(matchGroupList); result.setMatcher(matcher);
return; return;
} }
} }
@ -107,16 +107,13 @@ public class VoiceParseServiceImpl implements VoiceParseService {
* @param content 语音内容 * @param content 语音内容
* @return groupList * @return groupList
*/ */
private List<String> keyWordsMatch(String patternStr, String content) { private Matcher keyWordsMatch(String patternStr, String content) {
List<String> groupList = new ArrayList<>();
Pattern pattern = Pattern.compile(translateToPinYin(patternStr));// 匹配的模式 Pattern pattern = Pattern.compile(translateToPinYin(patternStr));// 匹配的模式
Matcher matcher = pattern.matcher(content); Matcher matcher = pattern.matcher(content);
if (matcher.find()) { if (matcher.find()) {
for (int index = 1, size = matcher.groupCount(); index <= size; index++) { return matcher;
groupList.add(matcher.group(index));
} }
} return null;
return groupList;
} }
/** /**
@ -126,12 +123,11 @@ public class VoiceParseServiceImpl implements VoiceParseService {
* @param result 结果信息 * @param result 结果信息
*/ */
private void paramExtract(Simulation simulation, VoiceDiscriminateResult result) { private void paramExtract(Simulation simulation, VoiceDiscriminateResult result) {
List<String> groupList = result.getMatchGroupList();
List<ParamExtractRule> paramsRules = result.getRule().getParamsRules(); List<ParamExtractRule> paramsRules = result.getRule().getParamsRules();
List<ParamExtractResult> paramExtractResults = new ArrayList<>(paramsRules.size()); List<ParamExtractResult> paramExtractResults = new ArrayList<>(paramsRules.size());
ParamExtractResult extractResult = null; ParamExtractResult extractResult = null;
String[] groupStrArr = null; String[] groupStrArr = null;
int groupSize = groupList.size(); int groupSize = result.getGroupCount();
for (ParamExtractRule rule : paramsRules) { for (ParamExtractRule rule : paramsRules) {
if (rule.getIndexArr() != null) { // 如果定位信息不为空 if (rule.getIndexArr() != null) { // 如果定位信息不为空
groupStrArr = new String[rule.getIndexArr().length]; groupStrArr = new String[rule.getIndexArr().length];
@ -141,7 +137,7 @@ public class VoiceParseServiceImpl implements VoiceParseService {
result.setMsg("提取参数出错"); result.setMsg("提取参数出错");
return; return;
} }
groupStrArr[index] = groupList.get(rule.getIndexArr()[index]); groupStrArr[index] = result.getGroup(rule.getIndexArr()[index]);
} }
} }
extractResult = new ParamExtractResult(); extractResult = new ParamExtractResult();

View File

@ -28,7 +28,7 @@ public enum ExtractRule {
STATION_NAME_EXTRACT("根据车站名称匹配车站") { STATION_NAME_EXTRACT("根据车站名称匹配车站") {
@Override @Override
public Object matchParam(Simulation simulation, String... sourceStr) { public Object matchParam(Simulation simulation, String[] sourceStr) {
if (sourceStr == null || sourceStr.length == 0) { if (sourceStr == null || sourceStr.length == 0) {
throw new IllegalArgumentException("定位参数不缺失"); throw new IllegalArgumentException("定位参数不缺失");
} }
@ -36,37 +36,22 @@ public enum ExtractRule {
// 这里可能需要做对入参做一些处理 // 这里可能需要做对入参做一些处理
String stationName = sourceStr[0]; String stationName = sourceStr[0];
return findDevice(simulation, MapElement.DeviceType.STATION,Station.class,stationName); return findDevice(simulation, MapElement.DeviceType.STATION,Station.class,stationName);
/* Optional<Station> stationOptional = simulation.getRepository().getStationList().stream()
.filter(station -> Objects.equals(stationName, PinYinUtil.toPinYin(station.getName())))
.findFirst();
if (stationOptional.isPresent()) {
return stationOptional.get();
}
throw new IllegalArgumentException("未找到【" + stationName + "】车站");*/
} }
}, },
ROUTE_NAME_EXTRACT("根据进路名称匹配进路") { ROUTE_NAME_EXTRACT("根据进路名称匹配进路") {
@Override @Override
public Object matchParam(Simulation simulation, String... sourceStr) { public Object matchParam(Simulation simulation, String[] sourceStr) {
if (sourceStr == null || sourceStr.length == 0) { if (sourceStr == null || sourceStr.length == 0) {
throw new IllegalArgumentException("定位参数不缺失"); throw new IllegalArgumentException("定位参数不缺失");
} }
// 这里可能对XS做处理 // 这里可能对XS做处理
String routeName = sourceStr[0]; String routeName = sourceStr[0];
return findDevice(simulation, MapElement.DeviceType.ROUTE,Route.class,routeName); return findDevice(simulation, MapElement.DeviceType.ROUTE,Route.class,routeName);
/* Optional<Route> routeOptional = simulation.getRepository().getRouteList().stream()
.filter(route -> Objects.equals(route.getName(), routeName))
.findFirst();
if (routeOptional.isPresent()) {
return routeOptional.get();
}
throw new IllegalArgumentException("未找到【" + routeName + "】进路");*/
} }
}, },
ROUTE_SIGNAL_EXTRACT("根据起始、终点信号机匹配进路") { ROUTE_SIGNAL_EXTRACT("根据起始、终点信号机匹配进路") {
@Override @Override
public Object matchParam(Simulation simulation, String... sourceStr) { public Object matchParam(Simulation simulation, String[] sourceStr) {
if (sourceStr == null || sourceStr.length < 2) { if (sourceStr == null || sourceStr.length < 2) {
throw new IllegalArgumentException("定位参数缺失"); throw new IllegalArgumentException("定位参数缺失");
} }
@ -85,19 +70,18 @@ public enum ExtractRule {
if (routeOptional.isPresent()) { if (routeOptional.isPresent()) {
return routeOptional.get(); return routeOptional.get();
} }
// throw new IllegalArgumentException("未找到【" + startSignalName + "-" + endSignalName + "】进路");
throw new IllegalArgumentException(String.format("未找到信号机始端[%s-%s],终端[%s-%s]",startSignalName, startSignal.getName(), endSignalName, endSinal.getName())); throw new IllegalArgumentException(String.format("未找到信号机始端[%s-%s],终端[%s-%s]",startSignalName, startSignal.getName(), endSignalName, endSinal.getName()));
} }
}, },
STAND_STATION_UP_DOWN_EXTRACT("根据车站、上下行匹配站台") { STAND_STATION_UP_DOWN_EXTRACT("根据车站、上下行匹配站台") {
@Override @Override
public Object matchParam(Simulation simulation, String... sourceStr) { public Object matchParam(Simulation simulation, String[] sourceStr) {
if (sourceStr == null || sourceStr.length < 2) { if (sourceStr == null || sourceStr.length < 2) {
throw new IllegalArgumentException("定位参数缺失"); throw new IllegalArgumentException("定位参数缺失");
} }
String stationName = sourceStr[0], upDown = sourceStr[1]; String stationName = sourceStr[0], upDown = sourceStr[1];
Station station = (Station) STATION_NAME_EXTRACT.matchParam(simulation, stationName); Station station = (Station) STATION_NAME_EXTRACT.matchParam(simulation, new String[] {stationName});
Boolean right = (Boolean) UP_DOWN_WAY.matchParam(simulation,upDown); Boolean right = (Boolean) UP_DOWN_WAY.matchParam(simulation,new String[]{upDown});
List<Stand> stands = station.getStandOf(right); List<Stand> stands = station.getStandOf(right);
if(CollectionUtils.isEmpty(stands)){ if(CollectionUtils.isEmpty(stands)){
throw new IllegalArgumentException(String.format("不能获取对应的站台 车站[%s-%s],上下行[%s-%s]",stationName,station.getCode(),upDown,right)); throw new IllegalArgumentException(String.format("不能获取对应的站台 车站[%s-%s],上下行[%s-%s]",stationName,station.getCode(),upDown,right));
@ -107,7 +91,7 @@ public enum ExtractRule {
} }
},UP_DOWN_WAY("车辆上下行"){ },UP_DOWN_WAY("车辆上下行"){
@Override @Override
public Object matchParam(Simulation simulation, String... sourceStr) { public Object matchParam(Simulation simulation, String[] sourceStr) {
String way = sourceStr[0]; String way = sourceStr[0];
if(StringUtils.containsIgnoreCase(way,"shang")){ if(StringUtils.containsIgnoreCase(way,"shang")){
return true; return true;
@ -118,7 +102,7 @@ public enum ExtractRule {
} }
},SWITCH_NAME("道岔名称"){ },SWITCH_NAME("道岔名称"){
@Override @Override
public Object matchParam(Simulation simulation, String... sourceStr) { public Object matchParam(Simulation simulation, String[] sourceStr) {
String swtichName = sourceStr[0]; String swtichName = sourceStr[0];
return ExtractRule.findDevice(simulation, MapElement.DeviceType.SWITCH,Switch.class,swtichName); return ExtractRule.findDevice(simulation, MapElement.DeviceType.SWITCH,Switch.class,swtichName);
} }
@ -134,7 +118,7 @@ public enum ExtractRule {
this.description = description; this.description = description;
} }
public abstract Object matchParam(Simulation simulation, String... sourceStr); public abstract Object matchParam(Simulation simulation, String[] sourceStr);
private static MapNamedElement findDevice(Simulation simulation, MapElement.DeviceType dt, Class<? extends MapNamedElement> eleClass,String matchVal){ private static MapNamedElement findDevice(Simulation simulation, MapElement.DeviceType dt, Class<? extends MapNamedElement> eleClass,String matchVal){
List<? extends MapNamedElement> eleList = simulation.getRepository().getListByType(dt,eleClass); List<? extends MapNamedElement> eleList = simulation.getRepository().getListByType(dt,eleClass);

View File

@ -4,6 +4,8 @@ import club.joylink.rtss.simulation.cbtc.member.SimulationMember;
import lombok.Data; import lombok.Data;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.regex.Matcher;
/** /**
* 语音解析结果 * 语音解析结果
@ -33,9 +35,14 @@ public class VoiceDiscriminateResult {
private VoiceDiscriminateRule rule; private VoiceDiscriminateRule rule;
/** /**
* 匹配指令结果 * 原始匹配指令结果
*/ */
private List<String> matchGroupList; private Matcher matcher;
/**
* 纠正过的group map
*/
private Map<Integer, String> correctGroupStr;
/** /**
* 参数提取集合 * 参数提取集合
@ -53,4 +60,25 @@ public class VoiceDiscriminateResult {
public String getMatchOriginContent() { public String getMatchOriginContent() {
return this.originPinYin; return this.originPinYin;
} }
/**
* 获取匹配到group数量
* @return 数量
*/
public int getGroupCount() {
return matcher == null ? 0 : matcher.groupCount();
}
/**
* 根据索引获取group 原始
*
* @param index 位置
* @return group
*/
public String getGroup(int index) {
if (index > getGroupCount()) {
throw new IllegalArgumentException("out of index");
}
return matcher.group(index);
}
} }